const int SMEM_I_SIZE = ((BLOCK_WRITE_LENGTH * NUM_MATS) % 2 == 1) ? BLOCK_WRITE_LENGTH * NUM_MATS + 1 : BLOCK_WRITE_LENGTH * NUM_MATS; __shared__ T_MATH smemi[1][SMEM_I_SIZE]; __shared__ T_MATH smemh[1][WARPS_PER_BLOCK_Y * WARP_SIZE_Y * ELE_PER_THREAD_Y]; __shared__ T_MATH smemcx[BASIC_RNN ? 1 : BLOCK_WRITE_LENGTH][BASIC_RNN ? 1 : MINIBATCH]; __shared__ T_MATH smembias[RNN_MODE == CUDNN_GRU ? BLOCK_WRITE_LENGTH : 1]; int warpIdBlock = (threadIdx.x) / 32; int warpIdGlobal = blockIdx.x * WARPS_PER_BLOCK_X * WARPS_PER_BLOCK_Y + warpIdBlock; int laneId = (threadIdx.x) % 32; int rowStartBlock; int rowStart; int colStart; rowStartBlock = ((warpIdBlock / WARPS_PER_BLOCK_Y) * WARP_SIZE_X + (laneId % WARP_SIZE_X)) * ELE_PER_THREAD_X / NUM_MATS; rowStart = ((warpIdGlobal / WARPS_PER_BLOCK_Y) * WARP_SIZE_X + (laneId % WARP_SIZE_X)) * ELE_PER_THREAD_X / NUM_MATS; colStart = (laneId / WARP_SIZE_X) * INNER_UNROLL; colStart += (warpIdBlock % WARPS_PER_BLOCK_Y) * (VEC_LENGTH / WARPS_PER_BLOCK_Y); const int rowStride = (RNN_MODE == CUDNN_LSTM || RNN_MODE == CUDNN_GRU) ? HIDDEN_SIZE : 1; T_MATH T_reg[ELE_PER_THREAD_Y][ELE_PER_THREAD_X]; RNN_persist_loadT (T_reg, T, rowStart, colStart, rowStride); for (int i_ = 0; i_ < BLOCK_WRITE_LENGTH * NUM_MATS; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_WRITE_LENGTH * NUM_MATS) { smemi[0][i] = cuGet(0); } } for (int i_ = 0; i_ < WARPS_PER_BLOCK_Y * WARP_SIZE_Y * ELE_PER_THREAD_Y; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < WARPS_PER_BLOCK_Y * WARP_SIZE_Y * ELE_PER_THREAD_Y) { smemh[0][i] = cuGet(0); } } if (RNN_MODE == CUDNN_LSTM || RNN_MODE == CUDNN_GRU) { for (int batch = 0; batch < MINIBATCH; batch++) { #pragma unroll for (int i_ = 0; i_ < BLOCK_WRITE_LENGTH; i_ += WARPS_PER_BLOCK_X * WARPS_PER_BLOCK_Y * 32) { int i = i_ + threadIdx.x; if (i < BLOCK_WRITE_LENGTH && i + BLOCK_WRITE_LENGTH * blockIdx.x < HIDDEN_SIZE) { if (RNN_MODE == CUDNN_LSTM) { if (cx != NULL) smemcx[i][batch] = cuGet(cx[i + BLOCK_WRITE_LENGTH * blockIdx.x + batch * BATCH_STRIDE]); else smemcx[i][batch] = cuGet(0); } else if (RNN_MODE == CUDNN_GRU) { if (hx != NULL) smemcx[i][batch] = cuGet(hx[i + BLOCK_WRITE_LENGTH * blockIdx.x + batch * BATCH_STRIDE]); else smemcx[i][batch] = cuGet(0); } } } } } if (RNN_MODE == CUDNN_GRU) { #pragma unroll for (int i_ = 0; i_ < BLOCK_WRITE_LENGTH; i_ += WARPS_PER_BLOCK_X * WARPS_PER_BLOCK_Y * 32) { int i = i_ + threadIdx.x; if (i < BLOCK_WRITE_LENGTH && i + BLOCK_WRITE_LENGTH * blockIdx.x < HIDDEN_SIZE) { smembias[i] = cuGet(bias[5 * HIDDEN_SIZE + i + BLOCK_WRITE_LENGTH * blockIdx.x]); } } } __syncthreads(); for (int absStep = 0; absStep < seqLength; absStep++) { int step; if (DIRECTION == 1) step = seqLength - absStep - 1; else step = absStep; for (int batch = 0; batch < MINIBATCH; batch++) { T_MATH readBufferh[(VEC_LENGTH + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK]; T_MATH readBufferi[(BLOCK_WRITE_LENGTH * NUM_MATS + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK]; #pragma unroll for (int i = 0; i < (VEC_LENGTH + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; i++) { readBufferh[i] = 0; } #pragma unroll for (int i = 0; i < (BLOCK_WRITE_LENGTH * NUM_MATS + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; i++) { readBufferi[i] = 0; } #pragma unroll for (int i_ = 0; i_ < BLOCK_WRITE_LENGTH * NUM_MATS; i_ += WARPS_PER_BLOCK_X * WARPS_PER_BLOCK_Y * 32) { int i = i_ + threadIdx.x; if (i < BLOCK_WRITE_LENGTH * NUM_MATS) { int blockOffset = BLOCK_WRITE_LENGTH * blockIdx.x; int stepOffset = (step * TOTAL_MINIBATCH * HIDDEN_SIZE + batch * BATCH_STRIDE) * NUM_MATS; int j = (i % NUM_MATS); if ((i / NUM_MATS) + j * HIDDEN_SIZE + blockOffset < HIDDEN_SIZE * NUM_MATS) { readBufferi[i_ / (WARPS_PER_BLOCK_X * WARPS_PER_BLOCK_Y * 32)] = cuGet(x[(i / NUM_MATS) + j * HIDDEN_SIZE + blockOffset + stepOffset]); } } } #pragma unroll for (int i_ = 0; i_ < VEC_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i_ + THREADS_PER_BLOCK >= VEC_LENGTH && i >= HIDDEN_SIZE) break; if (absStep == 0) { if (hx != NULL) readBufferh[i_ / (THREADS_PER_BLOCK)] = cuGet(hx[i + batch * BATCH_STRIDE]); else readBufferh[i_ / (THREADS_PER_BLOCK)] = cuGet(0); } else { readBufferh[i_ / (THREADS_PER_BLOCK)] = cuGet(y[((step - (DIRECTION == 1 ? -1 : 1)) * TOTAL_MINIBATCH * OUTPUT_STRIDE + batch * BATCH_OUTPUT_STRIDE) + i]); } } if (absStep > 0) { int reloadH = false; #pragma unroll for (int i_ = 0; i_ < VEC_LENGTH; i_ += THREADS_PER_BLOCK) { if (isNegativeZero(readBufferh[i_ / THREADS_PER_BLOCK])) { reloadH = true; } } while (reloadH) { #pragma unroll for (int i_ = 0; i_ < VEC_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i_ + THREADS_PER_BLOCK >= VEC_LENGTH && i >= HIDDEN_SIZE) break; int index = ((step - (DIRECTION == 1 ? -1 : 1)) * TOTAL_MINIBATCH * OUTPUT_STRIDE + batch * BATCH_OUTPUT_STRIDE) + i; readBufferh[i_ / (THREADS_PER_BLOCK)] = cuGet(loadVolatile(y, index)); } reloadH = false; #pragma unroll for (int i_ = 0; i_ < VEC_LENGTH; i_ += THREADS_PER_BLOCK) { if (readBufferh[i_ / THREADS_PER_BLOCK] == cuGet(0) && signbit(readBufferh[i_ / THREADS_PER_BLOCK])) { reloadH = true; break; } } } } #pragma unroll for (int i_ = 0; i_ < BLOCK_WRITE_LENGTH * NUM_MATS; i_ += WARPS_PER_BLOCK_X * WARPS_PER_BLOCK_Y * 32) { int i = i_ + threadIdx.x; if (i < BLOCK_WRITE_LENGTH * NUM_MATS) { smemi[0][i] = readBufferi[i_ / (WARPS_PER_BLOCK_X * WARPS_PER_BLOCK_Y * 32)]; } } #pragma unroll for (int i_ = 0; i_ < VEC_LENGTH; i_ += WARPS_PER_BLOCK_X * WARPS_PER_BLOCK_Y * 32) { int i = i_ + threadIdx.x; if (i_ + THREADS_PER_BLOCK >= VEC_LENGTH && i >= VEC_LENGTH) break; smemh[0][i] = readBufferh[i_ / (WARPS_PER_BLOCK_X * WARPS_PER_BLOCK_Y * 32)]; } __syncthreads(); T_MATH accumulator[ELE_PER_THREAD_X][INNER_UNROLL]; RNN_persist_GEMM (T_reg, accumulator, smemh[0], rowStartBlock, colStart); if (colStart == 0) { if (RNN_MODE == CUDNN_LSTM) { #pragma unroll for (int j = 0; j < ELE_PER_THREAD_X; j++) { accumulator[j][0] += smemi[0][rowStartBlock * NUM_MATS + j]; } #pragma unroll for (int j = 0; j < ELE_PER_THREAD_X / 4; j++) { T_MATH in_gate = sigmoid(accumulator[j * NUM_MATS + 0][0]); T_MATH forget_gate = sigmoid(accumulator[j * NUM_MATS + 1][0]); T_MATH in_gate2 = _tanh (accumulator[j * NUM_MATS + 2][0]); T_MATH out_gate = sigmoid(accumulator[j * NUM_MATS + 3][0]); T_MATH val = (forget_gate * clip(cuGet(smemcx[rowStartBlock + j][batch]),clipopt,nanopt,lclip,rclip)) + (in_gate * in_gate2); smemcx[rowStartBlock + j][batch] = val; if (TRAINING) { int row = rowStart + j; if (row < HIDDEN_SIZE) { int baseIndex = row + (step * TOTAL_MINIBATCH * HIDDEN_SIZE + batch * BATCH_STRIDE) * 4; storedResults[baseIndex + 0 * HIDDEN_SIZE] = cuGet(in_gate); storedResults[baseIndex + 1 * HIDDEN_SIZE] = cuGet(forget_gate); storedResults[baseIndex + 2 * HIDDEN_SIZE] = cuGet(in_gate2); storedResults[baseIndex + 3 * HIDDEN_SIZE] = cuGet(out_gate); c_data[row + (step * TOTAL_MINIBATCH * HIDDEN_SIZE + batch * BATCH_STRIDE)] = cuGet(val); } } val = out_gate * _tanh(clip(val,clipopt,nanopt,lclip,rclip)); accumulator[j * NUM_MATS][0] = val; } #pragma unroll for (int j = 0; j < ELE_PER_THREAD_X / 4; j++) { int row = rowStart + j; if (row >= HIDDEN_SIZE) break; y[row + (step * TOTAL_MINIBATCH * OUTPUT_STRIDE + batch * BATCH_OUTPUT_STRIDE) ] = getSafeOutput(accumulator[j * NUM_MATS][0]); } } else if (RNN_MODE == CUDNN_GRU) { #pragma unroll for (int j = 0; j < ELE_PER_THREAD_X; j++) { if (j % 3 != 2) accumulator[j][0] += smemi[0][rowStartBlock * NUM_MATS + j]; else accumulator[j][0] += smembias[rowStartBlock + j / 3]; } #pragma unroll for (int j = 0; j < ELE_PER_THREAD_X / 3; j++) { T_MATH r = sigmoid(accumulator[j * NUM_MATS + 0][0]); T_MATH z = sigmoid(accumulator[j * NUM_MATS + 1][0]); if (TRAINING) { int row = rowStart + j; if (row < HIDDEN_SIZE) { int baseIndex = row + (step * TOTAL_MINIBATCH * HIDDEN_SIZE + batch * BATCH_STRIDE) * 6; storedResults[baseIndex + 0 * HIDDEN_SIZE] = cuGet(r); storedResults[baseIndex + 1 * HIDDEN_SIZE] = cuGet(z); storedResults[baseIndex + 2 * HIDDEN_SIZE] = cuGet(smemi[0][rowStartBlock * NUM_MATS + j * NUM_MATS + 2]); storedResults[baseIndex + 3 * HIDDEN_SIZE] = cuGet(accumulator[j * NUM_MATS + 2][0]); } } T_MATH h_ = _tanh(r * accumulator[j * NUM_MATS + 2][0] + smemi[0][rowStartBlock * NUM_MATS + j * NUM_MATS + 2]); T_MATH val; val = (cuGet(1.f) - z) * h_ + z * smemcx[rowStartBlock + j][batch]; smemcx[rowStartBlock + j][batch] = val; accumulator[j * NUM_MATS][0] = val; } #pragma unroll for (int j = 0; j < ELE_PER_THREAD_X / 3; j++) { int row = rowStart + j; if (row >= HIDDEN_SIZE) break; y[row + (step * TOTAL_MINIBATCH * OUTPUT_STRIDE + batch * BATCH_OUTPUT_STRIDE)] = getSafeOutput(accumulator[j * NUM_MATS][0]); } } else { #pragma unroll for (int j = 0; j < ELE_PER_THREAD_X; j++) { accumulator[j][0] += smemi[0][rowStartBlock + j]; } #pragma unroll for (int j = 0; j < ELE_PER_THREAD_X; j++) { if (RNN_MODE == CUDNN_RNN_RELU) { if (TRAINING) { int row = rowStart + j; if (row < HIDDEN_SIZE) { storedResults[row + (step * TOTAL_MINIBATCH * HIDDEN_SIZE + batch * BATCH_STRIDE)] = cuGet(accumulator[j][0]); } } accumulator[j][0] = relu(accumulator[j][0]); } else if (RNN_MODE == CUDNN_RNN_TANH) { accumulator[j][0] = _tanh(accumulator[j][0]); if (TRAINING) { int row = rowStart + j; if (row < HIDDEN_SIZE) { storedResults[row + (step * TOTAL_MINIBATCH * HIDDEN_SIZE + batch * BATCH_STRIDE)] = cuGet(accumulator[j][0]); } } } } #pragma unroll for (int j = 0; j < ELE_PER_THREAD_X; j++) { int row = rowStart + j; if (row >= HIDDEN_SIZE) break; y[row + (step * TOTAL_MINIBATCH * OUTPUT_STRIDE + batch * BATCH_OUTPUT_STRIDE)] = getSafeOutput(accumulator[j][0]); } } } if (absStep == seqLength - 1) { if (colStart == 0 && hy != NULL) { #pragma unroll for (int j = 0; j < ELE_PER_THREAD_X / NUM_MATS; j++) { int row = rowStart + j; if (row >= HIDDEN_SIZE) break; hy[row + batch * BATCH_STRIDE] = cuGet(accumulator[j * NUM_MATS][0]); } } if (RNN_MODE == CUDNN_LSTM) { __syncthreads(); if (colStart == 0 && cy != NULL) { #pragma unroll for (int j = 0; j < ELE_PER_THREAD_X / NUM_MATS; j++) { int row = rowStart + j; if (row >= HIDDEN_SIZE) break; cy[row + batch * BATCH_STRIDE] = cuGet(smemcx[rowStartBlock + j][batch]); } } } } __syncthreads(); } } } extern "C" __launch_bounds__(256, 1) __global__ void RNN_persist_fp(const T_ELEM * __restrict__ x, T_GEMM_IN * __restrict__ y, const T_ELEM * __restrict__ hx, T_ELEM * __restrict__ hy, const T_ELEM * __restrict__ cx, T_ELEM * __restrict__ cy, T_ELEM * __restrict__ c_data, T_ELEM * __restrict__ tmp_h, T_ELEM * __restrict__ storedResults, const T_GEMM_IN * __restrict__ T, const T_ELEM * __restrict__ bias, const int seqLength, cudnnRNNClipMode_t clipopt, cudnnNanPropagation_t nanopt, float lclip, float rclip) { #if (__CUDA_ARCH__ < 600) return; #else if (shouldPipelineFP()) { RNN_persist_fp_stream(x, y, hx, hy, cx, cy, c_data, tmp_h, storedResults, T, bias, seqLength,clipopt,nanopt,lclip,rclip); } else { RNN_persist_fp_simple(x, y, hx, hy, cx, cy, c_data, tmp_h, storedResults, T, bias, seqLength,clipopt,nanopt,lclip,rclip); } #endif }