Benji Visser

Writing a Softmax Kernel From Scratch

Hi! I'm Benji. I'm writing a softmax kernel from scratch because I'm learning about inference engineering. Why? I want to nerdsnipe.

In this post, we'll break down how softmax works, show you how to write your own CUDA kernel for it, and walk through benchmarking each version step by step.

Let's start from the very basics, so no prior knowledge of GPU programming or kernels is required.

First things first: what's a kernel? In GPU programming, a kernel is simply a function that runs on the GPU instead of the CPU. And what about CUDA? CUDA is NVIDIA's toolkit and language for writing these GPU programs. It's low-level and similar to C++ in style and syntax.

Most people don't want to write C++ ish code, and it doesn't have an ecosystem, among other reasons, so we move up an abstraction layer to Python. We use Python libs like PyTorch to do the heavy lifting for us, which come with a ton of pre-built CUDA functions to use.

Alright, now that we've got that covered, let's talk about softmax.

Softmax is one of the most important algorithms in neural networks, especially in the Transformer architecture where it's used to normalize attention scores.

It takes a vector of real numbers and transforms them into probabilities, ensuring the output values are all between 0 and 1 and add up to 1.

Softmax for beginners

softmax(xᵢ) = exp(xᵢ) / Σⱼ exp(xⱼ)

Where:

  • xᵢ is the input vector
  • exp(xᵢ) is the exponential of the input
  • Σⱼ exp(xⱼ) is the sum of the exponentials of the input

Let's assume we have the input logits: [2.0, 1.0, 0.1].

Step 1: exp(xᵢ)

[exp(2.0), exp(1.0), exp(0.1)] = [7.3890561  2.71828183 1.10517092]

Step 2: Σⱼ exp(xⱼ)

[exp(2.0) + exp(1.0) + exp(0.1)] = 11.2125

Step 3: solve

exp(logits) / Σⱼ exp(xⱼ) = [0.65900114 0.24243297 0.09856589]

Great. But it doesn't have numerical stability. What does that mean? When we have large values, or mixed large/small values, the result can be inaccurate.

Large: [1000. 1001. 1002.]
exp(Large): [inf inf inf]

Or mixed large/small values:

Mixed: [0.  500. 1000.]
exp(Mixed): [1.00000000e+000 1.40359222e+217 inf]

The small values get completely dominated by the large values. The solution is to subtract the maximum value from all logits before exponentiating.

Numerical stable formula:

softmax(xᵢ) = exp(xᵢ - max(x)) / Σⱼ exp(xⱼ - max(x))

Great! We're making so much progress—we've learned what Softmax is and how to make it numerically stable.

Now, let's look at how you might implement Softmax on the GPU, and progressively optimize it by writing three different CUDA kernels.

Kernel 1: Naive Softmax

This is the most straightforward approach: three completely separate kernel launches, each pass reading and/or writing the entire array to and from global memory.

  • Kernel 1: Find the maximum in each row, reading all N elements.
  • Kernel 2: Compute the exponentials of each element (after subtracting the max) and store the results, reading and writing all N elements.
  • Kernel 3: Sum all the exponentials (to get the denominator), then divide each by the sum to get the normalized result, reading and writing all N elements.

Although simple to write and reason about, this approach is very inefficient in terms of memory bandwidth:

1: Find max         Read N
2: Compute exp      Read N, Write N
3: Sum and divide   Read N, Write N
Total: 5N memory operations

Kernel 2: Fused Softmax

This approach fuses the computation into a single kernel launch but still needs to read the input twice: once to find the maximum, and once for the exponentials and sum.

  • Pass 1: Each thread scans its row to find the max value (Read N).
  • Pass 2: Each thread loops through the row again, applies the exp(xᵢ - max), accumulates the sum, and may cache results (Read N).
  • Pass 3: Finally, each element is normalized by dividing by the sum, and the output is written back (Write N).

So even though the whole computation is done within one kernel, the input array is still loaded twice from memory.

1: Find max           Read N
2: Compute exp + sum  Read N
3: Normalize          Write N
Total: 3N memory operations

Kernel 3: Online Softmax

