* WARPS_PER_BLOCK_M * WARPS_PER_BLOCK_K + warpIdBlock; rowStartBlock = ((warpIdBlock / WARPS_PER_BLOCK_K) * THREADS_PER_WARP_M + (laneId % THREADS_PER_WARP_M)) * M_PER_THREAD_MAT; // rowStart = ((warpIdGlobal / WARPS_PER_BLOCK_K) * THREADS_PER_WARP_M + (laneId % THREADS_PER_WARP_M)) * // M_PER_THREAD_MAT; // rowStart = ((warpIdGlobal * THREADS_PER_WARP_M) + (laneId % THREADS_PER_WARP_M)) * M_PER_THREAD_MAT; // rowStart -= k_block * (gridDim.x / MATRIX_SPLIT_K_FACTOR) * WARPS_PER_BLOCK_M * WARPS_PER_BLOCK_K * // THREADS_PER_WARP_M * M_PER_THREAD_MAT; rowStart = ((warpIdM * THREADS_PER_WARP_M) + (laneId % THREADS_PER_WARP_M)) * M_PER_THREAD_MAT; ; rowStartBlockAct = threadIdx.x * M_PER_THREAD_ACT; rowStartAct = blockIdx.x * BLOCK_DI_LENGTH_ACT + threadIdx.x * M_PER_THREAD_ACT; colStart = (laneId / THREADS_PER_WARP_M) * INNER_UNROLL; T_MATH rMat_reg[K_PER_THREAD_MAT][M_PER_THREAD_MAT]; RNN_persist_load_rMat(rMat_reg, rMat + k_block * (HIDDEN_SIZE * HIDDEN_SIZE * NUM_MATS) / MATRIX_SPLIT_K_FACTOR, rowStart, colStart, 1); if (RNN_MODE == CUDNN_LSTM || RNN_MODE == CUDNN_GRU) { int blockOffset = BLOCK_DI_LENGTH_ACT * blockIdx.x; int stepOffset = DIRECTION == 1 ? 0 : (seqLength - 1) * TOTAL_MINIBATCH * HIDDEN_SIZE; for (int batch = 0; batch < MINIBATCH; batch++) { int batchOffset = batch * BATCH_STRIDE; #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH_ACT; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH_ACT && i + blockOffset < HIDDEN_SIZE) { if (RNN_MODE == CUDNN_LSTM) { smemdcx[batch][i] = cuGet(c_data[i + blockOffset + batchOffset + stepOffset]); } else { smemdcx[batch][i] = cuGet(tmp_dh[i + blockOffset + batchOffset]); } } } } } for (int i_ = 0; i_ < WARPS_PER_BLOCK_K * THREADS_PER_WARP_K * K_PER_THREAD_MAT; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < WARPS_PER_BLOCK_K * THREADS_PER_WARP_K * K_PER_THREAD_MAT) { smemGates1[0][i] = cuGet(0); smemGates1[1][i] = cuGet(0); } } int absStep = seqLength - 1; int step; if (DIRECTION == 1) step = 0; else step = seqLength - 1; T_MATH readBufferGates1[(VEC_LENGTH_MAT + THREADS_PER_BLOCK - 1) / (THREADS_PER_BLOCK)]; T_MATH readBufferGates2[(BLOCK_DI_LENGTH_ACT * NUM_LINS + THREADS_PER_BLOCK - 1) / (THREADS_PER_BLOCK)]; T_MATH readBufferi[(BLOCK_DI_LENGTH_ACT + THREADS_PER_BLOCK - 1) / (THREADS_PER_BLOCK)]; T_MATH readBufferc[(BLOCK_DI_LENGTH_ACT + THREADS_PER_BLOCK - 1) / (THREADS_PER_BLOCK)]; T_MATH readBufferTmpResults[(BLOCK_DI_LENGTH_ACT * MATRIX_SPLIT_K_FACTOR + THREADS_PER_BLOCK - 1) / (THREADS_PER_BLOCK)]; #pragma unroll for (int i = 0; i < (VEC_LENGTH_MAT + THREADS_PER_BLOCK - 1) / (THREADS_PER_BLOCK); i++) { readBufferGates1[i] = 0; } #pragma unroll for (int i = 0; i < (BLOCK_DI_LENGTH_ACT * NUM_LINS + THREADS_PER_BLOCK - 1) / (THREADS_PER_BLOCK); i++) { readBufferGates2[i] = 0; } #pragma unroll for (int i = 0; i < (BLOCK_DI_LENGTH_ACT + THREADS_PER_BLOCK - 1) / (THREADS_PER_BLOCK); i++) { readBufferi[i] = 0; if (RNN_MODE == CUDNN_LSTM || RNN_MODE == CUDNN_GRU) readBufferc[i] = 0; } #pragma unroll for (int i = 0; i < (BLOCK_DI_LENGTH_ACT * MATRIX_SPLIT_K_FACTOR + THREADS_PER_BLOCK - 1) / (THREADS_PER_BLOCK); i++) { readBufferTmpResults[i] = 0; } #pragma unroll for (int i_ = 0; i_ < VEC_LENGTH_MAT; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i_ + THREADS_PER_BLOCK * NUM_MATS >= VEC_LENGTH_MAT && i >= (HIDDEN_SIZE * NUM_MATS) / MATRIX_SPLIT_K_FACTOR) break; int typeOffset = RNN_MODE == CUDNN_GRU ? 3 * HIDDEN_SIZE : 0; int index = step * TOTAL_MINIBATCH * HIDDEN_SIZE; index *= NUM_LINS; index += k_block * NUM_MATS * HIDDEN_SIZE / MATRIX_SPLIT_K_FACTOR; index += typeOffset + i; readBufferGates1[i_ / (THREADS_PER_BLOCK)] = cuGet(storedResults[index]); } #pragma unroll for (int i_ = 0; i_ < VEC_LENGTH_MAT; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i_ + THREADS_PER_BLOCK >= VEC_LENGTH_MAT && i >= VEC_LENGTH_MAT) break; smemGates1[1 - (absStep * MINIBATCH) % 2][i] = readBufferGates1[i_ / (THREADS_PER_BLOCK)]; } if (absStep > 0) { if (RNN_MODE == CUDNN_LSTM) { for (int batch = 0; batch < MINIBATCH; batch++) { #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH_ACT; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH_ACT) { int blockOffset = BLOCK_DI_LENGTH_ACT * blockIdx.x; int batchOffset = batch * BATCH_STRIDE; int stepOffset = (step - (DIRECTION == 1 ? -1 : 1)) * TOTAL_MINIBATCH * HIDDEN_SIZE; if (i + blockOffset < HIDDEN_SIZE) { readBufferc[i_ / (THREADS_PER_BLOCK)] = cuGet(c_data[i + blockOffset + stepOffset + batchOffset]); } } } #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH_ACT; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH_ACT) { smemcr[batch][i] = readBufferc[i_ / (THREADS_PER_BLOCK)]; } } } } } __syncthreads(); #pragma unroll for (int i_ = 0; i_ < VEC_LENGTH_MAT; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i_ + THREADS_PER_BLOCK * NUM_MATS >= VEC_LENGTH_MAT && i >= (HIDDEN_SIZE * NUM_MATS) / MATRIX_SPLIT_K_FACTOR) break; int typeOffset = RNN_MODE == CUDNN_GRU ? 3 * HIDDEN_SIZE : 0; int index = step * TOTAL_MINIBATCH * HIDDEN_SIZE + BATCH_STRIDE; index *= NUM_LINS; index += k_block * NUM_MATS * HIDDEN_SIZE / MATRIX_SPLIT_K_FACTOR; index += typeOffset + i; readBufferGates1[i_ / (THREADS_PER_BLOCK)] = cuGet(storedResults[index]); } T_MATH accumulator[M_PER_THREAD_MAT][INNER_UNROLL]; RNN_persist_GEMM( rMat_reg, accumulator, smemGates1[1 - ((0 + absStep * MINIBATCH) % 2)], rowStartBlock, colStart); #pragma unroll for (int i_ = 0; i_ < VEC_LENGTH_MAT; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i_ + THREADS_PER_BLOCK >= VEC_LENGTH_MAT && i >= VEC_LENGTH_MAT) break; smemGates1[(0 + absStep * MINIBATCH) % 2][i] = readBufferGates1[i_ / (THREADS_PER_BLOCK)]; } if (colStart == 0) { #pragma unroll for (int j = 0; j < M_PER_THREAD_MAT; j++) { int row = rowStart + j; int batchOffset = 0 * BATCH_STRIDE * PERSISTENT_DGRAD_MAX_SPLIT_K; int stepOffset = step * TOTAL_MINIBATCH * HIDDEN_SIZE * PERSISTENT_DGRAD_MAX_SPLIT_K; if (row >= HIDDEN_SIZE) break; tmp_results[row + k_block * HIDDEN_SIZE + batchOffset + stepOffset] = getSafeOutput(accumulator[j][0]); } } __syncthreads(); for (int absStep = seqLength - 1; absStep >= 0; absStep--) { if (DIRECTION == 1) step = seqLength - absStep - 1; else step = absStep; for (int batch = 0; batch < MINIBATCH; batch++) { int final = absStep == 0 && batch == MINIBATCH - 1; if (!final) { if (!(absStep == 0 && batch >= MINIBATCH - 2)) { #pragma unroll for (int i_ = 0; i_ < VEC_LENGTH_MAT; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (K_PER_THREAD % MATRIX_SPLIT_K_FACTOR == 0) { if (i_ + THREADS_PER_BLOCK * NUM_MATS >= VEC_LENGTH_MAT && i >= HIDDEN_SIZE * NUM_MATS / MATRIX_SPLIT_K_FACTOR) break; } else { if (i_ + THREADS_PER_BLOCK + THREADS_PER_BLOCK * NUM_MATS >= VEC_LENGTH_MAT && i >= HIDDEN_SIZE * NUM_MATS / MATRIX_SPLIT_K_FACTOR) break; } int typeOffset = RNN_MODE == CUDNN_GRU ? 3 * HIDDEN_SIZE : 0; int index; if (DIRECTION == 1) { index = (step + 1) * TOTAL_MINIBATCH * HIDDEN_SIZE + BATCH_STRIDE * (2 - (MINIBATCH - batch)); if (batch < MINIBATCH - 2) { index += (batch + 2) * BATCH_STRIDE - TOTAL_MINIBATCH * HIDDEN_SIZE - BATCH_STRIDE * (2 - (MINIBATCH - batch)); } } else { index = (step - 1) * TOTAL_MINIBATCH * HIDDEN_SIZE + BATCH_STRIDE * (2 - (MINIBATCH - batch)); if (batch < MINIBATCH - 2) { index += (batch + 2) * BATCH_STRIDE + TOTAL_MINIBATCH * HIDDEN_SIZE - BATCH_STRIDE * (2 - (MINIBATCH - batch)); } } index *= NUM_LINS; index += k_block * NUM_MATS * HIDDEN_SIZE / MATRIX_SPLIT_K_FACTOR; index += typeOffset + i; readBufferGates1[i_ / (THREADS_PER_BLOCK)] = cuGet(storedResults[index]); } } } #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH_ACT * MATRIX_SPLIT_K_FACTOR; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; int j = i % MATRIX_SPLIT_K_FACTOR; i = i / MATRIX_SPLIT_K_FACTOR; int blockOffset = BLOCK_DI_LENGTH_ACT * blockIdx.x; if (i + blockOffset < HIDDEN_SIZE) { int batchOffset = batch * BATCH_STRIDE * PERSISTENT_DGRAD_MAX_SPLIT_K; int stepOffset = step * TOTAL_MINIBATCH * HIDDEN_SIZE * PERSISTENT_DGRAD_MAX_SPLIT_K; readBufferTmpResults[i_ / (THREADS_PER_BLOCK)] = cuGet(tmp_results[i + j * HIDDEN_SIZE + blockOffset + stepOffset + batchOffset]); } } if (absStep > 0) { #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH_ACT * NUM_LINS; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH_ACT * NUM_LINS) { int blockOffset = BLOCK_DI_LENGTH_ACT * blockIdx.x; int batchOffset = (batch)*BATCH_STRIDE * NUM_LINS; int stepOffset = ((step - (DIRECTION == 1 ? -1 : 1)) * TOTAL_MINIBATCH * HIDDEN_SIZE) * NUM_LINS; int j = (i % NUM_LINS); if ((i / NUM_LINS) + j * HIDDEN_SIZE + blockOffset < HIDDEN_SIZE * NUM_LINS) { readBufferGates2[i_ / (THREADS_PER_BLOCK)] = cuGet( tmp_gates[(i / NUM_LINS) + j * HIDDEN_SIZE + blockOffset + stepOffset + batchOffset]); } } } #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH_ACT; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH_ACT) { int blockOffset = BLOCK_DI_LENGTH_ACT * blockIdx.x; int batchOffset = (batch)*BATCH_OUTPUT_STRIDE; int stepOffset = (step - (DIRECTION == 1 ? -1 : 1)) * TOTAL_MINIBATCH * OUTPUT_STRIDE; if (i + blockOffset < HIDDEN_SIZE) { readBufferi[i_ / (THREADS_PER_BLOCK)] = cuGet(dy[i + blockOffset + stepOffset + batchOffset]); } } } if (RNN_MODE == CUDNN_LSTM) { #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH_ACT; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH_ACT) { int blockOffset = BLOCK_DI_LENGTH_ACT * blockIdx.x; int batchOffset = (batch)*BATCH_STRIDE; int stepOffset = (step - (DIRECTION == 1 ? -2 : 2)) * TOTAL_MINIBATCH * HIDDEN_SIZE; if (i + blockOffset < HIDDEN_SIZE) { if (absStep == 1) { readBufferc[i_ / (THREADS_PER_BLOCK)] = cx == NULL ? cuGet(0) : cuGet(cx[i + blockOffset + batchOffset]); } else { readBufferc[i_ / (THREADS_PER_BLOCK)] = cuGet(c_data[i + blockOffset + batchOffset + stepOffset]); } } } } } else if (RNN_MODE == CUDNN_GRU) { #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH_ACT; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH_ACT) { int blockOffset = BLOCK_DI_LENGTH_ACT * blockIdx.x; if (i + blockOffset < HIDDEN_SIZE) { if (absStep == 1) { int batchOffset = batch * BATCH_STRIDE; readBufferc[i_ / (THREADS_PER_BLOCK)] = cx == NULL ? cuGet(0) : cuGet(cx[i + blockOffset + batchOffset]); } else { int stepOffset = (step - (DIRECTION == 1 ? -2 : 2)) * TOTAL_MINIBATCH * OUTPUT_STRIDE; int batchOffset = batch * BATCH_OUTPUT_STRIDE; readBufferc[i_ / (THREADS_PER_BLOCK)] = cuGet(h_data[i + blockOffset + batchOffset + stepOffset]); } } } } } } RNN_persist_GEMM( rMat_reg, accumulator, smemGates1[((batch + absStep * MINIBATCH) % 2)], rowStartBlock, colStart); if (!final) { if (colStart == 0) { #pragma unroll for (int j = 0; j < M_PER_THREAD_MAT; j++) { int row = rowStart + j; int batchOffset; if (batch == MINIBATCH - 1) { if (DIRECTION == 0) batchOffset = -TOTAL_MINIBATCH * HIDDEN_SIZE * PERSISTENT_DGRAD_MAX_SPLIT_K; else batchOffset = TOTAL_MINIBATCH * HIDDEN_SIZE * PERSISTENT_DGRAD_MAX_SPLIT_K; } else { batchOffset = (batch + 1) * BATCH_STRIDE * PERSISTENT_DGRAD_MAX_SPLIT_K; } int stepOffset = step * TOTAL_MINIBATCH * HIDDEN_SIZE * PERSISTENT_DGRAD_MAX_SPLIT_K; if (row >= HIDDEN_SIZE) break; tmp_results[row + k_block * HIDDEN_SIZE + batchOffset + stepOffset] = getSafeOutput(accumulator[j][0]); } } } int reloadGates = false; #pragma unroll for (int i_ = 0; i_ < VEC_LENGTH_MAT; i_ += THREADS_PER_BLOCK) { if (isNegativeZero(readBufferGates1[i_ / THREADS_PER_BLOCK])) { reloadGates = true; break; } } if (!final) { if (!(absStep == 0 && batch >= MINIBATCH - 2)) { if (absStep < seqLength - 1 || batch >= MINIBATCH - 2) { while (reloadGates) { #pragma unroll for (int i_ = 0; i_ < VEC_LENGTH_MAT; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (K_PER_THREAD % MATRIX_SPLIT_K_FACTOR == 0) { if (i_ + THREADS_PER_BLOCK * NUM_MATS >= VEC_LENGTH_MAT && i >= HIDDEN_SIZE * NUM_MATS / MATRIX_SPLIT_K_FACTOR) break; } else { if (i_ + THREADS_PER_BLOCK + THREADS_PER_BLOCK * NUM_MATS >= VEC_LENGTH_MAT && i >= HIDDEN_SIZE * NUM_MATS / MATRIX_SPLIT_K_FACTOR) break; } int typeOffset = RNN_MODE == CUDNN_GRU ? 3 * HIDDEN_SIZE : 0; int index; if (DIRECTION == 1) { index = (step + 1) * TOTAL_MINIBATCH * HIDDEN_SIZE + BATCH_STRIDE * (2 - (MINIBATCH - batch)); if (batch < MINIBATCH - 2) { index += (batch + 2) * BATCH_STRIDE - TOTAL_MINIBATCH * HIDDEN_SIZE - BATCH_STRIDE * (2 - (MINIBATCH - batch)); } } else { index = (step - 1) * TOTAL_MINIBATCH * HIDDEN_SIZE + BATCH_STRIDE * (2 - (MINIBATCH - batch)); if (batch < MINIBATCH - 2) { index += (batch + 2) * BATCH_STRIDE + TOTAL_MINIBATCH * HIDDEN_SIZE - BATCH_STRIDE * (2 - (MINIBATCH - batch)); } } index *= NUM_LINS; index += k_block * NUM_MATS * HIDDEN_SIZE / MATRIX_SPLIT_K_FACTOR; index += typeOffset + i; readBufferGates1[i_ / THREADS_PER_BLOCK] = cuGet(loadVolatile(storedResults, index)); } reloadGates = false; #pragma unroll for (int i_ = 0; i_ < VEC_LENGTH_MAT; i_ += THREADS_PER_BLOCK) { if (isNegativeZero(readBufferGates1[i_ / THREADS_PER_BLOCK])) { reloadGates = true; break; } } } } } } bool reloadTmpResults = false; #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH_ACT * MATRIX_SPLIT_K_FACTOR; i_ += THREADS_PER_BLOCK) { if (isNegativeZero(readBufferTmpResults[i_ / (THREADS_PER_BLOCK)])) { reloadTmpResults = true; break; } } while (reloadTmpResults) { #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH_ACT * MATRIX_SPLIT_K_FACTOR; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; int j = i % MATRIX_SPLIT_K_FACTOR; i = i / MATRIX_SPLIT_K_FACTOR; int blockOffset = BLOCK_DI_LENGTH_ACT * blockIdx.x; if (i + blockOffset < HIDDEN_SIZE) { int batchOffset = batch * BATCH_STRIDE * PERSISTENT_DGRAD_MAX_SPLIT_K; int stepOffset = step * TOTAL_MINIBATCH * HIDDEN_SIZE * PERSISTENT_DGRAD_MAX_SPLIT_K; readBufferTmpResults[i_ / (THREADS_PER_BLOCK)] = cuGet( loadVolatile(tmp_results, i + j * HIDDEN_SIZE + blockOffset + stepOffset + batchOffset)); } } reloadTmpResults = false; #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH_ACT * MATRIX_SPLIT_K_FACTOR; i_ += THREADS_PER_BLOCK) { if (isNegativeZero(readBufferTmpResults[i_ / (THREADS_PER_BLOCK)])) { reloadTmpResults = true; break; } } } #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH_ACT * MATRIX_SPLIT_K_FACTOR; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; int j = i % MATRIX_SPLIT_K_FACTOR; i = i / MATRIX_SPLIT_K_FACTOR; if (i < BLOCK_DI_LENGTH_ACT) { smemtmp[getSmemSectionTmp(batch, absStep)][j * BLOCK_DI_LENGTH_ACT + i] = readBufferTmpResults[i_ / (THREADS_PER_BLOCK)]; } } #pragma unroll for (int i_ = 0; i_ < VEC_LENGTH_MAT; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i_ + THREADS_PER_BLOCK >= VEC_LENGTH_MAT && i >= VEC_LENGTH_MAT) break; smemGates1[1 - ((batch + absStep * MINIBATCH) % 2)][i] = readBufferGates1[i_ / (THREADS_PER_BLOCK)]; } #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH_ACT * NUM_LINS; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH_ACT * NUM_LINS) { smemGates2[getSmemSectionGates2(batch, absStep)][i] = readBufferGates2[i_ / (THREADS_PER_BLOCK)]; } } if (M_PER_THREAD_ACT > 1 || GROUP_BATCH_SIZE >= 1) { #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH_ACT; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH_ACT) { smemi[getSmemSectionI(batch, absStep)][i] = readBufferi[i_ / (THREADS_PER_BLOCK)]; } } } if (RNN_MODE == CUDNN_LSTM) { if (absStep != seqLength - 1 || batch == MINIBATCH - 1) { #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH_ACT; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH_ACT) { smemcr[batch + 1 == MINIBATCH ? 0 : (batch + 1)][i] = smemcl[batch + 1 == MINIBATCH ? 0 : (batch + 1)][i]; } } } #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH_ACT; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH_ACT) { smemcl[batch][i] = readBufferc[i_ / (THREADS_PER_BLOCK)]; } } } else if (RNN_MODE == CUDNN_GRU) { #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH_ACT; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH_ACT) { smemcl[batch][i] = readBufferc[i_ / (THREADS_PER_BLOCK)]; } } } __syncthreads(); T_MATH results[M_PER_THREAD_ACT]; #if (defined(GROUP_BATCH_SIZE) && GROUP_BATCH_SIZE >= 1) if (batch % GROUP_BATCH_SIZE == GROUP_BATCH_SIZE - 1) { int writeBatch = batch - (GROUP_BATCH_SIZE - 1) + rowStartBlockAct / BLOCK_DI_LENGTH_ACT; int rowStartBlockGroup = rowStartBlockAct % BLOCK_DI_LENGTH_ACT; if ((rowStartBlockAct / BLOCK_DI_LENGTH_ACT) < GROUP_BATCH_SIZE && rowStartBlockGroup < BLOCK_DI_LENGTH_ACT) { #pragma unroll for (int j = 0; j < M_PER_THREAD_ACT; j++) { results[j] = 0; } #pragma unroll for (int i = 0; i < MATRIX_SPLIT_K_FACTOR; i++) { #pragma unroll for (int j = 0; j < M_PER_THREAD_ACT; j++) { results[j] += smemtmp[getSmemSectionTmp(writeBatch, absStep)] [i * BLOCK_DI_LENGTH_ACT + rowStartBlockGroup + j]; } } } if (absStep > 0) { if ((rowStartBlockAct / BLOCK_DI_LENGTH_ACT) < GROUP_BATCH_SIZE && rowStartBlockGroup < BLOCK_DI_LENGTH_ACT) { if (RNN_MODE == CUDNN_LSTM) { // if (M_PER_THREAD_ACT > 1) { #pragma unroll for (int j = 0; j < M_PER_THREAD_ACT; j++) { results[j] += smemi[getSmemSectionI(writeBatch, absStep)][rowStartBlockGroup + j]; } // } // else { // results[0] += readBufferi[0]; // } #pragma unroll for (int j = 0; j < M_PER_THREAD_ACT; j++) { int row = blockIdx.x * BLOCK_DI_LENGTH_ACT + rowStartBlockGroup + j; if (row >= HIDDEN_SIZE) break; int linGatesBaseIndex = row + ((step - (DIRECTION == 1 ? -1 : 1)) * TOTAL_MINIBATCH * HIDDEN_SIZE + writeBatch * BATCH_STRIDE) * 4; T_MATH linear_in_gate = smemGates2[getSmemSectionGates2(writeBatch, absStep)] [(rowStartBlockGroup + j) * NUM_MATS + 0]; T_MATH linear_forget_gate = smemGates2[getSmemSectionGates2(writeBatch, absStep)] [(rowStartBlockGroup + j) * NUM_MATS + 1]; T_MATH linear_in_gate2 = smemGates2[getSmemSectionGates2(writeBatch, absStep)] [(rowStartBlockGroup + j) * NUM_MATS + 2]; T_MATH linear_out_gate = smemGates2[getSmemSectionGates2(writeBatch, absStep)] [(rowStartBlockGroup + j) * NUM_MATS + 3]; T_MATH deltaY = results[j]; T_MATH _cy = cuGet(smemcr[writeBatch][rowStartBlockGroup + j]); T_MATH dclip_r = dclip(_cy, clipopt, nanopt, lclip, rclip); T_MATH dclip_l = dclip(smemcl[writeBatch][rowStartBlockGroup + j], clipopt, nanopt, lclip, rclip); T_MATH deltaC = deltaY * (linear_out_gate)*dtanh(clip(_cy, clipopt, nanopt, lclip, rclip)) * dclip_r + smemdcx[writeBatch][rowStartBlockGroup + j]; T_MATH out0 = deltaC * (linear_in_gate2)*dsigmoid_2(linear_in_gate); T_MATH out1 = deltaC * clip(smemcl[writeBatch][rowStartBlockGroup + j], clipopt, nanopt, lclip, rclip) * dsigmoid_2(linear_forget_gate); T_MATH out2 = deltaC * (linear_in_gate)*dtanh_2(linear_in_gate2); T_MATH out3 = deltaY * _tanh(clip(_cy, clipopt, nanopt, lclip, rclip)) * dsigmoid_2(linear_out_gate); storedResults[linGatesBaseIndex] = getSafeOutput(out0); storedResults[linGatesBaseIndex + HIDDEN_SIZE] = getSafeOutput(out1); storedResults[linGatesBaseIndex + HIDDEN_SIZE * 2] = getSafeOutput(out2); storedResults[linGatesBaseIndex + HIDDEN_SIZE * 3] = getSafeOutput(out3); smemdcx[writeBatch][rowStartBlockGroup + j] = deltaC * (linear_forget_gate)*dclip_l; } } else if (RNN_MODE == CUDNN_GRU) { // if (M_PER_THREAD_ACT > 1) { #pragma unroll for (int j = 0; j < M_PER_THREAD_ACT; j++) { results[j] += smemi[getSmemSectionI(writeBatch, absStep)][rowStartBlockGroup + j]; } // } // else { // results[0] += readBufferi[0]; // } #pragma unroll for (int j = 0; j < M_PER_THREAD_ACT; j++) { int row = blockIdx.x * BLOCK_DI_LENGTH_ACT + rowStartBlockGroup + j; if (row >= HIDDEN_SIZE) break; int linGatesBaseIndex = row + ((step - (DIRECTION == 1 ? -1 : 1)) * TOTAL_MINIBATCH * HIDDEN_SIZE + writeBatch * BATCH_STRIDE) * 6; T_MATH reset_gate = smemGates2[getSmemSectionGates2(writeBatch, absStep)] [(rowStartBlockGroup + j) * NUM_LINS + 0]; T_MATH update_gate = smemGates2[getSmemSectionGates2(writeBatch, absStep)] [(rowStartBlockGroup + j) * NUM_LINS + 1]; T_MATH new_gate_i = smemGates2[getSmemSectionGates2(writeBatch, absStep)] [(rowStartBlockGroup + j) * NUM_LINS + 2]; T_MATH new_gate_h = smemGates2[getSmemSectionGates2(writeBatch, absStep)] [(rowStartBlockGroup + j) * NUM_LINS + 3]; T_MATH deltaY = results[j] + smemdcx[writeBatch][rowStartBlockGroup + j]; smemdcx[writeBatch][rowStartBlockGroup + j] = deltaY * (update_gate); T_MATH lin_output_new_gate = (reset_gate)*new_gate_h + new_gate_i; T_MATH output_new_gate = _tanh(lin_output_new_gate); T_MATH delta_not = deltaY * output_new_gate; T_MATH hl = cuGet(smemcl[writeBatch][rowStartBlockGroup + j]); T_MATH delta_update_gate = (deltaY * hl - delta_not) * dsigmoid_2(update_gate); T_MATH dh_ = deltaY * (cuGet(1) - (update_gate)); T_MATH dtanh_new_gate = dh_ * dtanh(lin_output_new_gate); T_MATH delta_new_gate_i = dtanh_new_gate; T_MATH delta_new_gate_h = dtanh_new_gate * (reset_gate); T_MATH delta_reset_gate = dtanh_new_gate * new_gate_h * dsigmoid_2(reset_gate); storedResults[linGatesBaseIndex + 0 * HIDDEN_SIZE] = getSafeOutput(delta_reset_gate); storedResults[linGatesBaseIndex + 1 * HIDDEN_SIZE] = getSafeOutput(delta_update_gate); storedResults[linGatesBaseIndex + 2 * HIDDEN_SIZE] = getSafeOutput(delta_new_gate_i); storedResults[linGatesBaseIndex + 3 * HIDDEN_SIZE] = getSafeOutput(delta_reset_gate); storedResults[linGatesBaseIndex + 4 * HIDDEN_SIZE] = getSafeOutput(delta_update_gate); storedResults[linGatesBaseIndex + 5 * HIDDEN_SIZE] = getSafeOutput(delta_new_gate_h); } } else { // if (M_PER_THREAD_ACT > 1) { #pragma unroll for (int j = 0; j < M_PER_THREAD_ACT; j++) { results[j] += smemi[getSmemSectionI(writeBatch, absStep)][rowStartBlockGroup + j]; } // } // else { // results[0] += readBufferi[0]; // } #pragma unroll for (int j = 0; j < M_PER_THREAD_ACT; j++) { if (RNN_MODE == CUDNN_RNN_RELU) { results[j] = smemGates2[getSmemSectionGates2(writeBatch, absStep)][rowStartBlockGroup + j] < cuGet(0) ? cuGet(0) : results[j]; } else if (RNN_MODE == CUDNN_RNN_TANH) { results[j] *= dtanh_2(smemGates2[getSmemSectionGates2(writeBatch, absStep)][rowStartBlockGroup + j]); } } #pragma unroll for (int j = 0; j < M_PER_THREAD_ACT; j++) { // int row = rowStartAct + j; int row = blockIdx.x * BLOCK_DI_LENGTH_ACT + rowStartBlockGroup + j; if (row >= HIDDEN_SIZE) break; storedResults[row + ((step - (DIRECTION == 1 ? -1 : 1)) * TOTAL_MINIBATCH * HIDDEN_SIZE + writeBatch * BATCH_STRIDE)] = getSafeOutput(results[j]); } } } } else { if ((rowStartBlockAct / BLOCK_DI_LENGTH_ACT) < GROUP_BATCH_SIZE && rowStartBlockGroup < BLOCK_DI_LENGTH_ACT && dhx != NULL) { #pragma unroll for (int j = 0; j < M_PER_THREAD_ACT; j++) { int row = blockIdx.x * BLOCK_DI_LENGTH_ACT + rowStartBlockGroup + j; if (row >= HIDDEN_SIZE) break; if (RNN_MODE == CUDNN_GRU) { dhx[row + writeBatch * BATCH_STRIDE] = cuGet(results[j] + smemdcx[writeBatch][rowStartBlockGroup + j]); } else { dhx[row + writeBatch * BATCH_STRIDE] = cuGet(results[j]); } } } if (RNN_MODE == CUDNN_LSTM) { if (seqLength > 1) { if ((rowStartBlockAct / BLOCK_DI_LENGTH_ACT) < GROUP_BATCH_SIZE && rowStartBlockGroup < BLOCK_DI_LENGTH_ACT && dcx != NULL) { #pragma unroll for (int j = 0; j < M_PER_THREAD_ACT; j++) { int row = blockIdx.x * BLOCK_DI_LENGTH_ACT + rowStartBlockGroup + j; if (row >= HIDDEN_SIZE) break; dcx[row + writeBatch * BATCH_STRIDE] = cuGet(smemdcx[writeBatch][rowStartBlockGroup + j]); } } } } } __syncthreads(); } #else if (rowStartBlockAct < BLOCK_DI_LENGTH_ACT) { #pragma unroll for (int j = 0; j < M_PER_THREAD_ACT; j++) { results[j] = 0; } #pragma unroll for (int i = 0; i < MATRIX_SPLIT_K_FACTOR; i++) { #pragma unroll for (int j = 0; j < M_PER_THREAD_ACT; j++) { results[j] += smemtmp[getSmemSectionTmp(batch, absStep)][i * BLOCK_DI_LENGTH_ACT + rowStartBlockAct + j]; } } } if (absStep > 0) { if (rowStartBlockAct < BLOCK_DI_LENGTH_ACT) { if (RNN_MODE == CUDNN_LSTM) { if (M_PER_THREAD_ACT > 1) { #pragma unroll for (int j = 0; j < M_PER_THREAD_ACT; j++) { results[j] += smemi[getSmemSectionI(batch, absStep)][rowStartBlockAct + j]; } } else { results[0] += readBufferi[0]; } #pragma unroll for (int j = 0; j < M_PER_THREAD_ACT; j++) { int row = rowStartAct + j; if (row >= HIDDEN_SIZE) break; int linGatesBaseIndex = row + ((step - (DIRECTION == 1 ? -1 : 1)) * TOTAL_MINIBATCH * HIDDEN_SIZE + batch * BATCH_STRIDE) * 4; T_MATH linear_in_gate = smemGates2[getSmemSectionGates2(batch, absStep)][(rowStartBlockAct + j) * NUM_MATS + 0]; T_MATH linear_forget_gate = smemGates2[getSmemSectionGates2(batch, absStep)][(rowStartBlockAct + j) * NUM_MATS + 1]; T_MATH linear_in_gate2 = smemGates2[getSmemSectionGates2(batch, absStep)][(rowStartBlockAct + j) * NUM_MATS + 2]; T_MATH linear_out_gate = smemGates2[getSmemSectionGates2(batch, absStep)][(rowStartBlockAct + j) * NUM_MATS + 3]; T_MATH deltaY = results[j]; T_MATH _cy = cuGet(smemcr[batch][rowStartBlockAct + j]); T_MATH dclip_r = dclip(_cy, clipopt, nanopt, lclip, rclip); T_MATH dclip_l = dclip(smemcl[batch][rowStartBlockAct + j], clipopt, nanopt, lclip, rclip); T_MATH deltaC = deltaY * (linear_out_gate)*dtanh(clip(_cy, clipopt, nanopt, lclip, rclip)) * dclip_r + smemdcx[batch][rowStartBlockAct + j]; T_MATH out0 = deltaC * (linear_in_gate2)*dsigmoid_2(linear_in_gate); T_MATH out1 = deltaC * clip(smemcl[batch][rowStartBlockAct + j], clipopt, nanopt, lclip, rclip) * dsigmoid_2(linear_forget_gate); T_MATH out2 = deltaC * (linear_in_gate)*dtanh_2(linear_in_gate2); T_MATH out3 = deltaY * _tanh(clip(_cy, clipopt, nanopt, lclip, rclip)) * dsigmoid_2(linear_out_gate); storedResults[linGatesBaseIndex] = getSafeOutput(out0); storedResults[linGatesBaseIndex + HIDDEN_SIZE] = getSafeOutput(out1); storedResults[linGatesBaseIndex + HIDDEN_SIZE * 2] = getSafeOutput(out2); storedResults[linGatesBaseIndex + HIDDEN_SIZE * 3] = getSafeOutput(out3); smemdcx[batch][rowStartBlockAct + j] = deltaC * (linear_forget_gate)*dclip_l; } } else if (RNN_MODE == CUDNN_GRU) { if (M_PER_THREAD_ACT > 1) { #pragma unroll for (int j = 0; j < M_PER_THREAD_ACT; j++) { results[j] += smemi[getSmemSectionI(batch, absStep)][rowStartBlockAct + j]; } } else { results[0] += readBufferi[0]; } #pragma unroll for (int j = 0; j < M_PER_THREAD_ACT; j++) { int row = rowStartAct + j; if (row >= HIDDEN_SIZE) break; int linGatesBaseIndex = row + ((step - (DIRECTION == 1 ? -1 : 1)) * TOTAL_MINIBATCH * HIDDEN_SIZE + batch * BATCH_STRIDE) * 6; T_MATH reset_gate = smemGates2[getSmemSectionGates2(batch, absStep)][(rowStartBlockAct + j) * NUM_LINS + 0]; T_MATH update_gate = smemGates2[getSmemSectionGates2(batch, absStep)][(rowStartBlockAct + j) * NUM_LINS + 1]; T_MATH new_gate_i = smemGates2[getSmemSectionGates2(batch, absStep)][(rowStartBlockAct + j) * NUM_LINS + 2]; T_MATH new_gate_h = smemGates2[getSmemSectionGates2(batch, absStep)][(rowStartBlockAct + j) * NUM_LINS + 3]; T_MATH deltaY = results[j] + smemdcx[batch][rowStartBlockAct + j]; smemdcx[batch][rowStartBlockAct + j] = deltaY * (update_gate); T_MATH lin_output_new_gate = (reset_gate)*new_gate_h + new_gate_i; T_MATH output_new_gate = _tanh(lin_output_new_gate); T_MATH delta_not = deltaY * output_new_gate; T_MATH hl = cuGet(smemcl[batch][rowStartBlockAct + j]); T_MATH delta_update_gate = (deltaY * hl - delta_not) * dsigmoid_2(update_gate); T_MATH dh_ = deltaY * (cuGet(1) - (update_gate)); T_MATH dtanh_new_gate = dh_ * dtanh(lin_output_new_gate); T_MATH delta_new_gate_i = dtanh_new_gate; T_MATH delta_new_gate_h = dtanh_new_gate * (reset_gate); T_MATH delta_reset_gate = dtanh_new_gate * new_gate_h * dsigmoid_2(reset_gate); storedResults[linGatesBaseIndex + 0 * HIDDEN_SIZE] = getSafeOutput(delta_reset_gate); storedResults[linGatesBaseIndex + 1 * HIDDEN_SIZE] = getSafeOutput(delta_update_gate); storedResults[linGatesBaseIndex + 2 * HIDDEN_SIZE] = getSafeOutput(delta_new_gate_i); storedResults[linGatesBaseIndex + 3 * HIDDEN_SIZE] = getSafeOutput(delta_reset_gate); storedResults[linGatesBaseIndex + 4 * HIDDEN_SIZE] = getSafeOutput(delta_update_gate); storedResults[linGatesBaseIndex + 5 * HIDDEN_SIZE] = getSafeOutput(delta_new_gate_h); } } else { if (M_PER_THREAD_ACT > 1) { #pragma unroll for (int j = 0; j < M_PER_THREAD_ACT; j++) { results[j] += smemi[getSmemSectionI(batch, absStep)][rowStartBlockAct + j]; } } else { results[0] += readBufferi[0]; } #pragma unroll for (int j = 0; j < M_PER_THREAD_ACT; j++) { if (RNN_MODE == CUDNN_RNN_RELU) { results[j] = smemGates2[getSmemSectionGates2(batch, absStep)][rowStartBlockAct + j] < cuGet(0) ? cuGet(0) : results[j]; } else if (RNN_MODE == CUDNN_RNN_TANH) { results[j] *= dtanh_2(smemGates2[getSmemSectionGates2(batch, absStep)][rowStartBlockAct + j]); } } #pragma unroll for (int j = 0; j < M_PER_THREAD_ACT; j++) { int row = rowStartAct + j; if (row >= HIDDEN_SIZE) break; storedResults[row + ((step - (DIRECTION == 1 ? -1 : 1)) * TOTAL_MINIBATCH * HIDDEN_SIZE + batch * BATCH_STRIDE)] = getSafeOutput(results[j]); } } } } else { if (rowStartBlockAct < BLOCK_DI_LENGTH_ACT && dhx != NULL) { #pragma unroll for (int j = 0; j < M_PER_THREAD_ACT; j++) { int row = rowStartAct + j; if (row >= HIDDEN_SIZE) break; if (RNN_MODE == CUDNN_GRU) { dhx[row + batch * BATCH_STRIDE] = cuGet(results[j] + smemdcx[batch][rowStartBlockAct + j]); } else { dhx[row + batch * BATCH_STRIDE] = cuGet(results[j]); } } } if (RNN_MODE == CUDNN_LSTM) { if (seqLength > 1) { if (rowStartBlockAct < BLOCK_DI_LENGTH_ACT && dcx != NULL) { #pragma unroll for (int j = 0; j < M_PER_THREAD_ACT; j++) { int row = rowStartAct + j; if (row >= HIDDEN_SIZE) break; dcx[row + batch * BATCH_STRIDE] = cuGet(smemdcx[batch][rowStartBlockAct + j]); } } } } } #endif } } }