T_GEMM_IN * __restrict__ y, const T_ELEM * __restrict__ hx, T_ELEM * __restrict__ hy, const T_ELEM * __restrict__ cx, T_ELEM * __restrict__ cy, T_ELEM * __restrict__ c_data, T_ELEM * __restrict__ tmp_h, T_ELEM * __restrict__ storedResults, const T_GEMM_IN * __restrict__ T, const T_ELEM * __restrict__ bias, const int seqLength, cudnnRNNClipMode_t clipopt, cudnnNanPropagation_t nanopt, float lclip, float rclip) { const int THREADS_PER_BLOCK = WARPS_PER_BLOCK_X * WARPS_PER_BLOCK_Y * 32; const int NUM_MATS = (RNN_MODE == CUDNN_LSTM) ? 4 : (RNN_MODE == CUDNN_GRU) ? 3 : 1; const int BASIC_RNN = RNN_MODE == CUDNN_RNN_RELU || RNN_MODE == CUDNN_RNN_TANH; const int THREAD_Y_STRIDE = WARP_SIZE_Y; const int BLOCK_WRITE_LENGTH = WARPS_PER_BLOCK_X * WARP_SIZE_X * ELE_PER_THREAD_X / NUM_MATS; const int SMEM_I_SIZE = ((BLOCK_WRITE_LENGTH * NUM_MATS) % 2 == 1) ? BLOCK_WRITE_LENGTH * NUM_MATS + 1 : BLOCK_WRITE_LENGTH * NUM_MATS; __shared__ T_MATH smemi[GROUP_BATCH_SIZE > 1 ? MINIBATCH : 2][SMEM_I_SIZE]; __shared__ T_MATH smemh[2][WARPS_PER_BLOCK_Y * WARP_SIZE_Y * ELE_PER_THREAD_Y]; __shared__ T_MATH smemcx[BASIC_RNN ? 1 : BLOCK_WRITE_LENGTH][BASIC_RNN ? 1 : MINIBATCH]; __shared__ T_MATH smembias[RNN_MODE == CUDNN_GRU ? BLOCK_WRITE_LENGTH : 1]; int warpIdBlock = (threadIdx.x) / 32; int warpIdGlobal = blockIdx.x * WARPS_PER_BLOCK_X * WARPS_PER_BLOCK_Y + warpIdBlock; int laneId = (threadIdx.x) % 32; int rowStartBlock; int rowStart; int colStart; rowStartBlock = ((warpIdBlock / WARPS_PER_BLOCK_Y) * WARP_SIZE_X + (laneId % WARP_SIZE_X)) * ELE_PER_THREAD_X / NUM_MATS; rowStart = ((warpIdGlobal / WARPS_PER_BLOCK_Y) * WARP_SIZE_X + (laneId % WARP_SIZE_X)) * ELE_PER_THREAD_X / NUM_MATS; colStart = (laneId / WARP_SIZE_X) * INNER_UNROLL; colStart += (warpIdBlock % WARPS_PER_BLOCK_Y) * (VEC_LENGTH / WARPS_PER_BLOCK_Y); const int rowStride = (RNN_MODE == CUDNN_LSTM || RNN_MODE == CUDNN_GRU) ? HIDDEN_SIZE : 1; T_MATH T_reg[ELE_PER_THREAD_Y][ELE_PER_THREAD_X]; RNN_persist_loadT(T_reg, T, rowStart, colStart, rowStride); for (int i_ = 0; i_ < BLOCK_WRITE_LENGTH * NUM_MATS; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_WRITE_LENGTH * NUM_MATS) { smemi[0][i] = cuGet(0); smemi[1][i] = cuGet(0); } } for (int i_ = 0; i_ < WARPS_PER_BLOCK_Y * WARP_SIZE_Y * ELE_PER_THREAD_Y; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < WARPS_PER_BLOCK_Y * WARP_SIZE_Y * ELE_PER_THREAD_Y) { smemh[0][i] = cuGet(0); smemh[1][i] = cuGet(0); } } if (RNN_MODE == CUDNN_LSTM || RNN_MODE == CUDNN_GRU) { for (int batch = 0; batch < MINIBATCH; batch++) { #pragma unroll for (int i_ = 0; i_ < BLOCK_WRITE_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_WRITE_LENGTH && i + BLOCK_WRITE_LENGTH * blockIdx.x < HIDDEN_SIZE) { if (RNN_MODE == CUDNN_LSTM) { if (cx != NULL) smemcx[i][batch]= cuGet(cx[i + BLOCK_WRITE_LENGTH * blockIdx.x + batch * BATCH_STRIDE]); else smemcx[i][batch]= cuGet(0); } else if (RNN_MODE == CUDNN_GRU) { if (hx != NULL) smemcx[i][batch] = cuGet(hx[i + BLOCK_WRITE_LENGTH * blockIdx.x + batch * BATCH_STRIDE]); else smemcx[i][batch] = cuGet(0); } } } } } if (RNN_MODE == CUDNN_GRU) { #pragma unroll for (int i_ = 0; i_ < BLOCK_WRITE_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_WRITE_LENGTH && i + BLOCK_WRITE_LENGTH * blockIdx.x < HIDDEN_SIZE) { smembias[i] = cuGet(bias[5 * HIDDEN_SIZE + i + BLOCK_WRITE_LENGTH * blockIdx.x]); } } } __syncthreads(); T_MATH readBufferh[(VEC_LENGTH + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK]; T_MATH readBufferi[(BLOCK_WRITE_LENGTH * NUM_MATS + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK]; #pragma unroll for (int i = 0; i < (VEC_LENGTH + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; i++) { readBufferh[i] = 0; } #pragma unroll for (int i = 0; i < (BLOCK_WRITE_LENGTH * NUM_MATS + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; i++) { readBufferi[i] = 0; } int absStep = 0; int step; if (DIRECTION == 1) step = seqLength - absStep - 1; else step = absStep; #pragma unroll for (int i_ = 0; i_ < BLOCK_WRITE_LENGTH * NUM_MATS; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_WRITE_LENGTH * NUM_MATS) { int blockOffset = BLOCK_WRITE_LENGTH * blockIdx.x; int stepOffset = (step * TOTAL_MINIBATCH + 0) * HIDDEN_SIZE * NUM_MATS; int j = (i % NUM_MATS); if ((i / NUM_MATS) + j * HIDDEN_SIZE + blockOffset < HIDDEN_SIZE * NUM_MATS) { readBufferi[i_ / (THREADS_PER_BLOCK)] = cuGet(x[(i / NUM_MATS) + j * HIDDEN_SIZE + blockOffset + stepOffset]); } } } #pragma unroll for (int i_ = 0; i_ < VEC_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i_ + THREADS_PER_BLOCK >= VEC_LENGTH && i >= HIDDEN_SIZE) break; if (hx != NULL) readBufferh[i_ / (THREADS_PER_BLOCK)] = cuGet(hx[i]); else readBufferh[i_ / (THREADS_PER_BLOCK)] = cuGet(0); } #pragma unroll for (int i_ = 0; i_ < BLOCK_WRITE_LENGTH * NUM_MATS; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_WRITE_LENGTH * NUM_MATS) { smemi[0][i] = readBufferi[i_ / (THREADS_PER_BLOCK)]; } } #pragma unroll for (int i_ = 0; i_ < VEC_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i_ + THREADS_PER_BLOCK >= VEC_LENGTH && i >= VEC_LENGTH) break; smemh[0][i] = readBufferh[i_ / (THREADS_PER_BLOCK)]; } __syncthreads(); for (; absStep < seqLength; absStep++) { if (DIRECTION == 1) step = seqLength - absStep - 1; else step = absStep; for (int batch = 0; batch < MINIBATCH; batch++) { int final = absStep == seqLength - 1 && batch == MINIBATCH - 1; if (!final) { if (absStep == 0 && batch < MINIBATCH - 1) { #pragma unroll for (int i_ = 0; i_ < VEC_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i_ + THREADS_PER_BLOCK < VEC_LENGTH || i < HIDDEN_SIZE) { if (hx != NULL) readBufferh[i_ / THREADS_PER_BLOCK] = cuGet(hx[i + (batch + 1) * BATCH_STRIDE]); } } } else { #pragma unroll for (int i_ = 0; i_ < VEC_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i_ + THREADS_PER_BLOCK < VEC_LENGTH || i < HIDDEN_SIZE) { int index; index = step * TOTAL_MINIBATCH * OUTPUT_STRIDE + i; if (DIRECTION == 0) { if (batch != MINIBATCH - 1) { index += (batch + 1) * BATCH_OUTPUT_STRIDE - TOTAL_MINIBATCH * OUTPUT_STRIDE; } } else { if (batch != MINIBATCH - 1) { index += (batch + 1) * BATCH_OUTPUT_STRIDE + TOTAL_MINIBATCH * OUTPUT_STRIDE; } } readBufferh[i_ / THREADS_PER_BLOCK] = cuGet(y[index]); } } } #pragma unroll for (int i_ = 0; i_ < BLOCK_WRITE_LENGTH * NUM_MATS; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if ((i_ + THREADS_PER_BLOCK < BLOCK_WRITE_LENGTH * NUM_MATS || i < BLOCK_WRITE_LENGTH * NUM_MATS)) { int blockOffset = BLOCK_WRITE_LENGTH * blockIdx.x; int stepOffset = (step * TOTAL_MINIBATCH) * HIDDEN_SIZE * NUM_MATS; int batchOffset = (batch + 1) * BATCH_STRIDE * NUM_MATS; if (batch == MINIBATCH - 1) { if (DIRECTION == 0) { batchOffset = TOTAL_MINIBATCH * HIDDEN_SIZE * NUM_MATS; } else { batchOffset = -TOTAL_MINIBATCH * HIDDEN_SIZE * NUM_MATS; } } int j = (i % NUM_MATS); if ((i / NUM_MATS) + j * HIDDEN_SIZE + blockOffset < HIDDEN_SIZE * NUM_MATS) { readBufferi[i_ / THREADS_PER_BLOCK] = cuGet(x[(i / NUM_MATS) + j * HIDDEN_SIZE + blockOffset + stepOffset + batchOffset]); } } } } T_MATH accumulator[ELE_PER_THREAD_X][INNER_UNROLL]; RNN_persist_GEMM(T_reg, accumulator, smemh[getSmemSectionH(batch, absStep)], rowStartBlock, colStart); int reloadH = false; #pragma unroll for (int i_ = 0; i_ < VEC_LENGTH; i_ += THREADS_PER_BLOCK) { if (isNegativeZero(readBufferh[i_ / THREADS_PER_BLOCK])) { reloadH = true; } } if (absStep > 0 || batch == MINIBATCH - 1) { while (reloadH) { #pragma unroll for (int i_ = 0; i_ < VEC_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i_ + THREADS_PER_BLOCK >= VEC_LENGTH && i >= HIDDEN_SIZE) break; int index; if (DIRECTION == 0) { index = step * TOTAL_MINIBATCH * OUTPUT_STRIDE + i; if (batch != MINIBATCH - 1) { index += (batch + 1) * BATCH_OUTPUT_STRIDE - TOTAL_MINIBATCH * OUTPUT_STRIDE; } } else { index = step * TOTAL_MINIBATCH * OUTPUT_STRIDE + i; if (batch != MINIBATCH - 1) { index += (batch + 1) * BATCH_OUTPUT_STRIDE + TOTAL_MINIBATCH * OUTPUT_STRIDE; } } readBufferh[i_ / (THREADS_PER_BLOCK)] = cuGet(loadVolatile(y, index)); } reloadH = false; #pragma unroll for (int i_ = 0; i_ < VEC_LENGTH; i_ += THREADS_PER_BLOCK) { if (isNegativeZero(readBufferh[i_ / THREADS_PER_BLOCK])) { reloadH = true; break; } } } } #pragma unroll for (int i_ = 0; i_ < VEC_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i_ + THREADS_PER_BLOCK >= VEC_LENGTH && i >= VEC_LENGTH) break; smemh[getSmemSectionH(batch + 1, absStep)][i] = readBufferh[i_ / (THREADS_PER_BLOCK)]; } #pragma unroll for (int i_ = 0; i_ < BLOCK_WRITE_LENGTH * NUM_MATS; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i_ + THREADS_PER_BLOCK < BLOCK_WRITE_LENGTH * NUM_MATS || i < BLOCK_WRITE_LENGTH * NUM_MATS) { smemi[getSmemSectionI(batch + 1, absStep)][i] = readBufferi[i_ / (THREADS_PER_BLOCK)]; } } #if (defined(GROUP_BATCH_SIZE) && GROUP_BATCH_SIZE > 1) __shared__ T_MATH smemBatchGroup[GROUP_BATCH_SIZE-1][BLOCK_WRITE_LENGTH][ELE_PER_THREAD_X]; if (batch % GROUP_BATCH_SIZE == GROUP_BATCH_SIZE - 1 && colStart < GROUP_BATCH_SIZE) { int writeBatch = batch - colStart; if (colStart != 0) { #pragma unroll for (int j = 0; j < ELE_PER_THREAD_X; j++) { accumulator[j][0] = smemBatchGroup[(colStart - 1)][rowStartBlock][j]; } } if (RNN_MODE == CUDNN_GRU) { #pragma unroll for (int j = 0; j < ELE_PER_THREAD_X; j++) { if (j % 3 != 2) accumulator[j][0] += smemi[getSmemSectionI(writeBatch, absStep)][rowStartBlock * NUM_MATS + j]; else accumulator[j][0] += smembias[rowStartBlock + j / 3]; } } else { #pragma unroll for (int j = 0; j < ELE_PER_THREAD_X; j++) { accumulator[j][0] += smemi[getSmemSectionI(writeBatch, absStep)][rowStartBlock * NUM_MATS + j]; } }