This is the most advanced and memory-efficient version, based on the paper Online normalizer calculation for softmax.

  • The key idea is to perform it in a single pass, computing both the maximum and the sum of exponentials in a single loop as you traverse the data. This avoids ever having to make a second pass to recompute the max or sum.
  • After the single read, all that's left is to write back the normalized output.

This algorithm minimizes global memory accesses:

1: Find max AND sum   Read N
2: Normalize          Write N
Total: 2N memory operations

Performance

Before we start writing the kernels, let's dig into some performance metrics.

Modern GPUs have insane computing power compared to their memory bandwidth. A "memory round trip" refers to reading data from memory and then writing the result back—basically, every time we have to fetch or store a number in global GPU memory, it costs precious bandwidth. In the context of softmax, minimizing the number of times we read and write the full input (i.e., the number of passes over the data) directly impacts speed.

Example: For N=4096 elements (typical transformer row size)

  • Each float32 value is 4 bytes → Data = 16 KB per row
  • Compute: ~12,000 FLOPs (comparisons, exponentials, divisions)

on an A100 80GB (SXM):

  • ~2 TB/s global memory bandwidth
  • ~624 TFLOPS (FP16) compute according to the docs

Loading 16KB from global memory:

  • 16KB / 2TB/s = 8 nanoseconds

Compute ~12K FLOPS:

  • 12,000 FLOPS / 624 * 10^12 ops/s = 0.019 ns

So, waiting on global memory is 420x slower than doing the math. We are memory-bound.

This is why fused softmax helps: we load from global memory once, do all our intermediate work in shared memory, then write back once. The naive kernel hits global memory 5 times and the fused kernel hits it only twice.

The online kernel does everything in a single pass, so we read and write global memory exactly once each.

Naive Kernel Implementation

The naive approach uses three separate kernel launches. Each kernel launch reads/writes the full data from global memory.

Each thread handles one row:

// Kernel 1: Find the maximum value in each row
__global__ void find_max_kernel(
    const __half* x,
    __half* max_vals,
    const int batch_size,
    const int seq_len
) {
    int row = blockIdx.x * blockDim.x + threadIdx.x;
    if (row >= batch_size) return;

    const __half* my_row = x + row * seq_len;

    __half max_val = __float2half(-65504.0f);
    for (int i = 0; i < seq_len; i++) {
        max_val = __hmax(max_val, my_row[i]);
    }
    max_vals[row] = max_val;
}

// Kernel 2: Compute exp(x - max) for each element
__global__ void exp_kernel(
    const __half* x,
    const __half* max_vals,
    __half* exp_x,
    const int batch_size,
    const int seq_len
) {
    int row = blockIdx.x * blockDim.x + threadIdx.x;
    if (row >= batch_size) return;

    __half max_val = max_vals[row];
    const __half* my_row = x + row * seq_len;
    __half* my_exp = exp_x + row * seq_len;

    for (int i = 0; i < seq_len; i++) {
        my_exp[i] = hexp(__hsub(my_row[i], max_val));
    }
}

// Kernel 3: Sum and normalize to get final softmax
__global__ void normalize_kernel(
    const __half* exp_x,
    __half* output,
    const int batch_size,
    const int seq_len
) {
    int row = blockIdx.x * blockDim.x + threadIdx.x;
    if (row >= batch_size) return;

    const __half* my_exp = exp_x + row * seq_len;
    __half* my_out = output + row * seq_len;

    // Sum all values
    float sum = 0.0f;
    for (int i = 0; i < seq_len; i++) {
        sum += __half2float(my_exp[i]);
    }

    // Normalize
    __half inv_sum = __float2half(1.0f / sum);
    for (int i = 0; i < seq_len; i++) {
        my_out[i] = __hmul(my_exp[i], inv_sum);
    }
}

Memory traffic:

  • Kernel 1: Read N (find max)
  • Kernel 2: Read N + Write N (compute exp)
  • Kernel 3: Read N + Write N (normalize)
  • Total: 5N (3 reads + 2 writes)

Fused Kernel Implementation

The naive kernel uses three separate kernel launches. The fused kernel does everything in one kernel launch.

Each thread handles one row (same as naive, but all in one kernel):

