15th October 2025

Fastest BatchNorm Kernel on LeetGPU

Optimizing CUDA code to perform Batch Norm from scratch

In this article, I present a CUDA kernel implementation for computing the Batch Normalization (Batch Norm) operation.

In the end, this resulting kernel 2.54x faster than the previous optimal solution on LeetGPU.
Batch Normalization is a widely used technique in deep neural networks that helps accelerate convergence during training and improve numerical stability. It operates channel-wise across a batch of data, using two learnable parameters: γ\gamma (gamma) and β\beta (beta).

For each channel, we compute the mean and standard deviation as:

μc=1mi=1mxi,c,σc=1mi=1m(xi,cμc)2+ϵ \mu_c = \frac{1}{m} \sum_{i=1}^{m} x_{i,c}, \quad \sigma_c = \sqrt{\frac{1}{m} \sum_{i=1}^{m} (x_{i,c} - \mu_c)^2 + \epsilon}

The input is then normalized and scaled as follows:

x^i,c=xi,cμcσc \hat{x}_{i,c} = \frac{x_{i,c} - \mu_c}{\sigma_c} yi,c=γcx^i,c+βc y_{i,c} = \gamma_c \, \hat{x}_{i,c} + \beta_c

This normalization step first makes x^\hat{x} mean 0 unit variance, and then allowing the network to learn optimal rescaling and shift through the parameters γ\gamma and β\beta.

Understanding the Algorithm

The question in this problem is, how do you figure out what part to parallelize?
Do you parallelize across samples or across channels?
Well, parallelizing only along samples is awkward because the per-channel mean/variance need reductions across the batch (samples) for each channel. You could use atomicAdd and keep accumulating in global memory but that's slow and basically sequential (since they're highly contended by many threads).
So you need to parallelize across channels, maybe assign a block to 1 channel. But this is also suboptimal. Your data has shape [N, C], but since data must be flattened in CUDA arrays, the data looks like
[Sample1 Channel1, Sample1 Channel2, ..., Sample1 ChannelC, Sample2 Channel1, ...]
Imagine your reads now. Adjacent threads are reading from locations in memory C addresses apart. That means you cannot make use of coalescing each memory access! Coalescing really speeds up programs because the hardware then just loads in several adjacent memory addresses together.

The technique is to make use of both parallelization. You have 2D blocks rather than 1D blocks. On the x-axis threads vary across channel. On the y-axis, by sample. So, adjacent threads as per the x-axis now load [Sample1 Channel1, Sample1 Channel2, ..., Sample1 Channel WarpSize] which is coalesced. You can also now apply a reduction pattern to reduce over channels so actually the number of global reads a thread has to do becomes ceil(N / blockDim.y) instead of N.

Note: A reduction pattern is a pattern you can apply to any commutative and associative operator. You can imagine, if + is the operator, then we re-write: (a + (b + (c + d))) as ((a + b) + (c + d)). I.e., the even threads add their neighbors, then the threads divisible by 4 add threads 2 locations away and so on. This is not actually optimal, because for larger problems where the reduction spans multiple blocks and warps, whole warps eventually become useless (be skipped) but they still occupy the hardware. So, people actually write this in reverse, you sum elements from blockDim.y/2 elements away, then blockDim.y/4 elements away and so on, and then the threadIdx.x == 0 has your result

Once we have the mean element and sum of squares of inputs, the answer is really simple as we just apply the formula!

									
#define TILEX 32 // warp dim
#define TILEY 32
__global__ void batch_norm(const float* input, const float* gamma, const float* beta,
float* output, int N, int C, float eps){
__shared__ float Ms[TILEY][TILEX]; // mean sum
__shared__ float Ss[TILEY][TILEX]; // square sum
int sample = threadIdx.y; // only 1 block so blockIdx.y = 0
int channel = threadIdx.x + blockDim.x * blockIdx.x;
// Every thread sums the samples that have (sample % blockDim.y) == threadIdx.y
float m_sum = 0;
float s_sum = 0;
for(int i = sample; i < N; i += blockDim.y){
if(channel < C){
// coalesce across C reads
float v = input[i * C + channel];
m_sum += v;
s_sum += v * v;
}
}
Ms[threadIdx.y][threadIdx.x] = m_sum;
Ss[threadIdx.y][threadIdx.x] = s_sum;
__syncthreads();
// Reduction Pattern
for(int j = blockDim.y / 2; j > 0; j >>= 1){
if(threadIdx.y < j){ // Only accumulate in threads < j, eventually 0 contains the sum
Ms[threadIdx.y][threadIdx.x] += Ms[threadIdx.y + j][threadIdx.x];
Ss[threadIdx.y][threadIdx.x] += Ss[threadIdx.y + j][threadIdx.x];
}
__syncthreads();
}
// Load repeatedly used params into registers
float mean = Ms[0][threadIdx.x] / static_cast<float>(N);
float inv_std = fmaxf(Ss[0][threadIdx.x] / static_cast<float>(N) - mean * mean, 0.0f);
inv_std = rsqrtf(inv_std + eps);
float g = (channel < C) ? gamma[channel] : 0;
float b = (channel < C) ? beta[channel] : 0;
// Write back out for every sample
for(int i = sample; i < N; i += blockDim.y){
if(channel < C){
output[i * C + channel] = g * ((input[i * C + channel] - mean) * inv_std) + b;
}
}
}

Results

In the end, my kernel achieved the fastest time on LeetGPU's BatchNorm benchmark 2.54× faster than the next entry, written by the platform's top-ranked user. This also boosted my overall LeetGPU ranking to #21 globally.

More Stuff

If my cuda work was interesting, please check out my github for more:
CUDA Practice Repo

I hope you enjoyed reading the article

Any feedback is greatly appreciated! (Form is broken, message me on LinkedIn)

Menu