UDNN_RNN_TANH; __shared__ T_MATH smemi[GROUP_BATCH_SIZE > 1 ? MINIBATCH : 2][BLOCK_DI_LENGTH]; __shared__ T_MATH smemGates1[2][WARPS_PER_BLOCK_K * THREADS_PER_WARP_K * K_PER_THREAD]; __shared__ T_MATH smemGates2[GROUP_BATCH_SIZE > 1 ? MINIBATCH : 2][BLOCK_DI_LENGTH * NUM_LINS]; __shared__ T_MATH smemcl[BASIC_RNN ? 1 : MINIBATCH][BASIC_RNN ? 1 : BLOCK_DI_LENGTH]; __shared__ T_MATH smemcr[BASIC_RNN ? 1 : MINIBATCH][BASIC_RNN ? 1 : BLOCK_DI_LENGTH]; __shared__ T_MATH smemdcx[BASIC_RNN ? 1 : MINIBATCH][BASIC_RNN ? 1 : BLOCK_DI_LENGTH]; int warpIdBlock = (threadIdx.x) / 32; int warpIdGlobal = blockIdx.x * WARPS_PER_BLOCK_M * WARPS_PER_BLOCK_K + warpIdBlock; int laneId = (threadIdx.x) % 32; int rowStartBlock; int rowStart; int colStart; rowStartBlock = ((warpIdBlock / WARPS_PER_BLOCK_K) * THREADS_PER_WARP_M + (laneId % THREADS_PER_WARP_M)) * M_PER_THREAD; rowStart = ((warpIdGlobal / WARPS_PER_BLOCK_K) * THREADS_PER_WARP_M + (laneId % THREADS_PER_WARP_M)) * M_PER_THREAD; colStart = (laneId / THREADS_PER_WARP_M) * INNER_UNROLL; colStart += (warpIdBlock % WARPS_PER_BLOCK_K) * (VEC_LENGTH / WARPS_PER_BLOCK_K); const int rowStride = 1; T_MATH T_reg[K_PER_THREAD][M_PER_THREAD]; RNN_persist_load_rMat( T_reg, T, rowStart, colStart, rowStride); if (RNN_MODE == CUDNN_LSTM || RNN_MODE == CUDNN_GRU) { int blockOffset = BLOCK_DI_LENGTH * blockIdx.x; int stepOffset = DIRECTION == 1 ? 0 : ((seqLength - 1) * TOTAL_MINIBATCH) * HIDDEN_SIZE; for (int batch = 0; batch < MINIBATCH; batch++) { int batchOffset = batch * BATCH_STRIDE; #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH && i + BLOCK_DI_LENGTH * blockIdx.x < HIDDEN_SIZE) { if (RNN_MODE == CUDNN_LSTM) { smemdcx[batch][i] = cuGet(c_data[i + blockOffset + batchOffset + stepOffset]); } else { smemdcx[batch][i] = cuGet(tmp_dh[i + blockOffset + batchOffset]); } } } } } for (int i_ = 0; i_ < WARPS_PER_BLOCK_K * THREADS_PER_WARP_K * K_PER_THREAD; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < WARPS_PER_BLOCK_K * THREADS_PER_WARP_K * K_PER_THREAD) { smemGates1[0][i] = cuGet(0); smemGates1[1][i] = cuGet(0); } } __syncthreads(); int absStep = seqLength - 1; int step; if (DIRECTION == 1) step = 0; else step = seqLength - 1; T_MATH readBufferGates1[(VEC_LENGTH + THREADS_PER_BLOCK - 1) / (THREADS_PER_BLOCK)]; T_MATH readBufferGates2[(BLOCK_DI_LENGTH * NUM_LINS + THREADS_PER_BLOCK - 1) / (THREADS_PER_BLOCK)]; T_MATH readBufferi[(BLOCK_DI_LENGTH + THREADS_PER_BLOCK - 1) / (THREADS_PER_BLOCK)]; T_MATH readBufferc[(BLOCK_DI_LENGTH + THREADS_PER_BLOCK - 1) / (THREADS_PER_BLOCK)]; #pragma unroll for (int i = 0; i < (VEC_LENGTH + THREADS_PER_BLOCK - 1) / (THREADS_PER_BLOCK); i++) { readBufferGates1[i] = 0; } #pragma unroll for (int i = 0; i < (BLOCK_DI_LENGTH * NUM_LINS + THREADS_PER_BLOCK - 1) / (THREADS_PER_BLOCK); i++) { readBufferGates2[i] = 0; } #pragma unroll for (int i = 0; i < (BLOCK_DI_LENGTH + THREADS_PER_BLOCK - 1) / (THREADS_PER_BLOCK); i++) { readBufferi[i] = 0; if (RNN_MODE == CUDNN_LSTM || RNN_MODE == CUDNN_GRU) readBufferc[i] = 0; } #pragma unroll for (int i_ = 0; i_ < VEC_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i_ + THREADS_PER_BLOCK >= VEC_LENGTH / NUM_MATS && i >= HIDDEN_SIZE * NUM_MATS) break; int typeOffset = RNN_MODE == CUDNN_GRU ? 3 * HIDDEN_SIZE : 0; readBufferGates1[i_ / (THREADS_PER_BLOCK)] = cuGet(storedResults[(step * TOTAL_MINIBATCH * HIDDEN_SIZE) * NUM_LINS + typeOffset + i]); } #pragma unroll for (int i_ = 0; i_ < VEC_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i_ + THREADS_PER_BLOCK >= VEC_LENGTH && i >= VEC_LENGTH) break; smemGates1[(absStep * MINIBATCH) % 2][i] = readBufferGates1[i_ / (THREADS_PER_BLOCK)]; } if (absStep > 0) { #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH * NUM_LINS; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH * NUM_LINS) { int blockOffset = BLOCK_DI_LENGTH * blockIdx.x; int stepOffset = ((step - (DIRECTION == 1 ? -1 : 1)) * TOTAL_MINIBATCH * HIDDEN_SIZE) * NUM_LINS; int j = (i % NUM_LINS); if ((i / NUM_LINS) + j * HIDDEN_SIZE + blockOffset < HIDDEN_SIZE * NUM_LINS) { readBufferGates2[i_ / (THREADS_PER_BLOCK)] = cuGet(tmp_gates[(i / NUM_LINS) + j * HIDDEN_SIZE + blockOffset + stepOffset]); } } } #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH) { int blockOffset = BLOCK_DI_LENGTH * blockIdx.x; int stepOffset = (step - (DIRECTION == 1 ? -1 : 1)) * TOTAL_MINIBATCH * OUTPUT_STRIDE; if (i + blockOffset < HIDDEN_SIZE) { readBufferi[i_ / (THREADS_PER_BLOCK)] = cuGet(dy[i + blockOffset + stepOffset]); } } } #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH * NUM_LINS; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH * NUM_LINS) { smemGates2[getSmemSectionGates2(0, absStep)][i] = readBufferGates2[i_ / (THREADS_PER_BLOCK)]; } } #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH) { smemi[getSmemSectionI(0, absStep)][i] = readBufferi[i_ / (THREADS_PER_BLOCK)]; } } if (RNN_MODE == CUDNN_LSTM) { for (int batch = 0; batch < MINIBATCH; batch++) { #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH) { int blockOffset = BLOCK_DI_LENGTH * blockIdx.x; int batchOffset = batch * BATCH_STRIDE; int stepOffset = (step - (DIRECTION == 1 ? -1 : 1)) * TOTAL_MINIBATCH * HIDDEN_SIZE; if (i + blockOffset < HIDDEN_SIZE) { readBufferc[i_ / (THREADS_PER_BLOCK)] = cuGet(c_data[i + blockOffset + stepOffset + batchOffset]); } } } #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH) { smemcr[batch][i] = readBufferc[i_ / (THREADS_PER_BLOCK)]; } } } #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH) { int blockOffset = BLOCK_DI_LENGTH * blockIdx.x; int stepOffset = ((step - (DIRECTION == 1 ? -2 : 2)) * TOTAL_MINIBATCH) * HIDDEN_SIZE; if (i + blockOffset < HIDDEN_SIZE) { if (absStep > 1) { readBufferc[i_ / (THREADS_PER_BLOCK)] = cuGet(c_data[i + blockOffset + stepOffset]); } else { readBufferc[i_ / (THREADS_PER_BLOCK)] = cx == NULL ? cuGet(0) : cuGet(cx[i + blockOffset]); } } } } #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH) { smemcl[0][i] = readBufferc[i_ / (THREADS_PER_BLOCK)]; } } } else if (RNN_MODE == CUDNN_GRU) { #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH) { int blockOffset = BLOCK_DI_LENGTH * blockIdx.x; if (i + blockOffset < HIDDEN_SIZE) { if (absStep > 1) { int stepOffset = ((step - (DIRECTION == 1 ? -2 : 2)) * TOTAL_MINIBATCH) * OUTPUT_STRIDE; readBufferc[i_ / (THREADS_PER_BLOCK)] = cuGet(h_data[i + blockOffset + stepOffset]); } else { readBufferc[i_ / (THREADS_PER_BLOCK)] = cx == NULL ? cuGet(0) : cuGet(cx[i + blockOffset]); } } } } #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH) { smemcl[0][i] = readBufferc[i_ / (THREADS_PER_BLOCK)]; } } } } __syncthreads(); for (int absStep = seqLength - 1; absStep >= 0; absStep--) { if (DIRECTION == 1) step = seqLength - absStep - 1; else step = absStep; for (int batch = 0; batch < MINIBATCH; batch++) { int final = absStep == 0 && batch == MINIBATCH - 1; if (!final) { #pragma unroll for (int i_ = 0; i_ < VEC_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i_ + THREADS_PER_BLOCK >= VEC_LENGTH / NUM_MATS && i >= HIDDEN_SIZE * NUM_MATS) break; int typeOffset = RNN_MODE == CUDNN_GRU ? 3 * HIDDEN_SIZE : 0; int index; if (DIRECTION == 1) { index = (step + 1) * TOTAL_MINIBATCH * HIDDEN_SIZE; if (batch != MINIBATCH - 1) { index += (batch + 1) * BATCH_STRIDE - TOTAL_MINIBATCH * HIDDEN_SIZE; } } else { index = (step - 1) * TOTAL_MINIBATCH * HIDDEN_SIZE; if (batch != MINIBATCH - 1) { index += (batch + 1) * BATCH_STRIDE + TOTAL_MINIBATCH * HIDDEN_SIZE; } } index *= NUM_LINS; index += typeOffset + i; readBufferGates1[i_ / (THREADS_PER_BLOCK)] = cuGet(storedResults[index]); } } if (absStep > 0) { if (!(absStep == 1 && batch == MINIBATCH - 1)) { #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH * NUM_LINS; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH * NUM_LINS) { int blockOffset = BLOCK_DI_LENGTH * blockIdx.x; int batchOffset = (batch + 1) * BATCH_STRIDE * NUM_LINS; int stepOffset = ((step - (DIRECTION == 1 ? -1 : 1)) * TOTAL_MINIBATCH * HIDDEN_SIZE) * NUM_LINS; if (batch == MINIBATCH - 1) { if (DIRECTION == 0) batchOffset = -TOTAL_MINIBATCH * HIDDEN_SIZE * NUM_LINS; else batchOffset = TOTAL_MINIBATCH * HIDDEN_SIZE * NUM_LINS; } int j = (i % NUM_LINS); if ((i / NUM_LINS) + j * HIDDEN_SIZE + blockOffset < HIDDEN_SIZE * NUM_LINS) { int index = (i / NUM_LINS) + j * HIDDEN_SIZE + blockOffset + stepOffset + batchOffset; readBufferGates2[i_ / (THREADS_PER_BLOCK)] = cuGet(tmp_gates[index]); } } } #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH) { int blockOffset = BLOCK_DI_LENGTH * blockIdx.x; int batchOffset = (batch + 1) * BATCH_OUTPUT_STRIDE; int stepOffset = (step - (DIRECTION == 1 ? -1 : 1)) * TOTAL_MINIBATCH * OUTPUT_STRIDE; if (batch == MINIBATCH - 1) { if (DIRECTION == 0) batchOffset = -TOTAL_MINIBATCH * OUTPUT_STRIDE; else batchOffset = TOTAL_MINIBATCH * OUTPUT_STRIDE; } if (i + blockOffset < HIDDEN_SIZE) { readBufferi[i_ / (THREADS_PER_BLOCK)] = cuGet(dy[i + blockOffset + stepOffset + batchOffset]); } } } if (RNN_MODE == CUDNN_LSTM) { #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH) { int blockOffset = BLOCK_DI_LENGTH * blockIdx.x; int batchOffset = (batch + 1) * BATCH_STRIDE; int stepOffset = ((step - (DIRECTION == 1 ? -2 : 2)) * TOTAL_MINIBATCH) * HIDDEN_SIZE; if (i + blockOffset < HIDDEN_SIZE) { if (absStep <= 1 || (absStep == 2 && batch == MINIBATCH - 1)) { if (batch == MINIBATCH - 1) { batchOffset = 0; } readBufferc[i_ / (THREADS_PER_BLOCK)] = cx == NULL ? cuGet(0) : cuGet(cx[i + blockOffset + batchOffset]); } else { if (batch == MINIBATCH - 1) { if (DIRECTION == 0) batchOffset = -TOTAL_MINIBATCH * HIDDEN_SIZE; else batchOffset = TOTAL_MINIBATCH * HIDDEN_SIZE; } readBufferc[i_ / (THREADS_PER_BLOCK)] = cuGet(c_data[i + blockOffset + batchOffset + stepOffset]); } } } } } else if (RNN_MODE == CUDNN_GRU) { #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH) { int blockOffset = BLOCK_DI_LENGTH * blockIdx.x; if (i + blockOffset < HIDDEN_SIZE) { if (absStep <= 1 || (absStep == 2 && batch == MINIBATCH - 1)) { int batchOffset = (batch + 1) * BATCH_STRIDE; if (batch == MINIBATCH - 1) { batchOffset = 0; } readBufferc[i_ / (THREADS_PER_BLOCK)] = cx == NULL ? cuGet(0) : cuGet(cx[i + blockOffset + batchOffset]); } else { int stepOffset = ((step - (DIRECTION == 1 ? -2 : 2)) * TOTAL_MINIBATCH) * OUTPUT_STRIDE; int batchOffset = (batch + 1) * BATCH_OUTPUT_STRIDE; if (batch == MINIBATCH - 1) { if (DIRECTION == 0) batchOffset = -TOTAL_MINIBATCH * OUTPUT_STRIDE; else batchOffset = TOTAL_MINIBATCH * OUTPUT_STRIDE; } readBufferc[i_ / (THREADS_PER_BLOCK)] = cuGet(h_data[i + blockOffset + batchOffset + stepOffset]); } } } } } } } T_MATH accumulator[M_PER_THREAD][INNER_UNROLL]; RNN_persist_GEMM( T_reg, accumulator, smemGates1[((batch + absStep * MINIBATCH) % 2)], rowStartBlock, colStart); if (seqLength > 1 && (absStep < seqLength - 1 || batch == MINIBATCH - 1)) { int reloadGates = false; #pragma unroll for (int i_ = 0; i_ < VEC_LENGTH; i_ += THREADS_PER_BLOCK) { if (isNegativeZero(readBufferGates1[i_ / THREADS_PER_BLOCK])) { reloadGates = true; } } while (reloadGates) { #pragma unroll for (int i_ = 0; i_ < VEC_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i_ + THREADS_PER_BLOCK >= VEC_LENGTH / NUM_MATS && i >= HIDDEN_SIZE * NUM_MATS) break; int typeOffset = RNN_MODE == CUDNN_GRU ? 3 * HIDDEN_SIZE : 0; int index; if (DIRECTION == 1) { index = (step + 1) * TOTAL_MINIBATCH * HIDDEN_SIZE; if (batch != MINIBATCH - 1) { index += (batch + 1) * BATCH_STRIDE - TOTAL_MINIBATCH * HIDDEN_SIZE; } } else { index = (step - 1) * TOTAL_MINIBATCH * HIDDEN_SIZE; if (batch != MINIBATCH - 1) { index += (batch + 1) * BATCH_STRIDE + TOTAL_MINIBATCH * HIDDEN_SIZE; } } index *= NUM_LINS; index += typeOffset + i; readBufferGates1[i_ / THREADS_PER_BLOCK] = cuGet(loadVolatile(storedResults, index)); } reloadGates = false; #pragma unroll for (int i_ = 0; i_ < VEC_LENGTH; i_ += THREADS_PER_BLOCK) { if (isNegativeZero(readBufferGates1[i_ / THREADS_PER_BLOCK])) { reloadGates = true; break; } } } } #pragma unroll for (int i_ = 0; i_ < VEC_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i_ + THREADS_PER_BLOCK >= VEC_LENGTH && i >= VEC_LENGTH) break; smemGates1[1 - ((batch + absStep * MINIBATCH) % 2)][i] = readBufferGates1[i_ / (THREADS_PER_BLOCK)]; } if (absStep > 0) { #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH * NUM_LINS; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH * NUM_LINS) { smemGates2[getSmemSectionGates2(batch + 1, absStep)][i] = readBufferGates2[i_ / (THREADS_PER_BLOCK)]; } } #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH) { smemi[getSmemSectionI(batch + 1, absStep)][i] = readBufferi[i_ / (THREADS_PER_BLOCK)]; } } if (RNN_MODE == CUDNN_LSTM) { if (absStep != seqLength - 1 || batch == MINIBATCH - 1) { for (int i_ = 0; i_ < BLOCK_DI_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH) { smemcr[batch + 1 == MINIBATCH ? 0 : (batch + 1)][i] = smemcl[batch + 1 == MINIBATCH ? 0 : (batch + 1)][i]; } } } #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH) { smemcl[batch + 1 == MINIBATCH ? 0 : (batch + 1)][i] = readBufferc[i_ / (THREADS_PER_BLOCK)]; } } } else if (RNN_MODE == CUDNN_GRU) { #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH) { smemcl[batch + 1 == MINIBATCH ? 0 : (batch + 1)][i] = readBufferc[i_ / (THREADS_PER_BLOCK)]; } } } } if (absStep > 0) { #if (defined(GROUP_BATCH_SIZE) && GROUP_BATCH_SIZE > 1) __shared__ T_MATH smemBatchGroup[GROUP_BATCH_SIZE - 1][BLOCK_DI_LENGTH][M_PER_THREAD]; if (batch % GROUP_BATCH_SIZE == GROUP_BATCH_SIZE - 1 && colStart < GROUP_BATCH_SIZE) { int writeBatch = batch - colStart; if (colStart != 0) { #pragma unroll for (int j = 0; j < M_PER_THREAD; j++) { accumulator[j][0] = smemBatchGroup[(colStart - 1)][rowStartBlock][j]; } } #pragma unroll for (int j = 0; j < M_PER_THREAD; j++) { accumulator[j][0] += smemi[getSmemSectionI(writeBatch, absStep)][rowStartBlock + j]; } if (RNN_MODE == CUDNN_LSTM) { #pragma unroll for (int j = 0; j < M_PER_THREAD; j++) { int row = rowStart + j; if (row >= HIDDEN_SIZE) break; int linGatesBaseIndex = row + ((step - (DIRECTION == 1 ? -1 : 1)) * TOTAL_MINIBATCH * HIDDEN_SIZE + writeBatch * BATCH_STRIDE) * 4; T_MATH linear_in_gate = smemGates2[getSmemSectionGates2(writeBatch, absStep)][(rowStartBlock + j) * NUM_MATS + 0]; T_MATH linear_forget_gate = smemGates2[getSmemSectionGates2(writeBatch, absStep)][(rowStartBlock + j) * NUM_MATS + 1]; T_MATH linear_in_gate2 = smemGates2[getSmemSectionGates2(writeBatch, absStep)][(rowStartBlock + j) * NUM_MATS + 2]; T_MATH linear_out_gate = smemGates2[getSmemSectionGates2(writeBatch, absStep)][(rowStartBlock + j) * NUM_MATS + 3]; T_MATH deltaY = accumulator[j][0]; T_MATH _cy = cuGet(smemcr[writeBatch][rowStartBlock + j]); T_MATH dclip_r = dclip(_cy, clipopt, nanopt, lclip, rclip); T_MATH dclip_l = dclip(smemcl[writeBatch][rowStartBlock + j], clipopt, nanopt, lclip, rclip); T_MATH deltaC = deltaY * (linear_out_gate)*dtanh(clip(_cy, clipopt, nanopt, lclip, rclip)) * dclip_r + smemdcx[writeBatch][rowStartBlock + j]; T_MATH out0 = deltaC * (linear_in_gate2)*dsigmoid_2(linear_in_gate); T_MATH out1 = deltaC * clip(smemcl[writeBatch][rowStartBlock + j], clipopt, nanopt, lclip, rclip) * dsigmoid_2(linear_forget_gate); T_MATH out2 = deltaC * (linear_in_gate)*dtanh_2(linear_in_gate2); T_MATH out3 = deltaY * _tanh(clip(_cy, clipopt, nanopt, lclip, rclip)) * dsigmoid_2(linear_out_gate); storedResults[linGatesBaseIndex + 0 * HIDDEN_SIZE] = getSafeOutput(out0); storedResults[linGatesBaseIndex + 1 * HIDDEN_SIZE] = getSafeOutput(out1); storedResults[linGatesBaseIndex + 2 * HIDDEN_SIZE] = getSafeOutput(out2); storedResults[linGatesBaseIndex + 3 * HIDDEN_SIZE] = getSafeOutput(out3); smemdcx[writeBatch][rowStartBlock + j] = deltaC * (linear_forget_gate)*dclip_l; } } else if (RNN_MODE == CUDNN_GRU) { #pragma unroll for (int j = 0; j < M_PER_THREAD; j++) { int row = rowStart + j; if (row >= HIDDEN_SIZE) break; int linGatesBaseIndex = row + ((step - (DIRECTION == 1 ? -1 : 1)) * TOTAL_MINIBATCH * HIDDEN_SIZE + writeBatch * BATCH_STRIDE) * 6; T_MATH reset_gate = smemGates2[getSmemSectionGates2(writeBatch, absStep)][(rowStartBlock + j) * NUM_LINS + 0]; T_MATH update_gate = smemGates2[getSmemSectionGates2(writeBatch, absStep)][(rowStartBlock + j) * NUM_LINS + 1]; T_MATH new_gate_i = smemGates2[getSmemSectionGates2(writeBatch, absStep)][(rowStartBlock + j) * NUM_LINS + 2]; T_MATH new_gate_h = smemGates2[getSmemSectionGates2(writeBatch, absStep)][(rowStartBlock + j) * NUM_LINS + 3]; T_MATH deltaY = accumulator[j][0] + smemdcx[writeBatch][rowStartBlock + j]; smemdcx[writeBatch][rowStartBlock + j] = deltaY * (update_gate); T_MATH lin_output_new_gate = (reset_gate)*new_gate_h + new_gate_i; T_MATH output_new_gate = _tanh(lin_output_new_gate); T_MATH delta_not = deltaY * output_new_gate; T_MATH hl = cuGet(smemcl[writeBatch][rowStartBlock + j]); T_MATH delta_update_gate = (deltaY * hl - delta_not) * dsigmoid_2(update_gate); T_MATH dh_ = deltaY * (cuGet(1) - (update_gate)); T_MATH dtanh_new_gate = dh_ * dtanh(lin_output_new_gate); T_MATH delta_new_gate_i = dtanh_new_gate; T_MATH delta_new_gate_h = dtanh_new_gate * (reset_gate); T_MATH delta_reset_gate = dtanh_new_gate * new_gate_h * dsigmoid_2(reset_gate); storedResults[linGatesBaseIndex + 0 * HIDDEN_SIZE] = getSafeOutput(delta_reset_gate); storedResults[linGatesBaseIndex + 1 * HIDDEN_SIZE] = getSafeOutput(delta_update_gate); storedResults[linGatesBaseIndex + 2 * HIDDEN_SIZE] = getSafeOutput(delta_new_gate_i); storedResults[linGatesBaseIndex + 3 * HIDDEN_SIZE] = getSafeOutput(delta_reset_gate); storedResults[linGatesBaseIndex + 4 * HIDDEN_SIZE] = getSafeOutput(delta_update_gate); storedResults[linGatesBaseIndex + 5 * HIDDEN_SIZE] = getSafeOutput(delta_new_gate_h); } } else { #pragma unroll for (int j = 0; j < M_PER_THREAD; j++) { if (RNN_MODE == CUDNN_RNN_RELU) { accumulator[j][0] = smemGates2[getSmemSectionGates2(writeBatch, absStep)][rowStartBlock + j] < cuGet(0) ? cuGet(0) : accumulator[j][0]; } else if (RNN_MODE == CUDNN_RNN_TANH) { accumulator[j][0] *= dtanh_2(smemGates2[getSmemSectionGates2(writeBatch, absStep)][rowStartBlock + j]); } } #pragma unroll for (int j = 0; j < M_PER_THREAD; j++) { int row = rowStart + j; if (row >= HIDDEN_SIZE) break; int index = row + ((step - (DIRECTION == 1 ? -1 : 1)) * TOTAL_MINIBATCH * HIDDEN_SIZE + writeBatch * BATCH_STRIDE); storedResults[index] = getSafeOutput(accumulator[j][0]); } } } else if (colStart == 0) { #pragma unroll for (int j = 0; j < M_PER_THREAD; j++) { smemBatchGroup[GROUP_BATCH_SIZE - 2 - (batch % GROUP_BATCH_SIZE)][rowStartBlock][j] = accumulator[j][0]; } } #else if (colStart == 0) { if (RNN_MODE == CUDNN_LSTM) { #pragma unroll for (int j = 0; j < M_PER_THREAD; j++) { accumulator[j][0] += smemi[getSmemSectionI(batch, absStep)][rowStartBlock + j]; } #pragma unroll for (int j = 0; j < M_PER_THREAD; j++) { int row = rowStart + j; if (row >= HIDDEN_SIZE) break; int linGatesBaseIndex = row + ((step - (DIRECTION == 1 ? -1 : 1)) * TOTAL_MINIBATCH * HIDDEN_SIZE + batch * BATCH_STRIDE) * 4; T_MATH linear_in_gate = smemGates2[getSmemSectionGates2(batch, absStep)][(rowStartBlock + j) * NUM_MATS + 0]; T_MATH linear_forget_gate = smemGates2[getSmemSectionGates2(batch, absStep)][(rowStartBlock + j) * NUM_MATS + 1]; T_MATH linear_in_gate2 = smemGates2[getSmemSectionGates2(batch, absStep)][(rowStartBlock + j) * NUM_MATS + 2]; T_MATH linear_out_gate = smemGates2[getSmemSectionGates2(batch, absStep)][(rowStartBlock + j) * NUM_MATS + 3]; T_MATH deltaY = accumulator[j][0]; T_MATH _cy = cuGet(smemcr[batch][rowStartBlock + j]); T_MATH dclip_r = dclip(_cy, clipopt, nanopt, lclip, rclip); T_MATH dclip_l = dclip(smemcl[batch][rowStartBlock + j], clipopt, nanopt, lclip, rclip); T_MATH deltaC = deltaY * (linear_out_gate)*dtanh(clip(_cy, clipopt, nanopt, lclip, rclip)) * dclip_r + smemdcx[batch][rowStartBlock + j]; T_MATH out0 = deltaC * (linear_in_gate2)*dsigmoid_2(linear_in_gate); T_MATH out1 = deltaC * clip(smemcl[batch][rowStartBlock + j], clipopt, nanopt, lclip, rclip) * dsigmoid_2(linear_forget_gate); T_MATH out2 = deltaC * (linear_in_gate)*dtanh_2(linear_in_gate2); T_MATH out3 = deltaY * _tanh(clip(_cy, clipopt, nanopt, lclip, rclip)) * dsigmoid_2(linear_out_gate); storedResults[linGatesBaseIndex + 0 * HIDDEN_SIZE] = getSafeOutput(out0); storedResults[linGatesBaseIndex + 1 * HIDDEN_SIZE] = getSafeOutput(out1); storedResults[linGatesBaseIndex + 2 * HIDDEN_SIZE] = getSafeOutput(out2); storedResults[linGatesBaseIndex + 3 * HIDDEN_SIZE] = getSafeOutput(out3); smemdcx[batch][rowStartBlock + j] = deltaC * (linear_forget_gate)*dclip_l; } } else if (RNN_MODE == CUDNN_GRU) { #pragma unroll for (int j = 0; j < M_PER_THREAD; j++) { accumulator[j][0] += smemi[getSmemSectionI(batch, absStep)][rowStartBlock + j]; } #pragma unroll for (int j = 0; j < M_PER_THREAD; j++) { int row = rowStart + j; if (row >= HIDDEN_SIZE) break; int linGatesBaseIndex = row + ((step - (DIRECTION == 1 ? -1 : 1)) * TOTAL_MINIBATCH * HIDDEN_SIZE + batch * BATCH_STRIDE) * 6; T_MATH reset_gate = smemGates2[getSmemSectionGates2(batch, absStep)][(rowStartBlock + j) * NUM_LINS + 0]; T_MATH update_gate = smemGates2[getSmemSectionGates2(batch, absStep)][(rowStartBlock + j) * NUM_LINS + 1]; T_MATH new_gate_i = smemGates2[getSmemSectionGates2(batch, absStep)][(rowStartBlock + j) * NUM_LINS + 2]; T_MATH new_gate_h = smemGates2[getSmemSectionGates2(batch, absStep)][(rowStartBlock + j) * NUM_LINS + 3]; T_MATH deltaY = accumulator[j][0] + smemdcx[batch][rowStartBlock + j]; smemdcx[batch][rowStartBlock + j] = deltaY * (update_gate); T_MATH lin_output_new_gate = (reset_gate)*new_gate_h + new_gate_i; T_MATH output_new_gate = _tanh(lin_output_new_gate); T_MATH delta_not = deltaY * output_new_gate; T_MATH hl = cuGet(smemcl[batch][rowStartBlock + j]); T_MATH delta_update_gate = (deltaY * hl - delta_not) * dsigmoid_2(update_gate); T_MATH dh_ = deltaY * (cuGet(1) - (update_gate)); T_MATH dtanh_new_gate = dh_ * dtanh(lin_output_new_gate); T_MATH delta_new_gate_i = dtanh_new_gate; T_MATH delta_new_gate_h = dtanh_new_gate * (reset_gate); T_MATH delta_reset_gate = dtanh_new_gate * new_gate_h * dsigmoid_2(reset_gate); storedResults[linGatesBaseIndex + 0 * HIDDEN_SIZE] = getSafeOutput(delta_reset_gate); storedResults[linGatesBaseIndex + 1 * HIDDEN_SIZE] = getSafeOutput(delta_update_gate); storedResults[linGatesBaseIndex + 2 * HIDDEN_SIZE] = getSafeOutput(delta_new_gate_i); storedResults[linGatesBaseIndex + 3 * HIDDEN_SIZE] = getSafeOutput(delta_reset_gate); storedResults[linGatesBaseIndex + 4 * HIDDEN_SIZE] = getSafeOutput(delta_update_gate); storedResults[linGatesBaseIndex + 5 * HIDDEN_SIZE] = getSafeOutput(delta_new_gate_h); } } else { #pragma unroll for (int j = 0; j < M_PER_THREAD; j++) { accumulator[j][0] += smemi[getSmemSectionI(batch, absStep)][rowStartBlock + j]; } #pragma unroll for (int j = 0; j < M_PER_THREAD; j++) { if (RNN_MODE == CUDNN_RNN_RELU) { accumulator[j][0] = smemGates2[getSmemSectionGates2(batch, absStep)][rowStartBlock + j] < cuGet(0) ? cuGet(0) : accumulator[j][0]; } else if (RNN_MODE == CUDNN_RNN_TANH) { accumulator[j][0] *= dtanh_2(smemGates2[getSmemSectionGates2(batch, absStep)][rowStartBlock + j]); } } #pragma unroll for (int j = 0; j < M_PER_THREAD; j++) { int row = rowStart + j; if (row >= HIDDEN_SIZE) break; int index = row + ((step - (DIRECTION == 1 ? -1 : 1)) * TOTAL_MINIBATCH * HIDDEN_SIZE + batch * BATCH_STRIDE); storedResults[index] = getSafeOutput(accumulator[j][0]); } } } #endif } else { if (colStart == 0 && dhx != NULL) { #pragma unroll for (int j = 0; j < M_PER_THREAD; j++) { int row = rowStart + j; if (row >= HIDDEN_SIZE) break; if (RNN_MODE == CUDNN_GRU) { dhx[row + batch * BATCH_STRIDE] = cuGet(accumulator[j][0] + smemdcx[batch][rowStartBlock + j]); } else { dhx[row + batch * BATCH_STRIDE] = cuGet(accumulator[j][0]); } } } if (RNN_MODE == CUDNN_LSTM) { if (seqLength > 1) { __syncthreads(); if (colStart == 0 && dcx != NULL) { #pragma unroll for (int j = 0; j < M_PER_THREAD; j++) { int row = rowStart + j; ; if (row >= HIDDEN_SIZE) break; dcx[row + batch * BATCH_STRIDE] = cuGet(smemdcx[batch][rowStartBlock + j]); } } } } } __syncthreads(); } } }