__global__ void softmax_fused_kernel(
    const __half* x,
    __half* output,
    const int batch_size,
    const int seq_len
) {
    int row = blockIdx.x * blockDim.x + threadIdx.x;
    if (row >= batch_size) return;

    const __half* x_row = x + row * seq_len;
    __half* out_row = output + row * seq_len;

    // Pass 1: Find max
    float max_val = -65504.0f;
    for (int i = 0; i < seq_len; i++) {
        max_val = fmaxf(max_val, __half2float(x_row[i]));
    }

    // Pass 2: Compute exp and sum
    float sum = 0.0f;
    for (int i = 0; i < seq_len; i++) {
        sum += expf(__half2float(x_row[i]) - max_val);
    }

    // Pass 3: Normalize and write output
    for (int i = 0; i < seq_len; i++) {
        float result = expf(__half2float(x_row[i]) - max_val) / sum;
        out_row[i] = __float2half(result);
    }
}

Memory traffic:

  • Read N (find max)
  • Read N (compute exp and sum)
  • Write N (normalize)
  • Total: 3N (2 reads + 1 write)

Online Kernel Implementation

The online kernel computes max AND sum in a single pass. The reason we can do this is because when we find a new max, we rescale our running sum.

The Algorithm

This psuedo code shows the rescaling process.

For each element x_i:
    m_new = max(m, x_i)
    d = d * exp(m - m_new) + exp(x_i - m_new)
    m = m_new

Final: softmax(x_i) = exp(x_i - m) / d

Why does this work? When the max changes, all previous exp terms need adjustment:

exp(x_j - m_new) = exp(x_j - m_old) * exp(m_old - m_new)

So we multiply the entire running sum by exp(m_old - m_new).

The Code

__global__ void softmax_online_kernel(
    const __half* x,
    __half* output,
    const int batch_size,
    const int seq_len
) {
    int row = blockIdx.x * blockDim.x + threadIdx.x;
    if (row >= batch_size) return;

    const __half* x_row = x + row * seq_len;
    __half* out_row = output + row * seq_len;

    float m = -65504.0f;  // Running max
    float d = 0.0f;       // Running sum of exp(x_i - m)

    for (int i = 0; i < seq_len; i++) {
        float xi = __half2float(x_row[i]);

        float m_new = fmaxf(m, xi);

        // Rescale running sum, then add new term
        d = d * expf(m - m_new) + expf(xi - m_new);

        m = m_new;
    }

    // Second pass: normalize and write output
    for (int i = 0; i < seq_len; i++) {
        float xi = __half2float(x_row[i]);
        float result = expf(xi - m) / d;
        out_row[i] = __float2half(result);
    }
}

Memory traffic:

  • Read N (single pass: max AND sum)
  • Write N (normalize)
  • Total: 2N

Benchmarks

I tested the kernels on a g4dn.xlarge (Tesla T4, 16GB). All kernels verified correct against PyTorch's F.softmax().

Results

Online kernel is roughly 2x faster than naive!

 Batch  SeqLen |        Naive        Fused       Online |  Speedup vs Naive
----------------------------------------------------------------------
    32     512 |     0.1324ms     0.0776ms     0.0682ms |  1.94x (Online)
    32    1024 |     0.2670ms     0.1623ms     0.1408ms |  1.90x (Online)
    32    2048 |     0.5335ms     0.3368ms     0.2845ms |  1.88x (Online)
    32    4096 |     1.0633ms     0.6715ms     0.5666ms |  1.88x (Online)
    32    8192 |     2.0770ms     1.3246ms     1.1147ms |  1.86x (Online)
    64     512 |     0.2007ms     0.1234ms     0.1024ms |  1.96x (Online)
    64    1024 |     0.3986ms     0.2468ms     0.2033ms |  1.96x (Online)
    64    2048 |     0.7899ms     0.4933ms     0.4007ms |  1.97x (Online)
    64    4096 |     1.5726ms     0.9834ms     0.8058ms |  1.95x (Online)
    64    8192 |     3.1366ms     1.9745ms     1.6136ms |  1.94x (Online)

Why PyTorch is Faster

Our kernels only use one thread per row, because I'm still learning, and I wanted to keep the code simple. PyTorch is faster because it uses fancy things like block parallelism and warp reductions.

The point of this project was to demonstrate how reducing memory passes improves performance, and to hone my skills with CUDA. Reducing memory passes is the same principle that makes FlashAttention fast. Look out for my next project... reimplementing FlashAttention.