if (RNN_MODE == CUDNN_LSTM) { #pragma unroll for (int j = 0; j < ELE_PER_THREAD_X / 4; j++) { T_MATH in_gate = sigmoid(accumulator[j * NUM_MATS + 0][0]); T_MATH forget_gate = sigmoid(accumulator[j * NUM_MATS + 1][0]); T_MATH in_gate2 = _tanh (accumulator[j * NUM_MATS + 2][0]); T_MATH out_gate = sigmoid(accumulator[j * NUM_MATS + 3][0]); T_MATH val = (forget_gate * clip(cuGet(smemcx[rowStartBlock + j][writeBatch]),clipopt,nanopt,lclip,rclip)) + (in_gate * in_gate2); smemcx[rowStartBlock + j][writeBatch] = val; if (TRAINING) { int row = rowStart + j; if (row < HIDDEN_SIZE) { int baseIndex = row + (step * TOTAL_MINIBATCH * HIDDEN_SIZE + writeBatch * BATCH_STRIDE) * 4; storedResults[baseIndex + 0 * HIDDEN_SIZE] = cuGet(in_gate); storedResults[baseIndex + 1 * HIDDEN_SIZE] = cuGet(forget_gate); storedResults[baseIndex + 2 * HIDDEN_SIZE] = cuGet(in_gate2); storedResults[baseIndex + 3 * HIDDEN_SIZE] = cuGet(out_gate); c_data[row + (step * TOTAL_MINIBATCH * HIDDEN_SIZE + writeBatch * BATCH_STRIDE)] = cuGet(val); } } val = out_gate * _tanh(clip(val,clipopt,nanopt,lclip,rclip)); accumulator[j * NUM_MATS][0] = val; } #pragma unroll for (int j = 0; j < ELE_PER_THREAD_X / 4; j++) { int row = rowStart + j; if (row >= HIDDEN_SIZE) break; y[row + (step * TOTAL_MINIBATCH * OUTPUT_STRIDE + writeBatch * BATCH_OUTPUT_STRIDE)] = getSafeOutput(accumulator[j * NUM_MATS][0]); } } else if (RNN_MODE == CUDNN_GRU) { #pragma unroll for (int j = 0; j < ELE_PER_THREAD_X / 3; j++) { T_MATH r = sigmoid(accumulator[j * NUM_MATS + 0][0]); T_MATH z = sigmoid(accumulator[j * NUM_MATS + 1][0]); if (TRAINING) { int row = rowStart + j; if (row < HIDDEN_SIZE) { int baseIndex = row + (step * TOTAL_MINIBATCH * HIDDEN_SIZE + writeBatch * BATCH_STRIDE) * 6; storedResults[baseIndex + 0 * HIDDEN_SIZE] = cuGet(r); storedResults[baseIndex + 1 * HIDDEN_SIZE] = cuGet(z); storedResults[baseIndex + 2 * HIDDEN_SIZE] = cuGet(smemi[getSmemSectionI(writeBatch, absStep)][rowStartBlock * NUM_MATS + j * NUM_MATS + 2]); storedResults[baseIndex + 3 * HIDDEN_SIZE] = cuGet(accumulator[j * NUM_MATS + 2][0]); } } T_MATH h_ = _tanh(r * accumulator[j * NUM_MATS + 2][0] + smemi[getSmemSectionI(writeBatch, absStep)][rowStartBlock * NUM_MATS + j * NUM_MATS + 2]); T_MATH val; val = (cuGet(1) - z) * h_ + z * smemcx[rowStartBlock + j][writeBatch]; smemcx[rowStartBlock + j][writeBatch] = val; accumulator[j * NUM_MATS][0] = val; } #pragma unroll for (int j = 0; j < ELE_PER_THREAD_X / 3; j++) { int row = rowStart + j; if (row >= HIDDEN_SIZE) break; y[row + (step * TOTAL_MINIBATCH * OUTPUT_STRIDE + writeBatch * BATCH_OUTPUT_STRIDE)] = getSafeOutput(accumulator[j * NUM_MATS][0]); } } else { #pragma unroll for (int j = 0; j < ELE_PER_THREAD_X; j++) { if (RNN_MODE == CUDNN_RNN_RELU) { if (TRAINING) { int row = rowStart + j; if (row < HIDDEN_SIZE) { storedResults[row + (step * TOTAL_MINIBATCH * HIDDEN_SIZE + writeBatch * BATCH_STRIDE)] = cuGet(accumulator[j][0]); } } accumulator[j][0] = relu(accumulator[j][0]); } else if (RNN_MODE == CUDNN_RNN_TANH) { accumulator[j][0] = _tanh(accumulator[j][0]); if (TRAINING) { int row = rowStart + j; if (row < HIDDEN_SIZE) { storedResults[row + (step * TOTAL_MINIBATCH * HIDDEN_SIZE + writeBatch * BATCH_STRIDE)] = cuGet(accumulator[j][0]); } } } } #pragma unroll for (int j = 0; j < ELE_PER_THREAD_X; j++) { int row = rowStart + j; if (row >= HIDDEN_SIZE) break; y[row + (step * TOTAL_MINIBATCH * OUTPUT_STRIDE + writeBatch * BATCH_OUTPUT_STRIDE)] = getSafeOutput(accumulator[j][0]); } } } else if (colStart == 0) { #pragma unroll for (int j = 0; j < ELE_PER_THREAD_X; j++) { smemBatchGroup[GROUP_BATCH_SIZE - 2 - (batch % GROUP_BATCH_SIZE)][rowStartBlock][j] = accumulator[j][0]; } } if (absStep == seqLength - 1) { int writeBatch = batch - colStart; if (hy != NULL) { if (colStart < GROUP_BATCH_SIZE && batch % GROUP_BATCH_SIZE == GROUP_BATCH_SIZE - 1) { #pragma unroll for (int j = 0; j < ELE_PER_THREAD_X / NUM_MATS; j++) { int row = rowStart + j; if (row >= HIDDEN_SIZE) break; hy[row + writeBatch * BATCH_STRIDE] = cuGet(accumulator[j * NUM_MATS][0]); } } } if (RNN_MODE == CUDNN_LSTM) { __syncthreads(); if (cy != NULL) { if (colStart < GROUP_BATCH_SIZE && batch % GROUP_BATCH_SIZE == GROUP_BATCH_SIZE - 1) { #pragma unroll for (int j = 0; j < ELE_PER_THREAD_X / NUM_MATS; j++) { int row = rowStart + j; if (row >= HIDDEN_SIZE) break; cy[row + writeBatch * BATCH_STRIDE] = cuGet(smemcx[rowStartBlock + j][writeBatch]); } } } } } __syncthreads(); #else if (colStart == 0) { if (RNN_MODE == CUDNN_LSTM) { #pragma unroll for (int j = 0; j < ELE_PER_THREAD_X; j++) { accumulator[j][0] += smemi[getSmemSectionI(batch, absStep)][rowStartBlock * NUM_MATS + j]; } #pragma unroll for (int j = 0; j < ELE_PER_THREAD_X / 4; j++) { T_MATH in_gate = sigmoid(accumulator[j * NUM_MATS + 0][0]); T_MATH forget_gate = sigmoid(accumulator[j * NUM_MATS + 1][0]); T_MATH in_gate2 = _tanh (accumulator[j * NUM_MATS + 2][0]); T_MATH out_gate = sigmoid(accumulator[j * NUM_MATS + 3][0]); T_MATH val = (forget_gate * clip(cuGet(smemcx[rowStartBlock + j][batch]),clipopt,nanopt,lclip,rclip)) + (in_gate * in_gate2); smemcx[rowStartBlock + j][batch] = val; if (TRAINING) { int row = rowStart + j; if (row < HIDDEN_SIZE) { int baseIndex = row + (step * TOTAL_MINIBATCH * HIDDEN_SIZE + batch * BATCH_STRIDE) * 4; storedResults[baseIndex + 0 * HIDDEN_SIZE] = cuGet(in_gate); storedResults[baseIndex + 1 * HIDDEN_SIZE] = cuGet(forget_gate); storedResults[baseIndex + 2 * HIDDEN_SIZE] = cuGet(in_gate2); storedResults[baseIndex + 3 * HIDDEN_SIZE] = cuGet(out_gate); c_data[row + (step * TOTAL_MINIBATCH * HIDDEN_SIZE + batch * BATCH_STRIDE)] = cuGet(val); } } val = out_gate * _tanh(clip(val,clipopt,nanopt,lclip,rclip)); accumulator[j * NUM_MATS][0] = val; } #pragma unroll for (int j = 0; j < ELE_PER_THREAD_X / 4; j++) { int row = rowStart + j; if (row >= HIDDEN_SIZE) break; y[row + (step * TOTAL_MINIBATCH * OUTPUT_STRIDE + batch * BATCH_OUTPUT_STRIDE)] = getSafeOutput(accumulator[j * NUM_MATS][0]); } } else if (RNN_MODE == CUDNN_GRU) { #pragma unroll for (int j = 0; j < ELE_PER_THREAD_X; j++) { if (j % 3 != 2) accumulator[j][0] += smemi[getSmemSectionI(batch, absStep)][rowStartBlock * NUM_MATS + j]; else accumulator[j][0] += smembias[rowStartBlock + j / 3]; } #pragma unroll for (int j = 0; j < ELE_PER_THREAD_X / 3; j++) { T_MATH r = sigmoid(accumulator[j * NUM_MATS + 0][0]); T_MATH z = sigmoid(accumulator[j * NUM_MATS + 1][0]); if (TRAINING) { int row = rowStart + j; if (row < HIDDEN_SIZE) { int baseIndex = row + (step * TOTAL_MINIBATCH * HIDDEN_SIZE + batch * BATCH_STRIDE) * 6; storedResults[baseIndex + 0 * HIDDEN_SIZE] = cuGet(r); storedResults[baseIndex + 1 * HIDDEN_SIZE] = cuGet(z); storedResults[baseIndex + 2 * HIDDEN_SIZE] = cuGet(smemi[getSmemSectionI(batch, absStep)][rowStartBlock * NUM_MATS + j * NUM_MATS + 2]); storedResults[baseIndex + 3 * HIDDEN_SIZE] = cuGet(accumulator[j * NUM_MATS + 2][0]); } } T_MATH h_ = _tanh(r * accumulator[j * NUM_MATS + 2][0] + smemi[getSmemSectionI(batch, absStep)][rowStartBlock * NUM_MATS + j * NUM_MATS + 2]); T_MATH val; val = (cuGet(1) - z) * h_ + z * smemcx[rowStartBlock + j][batch]; smemcx[rowStartBlock + j][batch] = val; accumulator[j * NUM_MATS][0] = val; } #pragma unroll for (int j = 0; j < ELE_PER_THREAD_X / 3; j++) { int row = rowStart + j; if (row >= HIDDEN_SIZE) break; y[row + (step * TOTAL_MINIBATCH * OUTPUT_STRIDE + batch * BATCH_OUTPUT_STRIDE)] = getSafeOutput(accumulator[j * NUM_MATS][0]); } } else { #pragma unroll for (int j = 0; j < ELE_PER_THREAD_X; j++) { accumulator[j][0] += smemi[getSmemSectionI(batch, absStep)][rowStartBlock + j]; } #pragma unroll for (int j = 0; j < ELE_PER_THREAD_X; j++) { if (RNN_MODE == CUDNN_RNN_RELU) { if (TRAINING) { int row = rowStart + j; if (row < HIDDEN_SIZE) { storedResults[row + (step * TOTAL_MINIBATCH * HIDDEN_SIZE + batch * BATCH_STRIDE)] = cuGet(accumulator[j][0]); } } accumulator[j][0] = relu(accumulator[j][0]); } else if (RNN_MODE == CUDNN_RNN_TANH) { accumulator[j][0] = _tanh(accumulator[j][0]); if (TRAINING) { int row = rowStart + j; if (row < HIDDEN_SIZE) { storedResults[row + (step * TOTAL_MINIBATCH * HIDDEN_SIZE + batch * BATCH_STRIDE)] = cuGet(accumulator[j][0]); } } } } #pragma unroll for (int j = 0; j < ELE_PER_THREAD_X; j++) { int row = rowStart + j; if (row >= HIDDEN_SIZE) break; y[row + (step * TOTAL_MINIBATCH * OUTPUT_STRIDE + batch * BATCH_OUTPUT_STRIDE)] = getSafeOutput(accumulator[j][0]); } } } if (absStep == seqLength - 1) { if (colStart == 0 && hy != NULL) { #pragma unroll for (int j = 0; j < ELE_PER_THREAD_X / NUM_MATS; j++) { int row = rowStart + j; if (row >= HIDDEN_SIZE) break; hy[row + batch * BATCH_STRIDE] = cuGet(accumulator[j * NUM_MATS][0]); } } if (RNN_MODE == CUDNN_LSTM) { __syncthreads(); if (colStart == 0 && cy != NULL) { #pragma unroll for (int j = 0; j < ELE_PER_THREAD_X / NUM_MATS; j++) { int row = rowStart + j; if (row >= HIDDEN_SIZE) break; cy[row + batch * BATCH_STRIDE] = cuGet(smemcx[rowStartBlock + j][batch]); } } } } __syncthreads(); #endif } } }