Optimizing a CNN in CUDA
The main reason for this project was to gain a better understanding of how to optimize CUDA kernels. To archieve that i implemented multiple versions of a different functions in CUDA and compared their performance. In order to not just compare random or idealistic functions i chose to implement a convolutional neural network (CNN) in CUDA instead.
Please note: The main goal of this project is to gain understanding in the methods used to optimize CUDA code.
All optimizations have been tested to produce the same results as the original code. However, the original code presented here was not tested for correctness as it was only used as a reference for optimization. Find the full code here: Github.
Basics of Performance Optimization in CUDA
In order to do a meaningful analysis of the performance of the different versions it is important to lay out a few ground rules of what you are actually trying to optimize and how to measure it.
With GPU kernels measuring the execution time is not trivial as the kernel is executed asynchronously.
Here i used the sotware "NVIDIA Nsight Compute" to get data of the execution time and behavior of the kernels.
All measurements were done on a NVIDIA GeForce GTX 2080 Ti with the programs being executed using ncu -f -o OUT_NAME IN_NAME
.
For more information please look at the source code on GitHub.
When looking at the asembly code of the CUDA kernels we can, for simplicity sake, assume that all FFMA instructions are the actual calculations. and all other instructions are simply overhead that wastes time. The left image clearly shows an reacurring pattern of FFMA instructions that are interrupted by other instructions that compute the adress of the memory. While the right image shows a block of FFMA instructions that are uninterrupted by other instructions.
Gradient Weights Reduction
The gradient weights calculation is the most complicated and interesting part of the backward pass. It is needed to calculate the gradient weight and bias values used by the optimizer to update the weights and biases of the network. What makes it so interesting is that it features a global reduction of the gradients wich is not that trivial to efficiently implement in a highly parallel environment like a GPU.
I used a simple reduction algorithm that uses shared memory to store the values and then reduce them.
Each thread computes it part of the reduction and then stores the result in shared memory.
Then half of the threads add their value to the other half of the threads value.
This is repeated until only one value is left.
Now every block has one value that is the sum of all values in the block.
To reduce the values of all blocks to one value I used atomicAdd()
to add the value of the block to a global memory location.
template<int filter_size>
__global__
void baseline(Tensor<float, 4> input, Tensor<float, 4> error, Tensor<float, 4> gradient_weights) {
const short block_y = blockIdx.y;
const short block_x = blockIdx.x;
const short thread_y = threadIdx.y;
const short thread_x = threadIdx.x;
// get our position in the grid. #blocks*blocksize + position in block
const short x = block_x * blockDim.x + thread_x;
const short y = block_y * blockDim.y + thread_y;
const short z = blockIdx.z * blockDim.z;
const short batch_size = input.getDim(0),
input_channels = input.getDim(1),
output_channels = error.getDim(1),
height = error.getDim(2),
width = error.getDim(3);
// get the channels we are working on
const short c_in = z % input_channels;
const short c_out = z / input_channels;
// local id for shared memory
const short tid = thread_y * blockDim.x + thread_x;
const short threads_in_block = blockDim.x * blockDim.y;
assert(c_in < input_channels && c_out < output_channels);
// start of actual calculation
extern __shared__ float sm[];
for (short j = (-filter_size+1)/2; j <= filter_size/2; j++) {
for (short i = (-filter_size+1)/2; i <= filter_size/2; i++) {
float val = 0.;
for (short b = 0; b < batch_size; b++) {
float input_val =0;
if (x + i >= 0 && x + i < width && y + j >= 0 && y + j < height) {
// for each image go through the same filter for one single channel
input_val=input(b, c_in, y + j, x + i);
}
val += input_val * error(b, c_out, y, x);
}
sm[tid] = val;
__syncthreads();
// intra block reduction
for (int n=threads_in_block/2;n > 0;n>>=1) {
if (tid < n)
sm[tid] += sm[tid + n];
__syncthreads();
}
// final "reduction"
if(tid==0)
atomicAdd(&gradient_weights(c_out, c_in, j+(filter_size-1)/2, i+(filter_size-1)/2), sm[tid]);
}
}
}
Optimizations
In order to improve code clarity the following code example will be simplified. To see the original working code please look at the GitHub repository.
Use shfl for reduction
The function __shfl_down_sync()
is a shuffle instruction that allows us to exchange values between threads in a warp.
This can be used to reduce the values of the threads in a warp.
As the warp size is 32 we can only use this function if the number for 32 threads at a time.
For simplicity we will only use this function for the last reduction step.
template<int filter_size>
__global__
void w_shfl(Tensor<float, 4> input, Tensor<float, 4> error, Tensor<float, 4> gradient_weights) {
...
extern __shared__ float sm[];
for (short j = (-filter_size+1)/2; j <= filter_size/2; j++) {
for (short i = (-filter_size+1)/2; i <= filter_size/2; i++) {
float val = 0.;
for (short b = 0; b < batch_size; b++) {
float input_val =0;
if (x + i >= 0 && x + i < width && y + j >= 0 && y + j < height) {
// for each image go through the same filter for one single channel
input_val=input(b, c_in, y + j, x + i);
}
val += input_val * error(b, c_out, y, x);
}
sm[tid] = val;
__syncthreads();
// intra block reduction
int n = threads_in_block;
while (n > 32) {
n = (n + 1) / 2;
if (tid < n){
sm[tid] += sm[tid + n];
sm[tid + n] = 0;
}
__syncthreads();
}
if (tid < 32) {
double val = sm[tid];
for (int offset = 16; offset > 0; offset /= 2)
val += __shfl_down_sync(0xffffffff, val, offset);
if (tid == 0)
atomicAdd(&gradient_weights(c_out, c_in, j+(filter_size-1)/2, i+(filter_size-1)/2), val);
}
}
}
}
Seperate Calculation and Reduction
By seperating the calculation and the reduction we can reorder the loops of the calculation. The new inner loop over the filter size allows us to load the values into registers and use them for multiple calculations.
template<int filter_size>
__global__
void w_shfl_registers(Tensor<float, 4> input, Tensor<float, 4> error, Tensor<float, 4> gradient_weights) {
extern __shared__ float sm[];
float val[filter_size][filter_size];
// init val
for(int j=0;j<filter_size;j++){
for(int i=0;i<filter_size;i++){
val[j][i]=0;
}
}
for (short b = 0; b < batch_size; b++) {
#pragma unroll
for (short j = (-filter_size+1)/2, jj =0; j <= filter_size/2; j++, jj++) {
for (short i = (-filter_size+1)/2,ii =0; i <= filter_size/2; i++, ii++) {
float input_val =0;
if (x + i >= 0 && x + i < width && y + j >= 0 && y + j < height) {
// for each image go through the same filter for one single channel
input_val=input(b, c_in, y + j, x + i);
}
val[jj][ii] += input_val * error(b, c_out, y, x);
}
}
}
for (short j = 0; j < filter_size; j++) {
for (short i =0; i < filter_size; i++) {
sm[tid] = val[j][i];
__syncthreads();
int n = threads_in_block;
while (n > 32) {
n = (n + 1) / 2;
if (tid < n){
sm[tid] += sm[tid + n];
sm[tid + n] = 0;
}
__syncthreads();
}
if (tid < 32) {
double val = sm[tid];
for (int offset = 16; offset > 0; offset /= 2)
val += __shfl_down_sync(0xffffffff, val, offset);
if (tid == 0)
atomicAdd(&gradient_weights(c_out, c_in, j, i), val);
}
}
}
}
More work per thread and padding
By using a padded input tensor we no longer have to check if the index is in bounds and by giving a thread more work we can reduce the overhead of launching a new thread. As the final values have to be reduced anyway we can let a single thread make multiple of those calculations. This reduces the total number of needed reductions and can increase the performance.
/*
Expects error.getDim(2) (height of the image) to be divisable by yBatch
*/
template<int filter_size, int yBatch>
__global__
void w_shfl_registers_reordered_padded_work(Tensor<float, 4> inputPadded, Tensor<float, 4> error, Tensor<float, 4> gradient_weights) {
extern __shared__ float sm[];
float val[filter_size][filter_size];
// init val
for(int j=0;j<filter_size;j++){
for(int i=0;i<filter_size;i++){
val[j][i]=0;
}
}
for (int b = 0; b < batch_size; b++) {
#pragma unroll
for ( int yy = 0; yy < yBatch; yy++){
for (int j = 0; j < filter_size; j++) {
for (int i =0; i < filter_size; i++) {
val[j][i] += inputPadded(b, c_in, y + yy + j, x + i) *error(b, c_out, y + yy, x);
}
}
}
}
sm[tid] = 0;
__syncthreads();
for (int j = 0; j < filter_size; j++) {
for (int i =0; i < filter_size; i++) {
double vval = val[j][i];
for (int offset = 16; offset > 0; offset /= 2)
vval += __shfl_down_sync(0xffffffff, vval, offset);
if (tid % 32 == 0) {
sm[tid/32] = vval;
}
__syncthreads();
// int n = threads_in_block;
if (tid < 32) { // only works if threads_in_block/32 <= 32 or threads_in_block <= 1024
double val = sm[tid];
for (int offset = 16; offset > 0; offset /= 2)
val += __shfl_down_sync(0xffffffff, val, offset);
if (tid == 0)
atomicAdd(&gradient_weights(c_out, c_in, j, i), val);
}
}
}
}
It is important to note that the results of the optimizations are not always the same and change with time. For example when I wrote this code a few years ago the speedup was able to reach around 40. Since then the compiler and the hardware have changed and improved.
Results
This optimization project is multiple years old. Therefore the effectivenes of the shown optimizations has changed over time. However, the general idea of the optimizations should still be valid.
Convolution forward
In the Forward case the parallelization using a GPU is simple. We iterate over the output by giving each thread an unique x and y position and an output channel on the output image. One then iterates over all entries of the batch and apply the weighted stencil (here 3x3) on the input in order to calculate the output at the given position.
This is only a simple example of a forward convolution. In a real CNN one would also have to consider the stride, padding and dilation. The here presented code is only able to handle a stride of 1 and no padding or dilation. Furthermore the code is only able to handle a stencil that is square and has an odd size.
template<int filter_size>
__global__
void cf_baseline(Tensor<float, 4> input, Tensor<float, 4> output, Tensor<float, 4> weights, Tensor<float, 1> bias) {
int x = blockIdx.x * blockDim.x + threadIdx.x;
int y = blockIdx.y * blockDim.y + threadIdx.y;
int c_out = blockIdx.z * blockDim.z;
int batch_size = input.getDim(0),
input_channels = input.getDim(1),
output_channels = output.getDim(1),
width = output.getDim(3),
height = output.getDim(2);
if (x >= width || y >= height || c_out >= output_channels)
return;
float channel_bias = bias(c_out);
for (int b = 0; b < batch_size; b++) {
float val = channel_bias;
for (int c_in = 0; c_in < input_channels; c_in++) {
for (int j = (-filter_size+1)/2; j <= (filter_size)/2; j++) {
for (int i = (-filter_size+1)/2; i <= (filter_size)/2; i++) {
float input_val = (x+i>=0&&x+i<width&&y+j>=0&&y+j<height)?input(b,c_in,y+j,x+i):0;
val += input_val * weights(c_out, c_in, j + (filter_size-1) / 2, i+(filter_size-1)/2);
}
}
}
output(b, c_out, y, x) = fmaxf(0.,val);
}
}
Code simplification
In order to make changes to the code more easily understandable i will present the code in a more structured way. This means removing the initialization of the variables and the checks if the index is in bounds. To see the original working code please look at the GitHub repository.
The code presented above will now look like this:
template<int filter_size>
__global__
void cf_baseline(Tensor<float, 4> input, Tensor<float, 4> output, Tensor<float, 4> weights, Tensor<float, 1> bias) {
float channel_bias = bias(c_out);
for (int b = 0; b < batch_size; b++) {
float val = channel_bias;
for (int c_in = 0; c_in < input_channels; c_in++) {
for (int j = (-filter_size+1)/2; j <= (filter_size)/2; j++) {
for (int i = (-filter_size+1)/2; i <= (filter_size)/2; i++) {
float input_val = (x+i>=0&&x+i<width&&y+j>=0&&y+j<height)?input(b,c_in,y+j,x+i):0;
val += input_val * weights(c_out, c_in, j + (filter_size-1) / 2, i+(filter_size-1)/2);
}
}
}
output(b, c_out, y, x) = fmaxf(0.,val);
}
}
Padding
One of the easiest ways to optimize the code is to add padding to the input tensor. This way we can avoid the checks if the index is in bounds and can simply access the index without any checks. As if-else statements are expensive on the GPU this can lead to a significant speedup.
template<int filter_size>
__global__
void cf_padded(Tensor<float, 4> inputPadded, Tensor<float, 4> output, Tensor<float, 4> weights, Tensor<float, 1> bias) {
float channel_bias = bias(c_out);
for (int b = 0; b < batch_size; b++) {
float val = channel_bias;
for (int c_in = 0; c_in < input_channels; c_in++) {
for (int j = 0; j < filter_size; j++) {
for (int i = 0; i < filter_size; i++) {
// here we can simply access the index without any checks
float input_val = inputPadded(b,c_in,y+j,x+i);
val += input_val * weights(c_out, c_in, j, i);
}
}
}
output(b, c_out, y, x) = fmaxf(0.,val);
}
}
More work per thread
Another way to optimize the code is to make each thread do more work. Instead of only calculating one output pixel per thread we can calculate multiple. This way we can reduce the overhead of launching a new thread and can increase the performance.
The fact that all calculations that are performaed on the same channels use the same weights and bias values. This allows us to load these values once and use them for all calculations.
template<int filter_size, int yBatch>
__global__
void cf_padded_work(Tensor<float, 4> inputPadded, Tensor<float, 4> output, Tensor<float, 4> weights, Tensor<float, 1> bias) {
// save the outputs in a register to reduce the number of memory accesses
float reg_outputs[yBatch];
float channel_bias = bias(c_out);
for (int b = 0; b < batch_size; b++) {
#pragma unroll
for ( int yy = 0; yy < yBatch; yy++) reg_outputs[yy] = channel_bias;
for ( int c_in = 0; c_in < input_channels; c_in++) {
for ( int yy = 0; yy < yBatch; yy++){
for (int j = 0; j < filter_size; j++) {
for (int i = 0; i < filter_size; i++) {
float input_val = inputPadded(b,c_in,y+yy+j,x+i);
reg_outputs[yy] += input_val * weights(c_out, c_in, j, i);
}
}
}
}
#pragma unroll
for ( int yy = 0; yy < yBatch; yy++){
output(b, c_out, y + yy, x) = fmaxf(0.,reg_outputs[yy]);
}
}
}
Use registers as cache
Using the previous optimization as a base we can see that the weights and bias values rarely change. In fact they are only dependent on the output channel and the input channel. This means that we can load the values into registers and use them for multiple calculations. Additionally we can also load the input values into registers. As the input values are used multiple times in the calculations. If structured correctly only 3 of the 9 input values have to be loaded from memory for each calculation.
Instead of explicitly loading the values into registers we can use the compiler to do this for us. If there are enough registers available the compiler will automatically load array values into registers.
template<int filter_size, int yBatch>
__global__
void cf_padded_work_save(Tensor<float, 4> inputPadded, Tensor<float, 4> output, Tensor<float, 4> weights, Tensor<float, 1> bias) {
float weights_reg[filter_size][filter_size];
float input_reg[filter_size][filter_size];
float channel_bias = bias(c_out);
for (int b = 0; b < batch_size; b++) {
float reg_outputs[yBatch];
#pragma unroll
for ( int yy = 0; yy < yBatch; yy++){
reg_outputs[yy] = channel_bias;
}
for (int c_in = 0; c_in < input_channels; c_in++) {
// init weights_reg and error_reg
for(int j=0;j<filter_size;j++){
for(int i=0;i<filter_size;i++){
weights_reg[j][i]=weights(c_out, c_in, j, i);
input_reg[j][i]=inputPadded(b, c_in,y+j, x+i);
}
}
for ( int yy = 0; yy < yBatch; yy++){
for (int j = 0; j < filter_size; j++) {
for (int i = 0; i < filter_size ; i++) {
reg_outputs[yy] += input_reg[j][i] * weights_reg[j][i];
}
}
for(int j=0;j<filter_size-1;j++){
for(int i=0;i<filter_size;i++){
input_reg[j][i]=input_reg[j+1][i];
}
}
for (int i = 0; i < filter_size; i++) {
input_reg[filter_size-1][i]= inputPadded(b, c_in, y + yy + filter_size, x + i);
}
}
}
#pragma unroll
for ( int yy = 0; yy < yBatch; yy++){
output(b, c_out, y + yy, x) = fmaxf(0.,reg_outputs[yy]);
}
}
}
Loop reordering
As the weights are only dependent on the output channel and the input channel and not on the batch position we can move the loop over the batch further inward. This way the weights have only to be loaded once for all batch positions. The problem of this approach ist that the array size is now dependent on the batch size. As the batch size is not known at compile time we have to use a template parameter to define the size of the array. This also means that high batch sizes (approximatly higher then 10) will need more registers then are available and will limit the kernels occupancy rate drastically.
template<int filter_size, int batch_size, int yBatch>
__global__
void cf_padded_work_reorder(Tensor<float, 4> inputPadded, Tensor<float, 4> output, Tensor<float, 4> weights, Tensor<float, 1> bias) {
float reg_outputs[batch_size][yBatch];
float reg_weights[filter_size][filter_size];
float channel_bias = bias(c_out);
for (int b = 0; b < batch_size; b++) {
for ( int yy = 0; yy < yBatch; yy++){
reg_outputs[b][yy] = channel_bias;
}
}
for (int c_in = 0; c_in < input_channels; c_in++) {
// load data into reg_weights
#pragma unroll
for (int j = 0; j < filter_size; j++) {
for (int i = 0; i < filter_size; i++) {
reg_weights[j][i] = weights(c_out, c_in, j, i);
}
}
// batch_size, yBatch and filter_size are known at compile time. This allows the compiler to optimize the code further
#pragma unroll
for (int b = 0; b < batch_size; b++) {
for ( int yy = 0; yy < yBatch; yy++){
for (int j = 0; j < filter_size; j++) {
for (int i = 0; i < filter_size; i++) {
float input_val = inputPadded(b,c_in,y+yy+j,x+i);
reg_outputs[b][yy] += input_val * reg_weights[j][i];
}
}
}
}
}
#pragma unroll
for (int b = 0; b < batch_size; b++) {
for ( int yy = 0; yy < yBatch; yy++){
output(b, c_out, y+yy, x) = fmaxf(0.,reg_outputs[b][yy]);
}
}
}
Results
Error calculation
In the error calculation we iterate over the error tensor and calculate the error at the given position. To do this we iterate over all output channels and apply the weights to the error tensor in order to calculate the error at the given position.
template<int filter_size>
__global__
void cb_baseline(Tensor<float, 4> error, Tensor<float, 4> next_error, Tensor<float,4> output,Tensor<float, 4> weights) {
int x = blockIdx.x * blockDim.x + threadIdx.x;
int y = blockIdx.y * blockDim.y + threadIdx.y;
int c_in = blockIdx.z * blockDim.z;
int batch_size = error.getDim(0),
input_channels = next_error.getDim(1),
output_channels = error.getDim(1),
height = next_error.getDim(2),
width = next_error.getDim(3);
assert(error.getDim(2)==width);
if (x >= width || y >= height || c_in >= input_channels)
return;
for (int b = 0; b < batch_size; b++) {
float val = 0.;
for (int c_out = 0; c_out < output_channels; c_out++) {
// continue;
for (int j = (-filter_size+1)/2; j <= (filter_size)/2; j++) {
for (int i = (-filter_size+1)/2; i <= (filter_size)/2; i++) {
float error_val =0;
if(x + i >= 0 && x + i < width && y + j >= 0 && y + j < height) {
error_val=(output(b, c_out, y + j, x + i) != 0.) ? error(b, c_out, y + j, x + i) : 0.;
}
val += error_val * weights(c_out, c_in, filter_size/2 - j , filter_size/2 - i );
}
}
}
next_error(b, c_in, y, x) = val;
}
}
Once again we the shown code will be presented in a more simplified way from here on.
Padding
As in the forward case we can add padding to the input tensor. This removes a lot of if-else statements and can lead to a significant speedup.
template<int filter_size>
__global__
void cb_padded(Tensor<float, 4> errorPadded, Tensor<float, 4> next_error, Tensor<float,4> outputPadded,Tensor<float, 4> weights) {
for (int b = 0; b < batch_size; b++) {
float val = 0.;
for (int c_out = 0; c_out < output_channels; c_out++) {
for (int j = 0; j < filter_size; j++) {
for (int i = 0; i < filter_size; i++) {
float error_val =0;
error_val=(outputPadded(b, c_out, y + j, x + i) != 0.) ? errorPadded(b, c_out, y + j, x + i) : 0.;
val += error_val * weights(c_out, c_in, filter_size - j - 1, filter_size - i -1);
}
}
}
next_error(b, c_in, y, x) = val;
}
}
Remove ReLU
It is possible to remove the ReLU function from the error calculation and instead apply it in a separate kernel. Interestingly this leads to a significant speedup.
__global__
static void ReLU_backwards(Tensor<float, 4> error, Tensor<float, 4> output) {
int x = blockIdx.x * blockDim.x + threadIdx.x;
int y = blockIdx.y * blockDim.y + threadIdx.y;
int c = blockIdx.z * blockDim.z;
int batch_size = error.getDim(0),
channels = error.getDim(1),
width = output.getDim(3),
height = output.getDim(2);
if (x >= width || y >= height || c >= channels)
return;
for (int b = 0; b < batch_size; b++) {
error(b, c, y, x) = (output(b, c, y, x) != 0.)?error(b, c, y, x):0.;
}
}
On one basic test on a NVIDA GeForce RTX 2080 Ti GPU the new ReLU kernel only needs 0.04 seconds according to the NVIDIA NSight Compute profiler while the baseline backwards kernel took over 2.6 seconds. This means that the time consumption of the ReLU kernel is negligible compared to the baseline kernel.
template<int filter_size>
__global__
void cb_padded_norelu(Tensor<float, 4> errorPadded, Tensor<float, 4> next_error, Tensor<float,4> outputPadded,Tensor<float, 4> weights) {
for (int b = 0; b < batch_size; b++) {
float val = 0.;
for (int c_out = 0; c_out < output_channels; c_out++) {
#pragma unroll
for (int j = 0; j < filter_size; j++) {
for (int i = 0; i < filter_size; i++) {
float error_val = errorPadded(b, c_out, y + j, x + i);
val += error_val * weights(c_out, c_in, filter_size - j - 1, filter_size - i -1);
}
}
}
next_error(b, c_in, y, x) = val;
}
}
More Work per Thread
We can again give each thread more work to do. Once the output channel is known the weights will not change. Therefore we can simply give each thread a set of points that all need the same weights to be calculated. And as all the condition values for the for loops are known at compile time (filter_size,yBatch) the compiler can optimize the code further by unrolling the loops and using registers to store the values. This only works if the yBatch is small enough otherwise the kernel will use too many registers or will not have enough registers to store all the values.
template<int filter_size, int yBatch>
__global__
void cb_padded_norelu_work(Tensor<float, 4> errorPadded, Tensor<float, 4> next_error, Tensor<float,4> outputPadded,Tensor<float, 4> weights) {
float reg_next_errors[yBatch];
for (int b = 0; b < batch_size; b++) {
#pragma unroll
for ( int yy = 0; yy < yBatch; yy++){
reg_next_errors[yy] = 0.;
}
for (int c_out = 0; c_out < output_channels; c_out++) {
#pragma unroll
for ( int yy = 0; yy < yBatch; yy++){
for (int j = 0; j < filter_size; j++) {
for (int i = 0; i < filter_size; i++) {
float error_val = errorPadded(b, c_out, y + yy + j, x + i);
reg_next_errors[yy] += error_val * weights(c_out, c_in, filter_size - j - 1, filter_size - i -1);
}
}
}
}
#pragma unroll
for ( int yy = 0; yy < yBatch; yy++){
next_error(b, c_in, y + yy, x) = reg_next_errors[yy];
}
}
}