1) * TOTAL_MINIBATCH) * HIDDEN_SIZE; for (int batch = 0; batch < MINIBATCH; batch++) { int batchOffset = batch * BATCH_STRIDE; #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH && i + BLOCK_DI_LENGTH * blockIdx.x < HIDDEN_SIZE) { if (RNN_MODE == CUDNN_LSTM) { smemdcx[batch][i] = cuGet(c_data[i + blockOffset + batchOffset + stepOffset]); } else { smemdcx[batch][i] = cuGet(tmp_dh[i + blockOffset + batchOffset]); } } } } } __syncthreads(); for (int absStep = seqLength - 1; absStep >= 0; absStep--) { int step; if (DIRECTION == 1) step = seqLength - absStep - 1; else step = absStep; for (int batch = 0; batch < MINIBATCH; batch++) { T_MATH readBufferGates1[(VEC_LENGTH + THREADS_PER_BLOCK - 1) / (THREADS_PER_BLOCK)]; T_MATH readBufferGates2[(BLOCK_DI_LENGTH * NUM_LINS + THREADS_PER_BLOCK - 1) / (THREADS_PER_BLOCK)]; T_MATH readBufferi[(BLOCK_DI_LENGTH + THREADS_PER_BLOCK - 1) / (THREADS_PER_BLOCK)]; T_MATH readBufferc[(BLOCK_DI_LENGTH + THREADS_PER_BLOCK - 1) / (THREADS_PER_BLOCK)]; #pragma unroll for (int i = 0; i < (VEC_LENGTH + THREADS_PER_BLOCK - 1) / (THREADS_PER_BLOCK); i++) { readBufferGates1[i] = 0; } #pragma unroll for (int i = 0; i < (BLOCK_DI_LENGTH * NUM_LINS + THREADS_PER_BLOCK - 1) / (THREADS_PER_BLOCK); i++) { readBufferGates2[i] = 0; } #pragma unroll for (int i = 0; i < (BLOCK_DI_LENGTH + THREADS_PER_BLOCK - 1) / (THREADS_PER_BLOCK); i++) { readBufferi[i] = 0; if (RNN_MODE == CUDNN_LSTM || RNN_MODE == CUDNN_GRU) readBufferc[i] = 0; } #pragma unroll for (int i_ = 0; i_ < VEC_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i_ + THREADS_PER_BLOCK >= VEC_LENGTH / NUM_MATS && i >= HIDDEN_SIZE * NUM_MATS) break; int typeOffset = RNN_MODE == CUDNN_GRU ? 3 * HIDDEN_SIZE : 0; int index = (step * TOTAL_MINIBATCH * HIDDEN_SIZE + batch * BATCH_STRIDE) * NUM_LINS + typeOffset + i; readBufferGates1[i_ / (THREADS_PER_BLOCK)] = cuGet(storedResults[index]); } if (absStep < seqLength - 1) { int reloadGates = false; #pragma unroll for (int i_ = 0; i_ < VEC_LENGTH; i_ += THREADS_PER_BLOCK) { if (cuGet(readBufferGates1[i_ / THREADS_PER_BLOCK]) == cuGet(0) && signbit(cuGet(readBufferGates1[i_ / THREADS_PER_BLOCK]))) { reloadGates = true; break; } } while (reloadGates) { #pragma unroll for (int i_ = 0; i_ < VEC_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i_ + THREADS_PER_BLOCK >= VEC_LENGTH / NUM_MATS && i >= HIDDEN_SIZE * NUM_MATS) break; int typeOffset = RNN_MODE == CUDNN_GRU ? 3 * HIDDEN_SIZE : 0; int index = (step * TOTAL_MINIBATCH * HIDDEN_SIZE + batch * BATCH_STRIDE) * NUM_LINS + typeOffset + i; readBufferGates1[i_ / THREADS_PER_BLOCK] = cuGet(loadVolatile(storedResults, index)); } reloadGates = false; #pragma unroll for (int i_ = 0; i_ < VEC_LENGTH; i_ += THREADS_PER_BLOCK) { if (cuGet(readBufferGates1[i_ / THREADS_PER_BLOCK]) == cuGet(0) && signbit(cuGet(readBufferGates1[i_ / THREADS_PER_BLOCK]))) { reloadGates = true; break; } } } } #pragma unroll for (int i_ = 0; i_ < VEC_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i_ + THREADS_PER_BLOCK >= VEC_LENGTH / NUM_MATS && i >= VEC_LENGTH) break; smemGates1[0][i] = readBufferGates1[i_ / THREADS_PER_BLOCK]; } if (absStep > 0) { #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH * NUM_LINS; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH * NUM_LINS) { int blockOffset = BLOCK_DI_LENGTH * blockIdx.x; int stepOffset = ((step - (DIRECTION == 1 ? -1 : 1)) * TOTAL_MINIBATCH * HIDDEN_SIZE + batch * BATCH_STRIDE) * NUM_LINS; int j = (i % NUM_LINS); if ((i / NUM_LINS) + j * HIDDEN_SIZE + blockOffset < HIDDEN_SIZE * NUM_LINS) { readBufferGates2[i_ / (THREADS_PER_BLOCK)] = cuGet(tmp_gates[(i / NUM_LINS) + j * HIDDEN_SIZE + blockOffset + stepOffset]); } } } #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH) { int blockOffset = BLOCK_DI_LENGTH * blockIdx.x; int stepOffset = (step - (DIRECTION == 1 ? -1 : 1)) * TOTAL_MINIBATCH * OUTPUT_STRIDE + batch * BATCH_OUTPUT_STRIDE; if (i + blockOffset < HIDDEN_SIZE) { readBufferi[i_ / (THREADS_PER_BLOCK)] = cuGet(dy[i + blockOffset + stepOffset]); } } } #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH * NUM_LINS; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH * NUM_LINS) { smemGates2[0][i] = readBufferGates2[i_ / (THREADS_PER_BLOCK)]; } } #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH) { smemi[0][i] = readBufferi[i_ / (THREADS_PER_BLOCK)]; } } if (RNN_MODE == CUDNN_LSTM) { if (absStep == seqLength - 1) { #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH) { int blockOffset = BLOCK_DI_LENGTH * blockIdx.x; int stepOffset = (step - (DIRECTION == 1 ? -1 : 1)) * TOTAL_MINIBATCH * HIDDEN_SIZE + batch * BATCH_STRIDE; if (i + blockOffset < HIDDEN_SIZE) { readBufferc[i_ / (THREADS_PER_BLOCK)] = cuGet(c_data[i + blockOffset + stepOffset]); } } } #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH) { smemcr[batch][i] = readBufferc[i_ / (THREADS_PER_BLOCK)]; } } } else { for (int i_ = 0; i_ < BLOCK_DI_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH) { smemcr[batch][i] = smemcl[batch][i]; } } } #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH) { int blockOffset = BLOCK_DI_LENGTH * blockIdx.x; int batchOffset = batch * BATCH_STRIDE; int stepOffset = ((step - (DIRECTION == 1 ? -2 : 2)) * TOTAL_MINIBATCH) * HIDDEN_SIZE; if (i + blockOffset < HIDDEN_SIZE) { if (absStep > 1) { readBufferc[i_ / (THREADS_PER_BLOCK)] = cuGet(c_data[i + blockOffset + batchOffset + stepOffset]); } else { readBufferc[i_ / (THREADS_PER_BLOCK)] = cx == NULL ? cuGet(0) : cuGet(cx[i + blockOffset + batchOffset]); } } } } #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH) { smemcl[batch][i] = readBufferc[i_ / (THREADS_PER_BLOCK)]; } } } else if (RNN_MODE == CUDNN_GRU) { #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH) { int blockOffset = BLOCK_DI_LENGTH * blockIdx.x; if (i + blockOffset < HIDDEN_SIZE) { if (absStep > 1) { int stepOffset = ((step - (DIRECTION == 1 ? -2 : 2)) * TOTAL_MINIBATCH) * OUTPUT_STRIDE; int batchOffset = batch * BATCH_OUTPUT_STRIDE; readBufferc[i_ / (THREADS_PER_BLOCK)] = cuGet(h_data[i + blockOffset + batchOffset + stepOffset]); } else { int batchOffset = batch * BATCH_STRIDE; readBufferc[i_ / (THREADS_PER_BLOCK)] = cx == NULL ? cuGet(0) : cuGet(cx[i + blockOffset + batchOffset]); } } } } #pragma unroll for (int i_ = 0; i_ < BLOCK_DI_LENGTH; i_ += THREADS_PER_BLOCK) { int i = i_ + threadIdx.x; if (i < BLOCK_DI_LENGTH) { smemcl[batch][i] = readBufferc[i_ / (THREADS_PER_BLOCK)]; } } } } __syncthreads(); T_MATH accumulator[M_PER_THREAD][INNER_UNROLL]; RNN_persist_GEMM( T_reg, accumulator, smemGates1[0], rowStartBlock, colStart); if (absStep > 0) { if (colStart == 0) { if (RNN_MODE == CUDNN_LSTM) { #pragma unroll for (int j = 0; j < M_PER_THREAD; j++) { accumulator[j][0] += smemi[0][rowStartBlock + j]; } #pragma unroll for (int j = 0; j < M_PER_THREAD; j++) { int row = rowStart + j; if (row >= HIDDEN_SIZE) break; int linGatesBaseIndex = row + ((step - (DIRECTION == 1 ? -1 : 1)) * TOTAL_MINIBATCH * HIDDEN_SIZE + batch * BATCH_STRIDE) * 4; T_MATH linear_in_gate = smemGates2[0][(rowStartBlock + j) * NUM_MATS + 0]; T_MATH linear_forget_gate = smemGates2[0][(rowStartBlock + j) * NUM_MATS + 1]; T_MATH linear_in_gate2 = smemGates2[0][(rowStartBlock + j) * NUM_MATS + 2]; T_MATH linear_out_gate = smemGates2[0][(rowStartBlock + j) * NUM_MATS + 3]; T_MATH deltaY = accumulator[j][0]; T_MATH _cy = cuGet(smemcr[batch][rowStartBlock + j]); T_MATH dclip_r = dclip(_cy, clipopt, nanopt, lclip, rclip); T_MATH dclip_l = dclip(smemcl[batch][rowStartBlock + j], clipopt, nanopt, lclip, rclip); T_MATH deltaC = deltaY * (linear_out_gate)*dtanh(clip(_cy, clipopt, nanopt, lclip, rclip)) * dclip_r + smemdcx[batch][rowStartBlock + j]; T_MATH out0 = deltaC * (linear_in_gate2)*dsigmoid_2(linear_in_gate); T_MATH out1 = deltaC * clip(smemcl[batch][rowStartBlock + j], clipopt, nanopt, lclip, rclip) * dsigmoid_2(linear_forget_gate); T_MATH out2 = deltaC * (linear_in_gate)*dtanh_2(linear_in_gate2); T_MATH out3 = deltaY * _tanh(clip(_cy, clipopt, nanopt, lclip, rclip)) * dsigmoid_2(linear_out_gate); storedResults[linGatesBaseIndex + 0 * HIDDEN_SIZE] = getSafeOutput(out0); storedResults[linGatesBaseIndex + 1 * HIDDEN_SIZE] = getSafeOutput(out1); storedResults[linGatesBaseIndex + 2 * HIDDEN_SIZE] = getSafeOutput(out2); storedResults[linGatesBaseIndex + 3 * HIDDEN_SIZE] = getSafeOutput(out3); smemdcx[batch][rowStartBlock + j] = deltaC * (linear_forget_gate)*dclip_l; } } else if (RNN_MODE == CUDNN_GRU) { #pragma unroll for (int j = 0; j < M_PER_THREAD; j++) { accumulator[j][0] += smemi[0][rowStartBlock + j]; } #pragma unroll for (int j = 0; j < M_PER_THREAD; j++) { int row = rowStart + j; if (row >= HIDDEN_SIZE) break; int linGatesBaseIndex = row + ((step - (DIRECTION == 1 ? -1 : 1)) * TOTAL_MINIBATCH * HIDDEN_SIZE + batch * BATCH_STRIDE) * 6; T_MATH reset_gate = smemGates2[0][(rowStartBlock + j) * NUM_LINS + 0]; T_MATH update_gate = smemGates2[0][(rowStartBlock + j) * NUM_LINS + 1]; T_MATH new_gate_i = smemGates2[0][(rowStartBlock + j) * NUM_LINS + 2]; T_MATH new_gate_h = smemGates2[0][(rowStartBlock + j) * NUM_LINS + 3]; T_MATH deltaY = accumulator[j][0] + smemdcx[batch][rowStartBlock + j]; smemdcx[batch][rowStartBlock + j] = deltaY * (update_gate); T_MATH lin_output_new_gate = (reset_gate)*new_gate_h + new_gate_i; T_MATH output_new_gate = _tanh(lin_output_new_gate); T_MATH delta_not = deltaY * output_new_gate; T_MATH hl = cuGet(smemcl[batch][rowStartBlock + j]); T_MATH delta_update_gate = (deltaY * hl - delta_not) * dsigmoid_2(update_gate); T_MATH dh_ = deltaY * (cuGet(1) - (update_gate)); T_MATH dtanh_new_gate = dh_ * dtanh(lin_output_new_gate); T_MATH delta_new_gate_i = dtanh_new_gate; T_MATH delta_new_gate_h = dtanh_new_gate * (reset_gate); T_MATH delta_reset_gate = dtanh_new_gate * new_gate_h * dsigmoid_2(reset_gate); storedResults[linGatesBaseIndex + 0 * HIDDEN_SIZE] = getSafeOutput(delta_reset_gate); storedResults[linGatesBaseIndex + 1 * HIDDEN_SIZE] = getSafeOutput(delta_update_gate); storedResults[linGatesBaseIndex + 2 * HIDDEN_SIZE] = getSafeOutput(delta_new_gate_i); storedResults[linGatesBaseIndex + 3 * HIDDEN_SIZE] = getSafeOutput(delta_reset_gate); storedResults[linGatesBaseIndex + 4 * HIDDEN_SIZE] = getSafeOutput(delta_update_gate); storedResults[linGatesBaseIndex + 5 * HIDDEN_SIZE] = getSafeOutput(delta_new_gate_h); } } else { #pragma unroll for (int j = 0; j < M_PER_THREAD; j++) { accumulator[j][0] += smemi[0][rowStartBlock + j]; } #pragma unroll for (int j = 0; j < M_PER_THREAD; j++) { if (RNN_MODE == CUDNN_RNN_RELU) { accumulator[j][0] = smemGates2[0][rowStartBlock + j] < cuGet(0) ? cuGet(0) : accumulator[j][0]; } else if (RNN_MODE == CUDNN_RNN_TANH) { accumulator[j][0] *= dtanh_2(smemGates2[0][rowStartBlock + j]); } } #pragma unroll for (int j = 0; j < M_PER_THREAD; j++) { int row = rowStart + j; if (row >= HIDDEN_SIZE) break; int index = row + ((step - (DIRECTION == 1 ? -1 : 1)) * TOTAL_MINIBATCH * HIDDEN_SIZE + batch * BATCH_STRIDE); storedResults[index] = getSafeOutput(accumulator[j][0]); } } } } else { if (colStart == 0 && dhx != NULL) { #pragma unroll for (int j = 0; j < M_PER_THREAD; j++) { int row = rowStart + j; if (row >= HIDDEN_SIZE) break; if (RNN_MODE == CUDNN_GRU) { dhx[row + batch * BATCH_STRIDE] = cuGet(accumulator[j][0] + smemdcx[batch][rowStartBlock + j]); } else { dhx[row + batch * BATCH_STRIDE] = cuGet(accumulator[j][0]); } } } if (RNN_MODE == CUDNN_LSTM) { if (seqLength > 1) { __syncthreads(); if (colStart == 0 && dcx != NULL) { #pragma unroll for (int j = 0; j < M_PER_THREAD; j++) { int row = rowStart + j; if (row >= HIDDEN_SIZE) break; dcx[row + batch * BATCH_STRIDE] = cuGet(smemdcx[batch][rowStartBlock + j]); } } } } } __syncthreads(); } } }