Compute shader wave intrinsics tricks

Angry Tomato!
7 min readApr 3, 2024

--

While diving deep into optimizing compute shaders, I encountered numerous challenges that could be tackled efficiently with wave intrinsics. These clever techniques not only solved many issues but also significantly boosted performance by reducing memory interactions. The primary aim of leveraging wave intrinsics is to cut down on the usage of VGPRs (Vector General Purpose Registers) and instead utilize SGPRs (Scalar General Purpose Registers), thus minimizing thread interactions with GPU memory.

If terms like GPU waves, SGPRs, or VGPRs are unfamiliar to you, I highly recommend reading this primer first: link to primer on GPU scalarization. It will provide you with the necessary background to follow along with the examples discussed here.

We will go from simples examples first then go to more complex ones after. So lets start.

1. Branch optimization

In the realm of GPU programming, when different lanes within the same wave diverge in their branching decisions, the GPU is compelled to execute both paths.

Let’s illustrate this with an example:

if(canBeOptimized)
{
// Optimal way
}
else
{
// Non optimal way
}

In this scenario, our intention is to take the optimal path whenever possible. However, due to the nature of GPU execution, in the average case, we might experience degraded performance. This happens because if even one lane within a wave needs to take the non-optimal path, the entire wave must execute both the optimal and non-optimal paths.

To mitigate this issue, we can employ a wave intrinsic called WaveAllTrue, which evaluates to true only if the specified condition holds true for all lanes within a wave. Here's how we can refactor our code to ensure that the optimal path is taken only when the entire wave can do so:

bool optimalWay = WaveAllTrue(canBeOptimized);
if(optimalWay)
{
// Optimal way
}
else
{
// Non optimal way
}

2. Calculate on one lane, read on all

In certain scenarios, we may need to perform an expensive operation only on a single lane within a wave. This is where wave intrinsics like WaveIsFirstLane and WaveReadLaneFirst come in handy. Let's explore an example where we can reduce the number of InterlockedAdd calls, thereby optimizing performance.

Consider the following example:

// Some condition if the lane should be included in addition
bool shouldAdd = ....;

// Count the number of lanes that needs to add to the buffer
uint addCount = WaveActiveCountBits(shouldAdd);

int valueAfterAddition;
if (WaveIsFirstLane())
{
// Add to the buffer
SomeDataBuffer.InterlockedAdd(bufferOffset, addCount, valueAfterAddition);
}
// Read the new value to all lanes
valueAfterAddition = WaveReadLaneFirst(valueAfterAddition);

3. Serialization of Writing Data

n scenarios where we need to serialize writing data to a buffer to prevent multiple lanes from writing to the same location, we can employ a specific pattern. This pattern utilizes two new wave intrinsics: WavePrefixCountBits and WavePrefixSum.

These intrinsics are quite similar; WavePrefixSumfor the n-th lane calculates the sum of the 0..(n-1)-th parameters passed by lanes, while WavePrefixCountBitsis equivalent to WavePrefixSum(condition ? 1 : 0)

In the simple case where each lane writes only 1 item:

uint baseWriteOffset = ....;
uint writeOffset = baseWriteOffset + WavePrefixCountBits(true);
SomeBuffer[writeOffset] = someValue;

The more complex case where every lane have different number of items to write:

// Number of items to write
uint numItems = ....;

// Offset for this lane relative to first lane
uint localOffset = WavePrefixSum(numItems);

// Offset for the whole wave
uint baseWriteOffset;

// Similar to 2. trick but on last lane insead of first
if(WaveGetLaneIndex() == WaveGetLaneCount() - 1)
{
uint totalItemCount = localOffset + numItems;
BufferItemCounter.InterlockedAdd(0, totalItemCount , baseWriteOffset);
}
baseWriteOffset = WaveReadLaneAt(baseWriteOffset, WaveGetLaneCount() - 1);

// Write the data to the buffer
for(uint i=0;i<numItems;i++)
{
SomeBuffer[baseWriteOffset + localOffset + i] = someValue;
}

4. Scalarization

Scalarization is a technique used to reduce VGPR (Vector General Purpose Register) usage, especially in scenarios with heavy branching. In scalarization, instead of executing all branches simultaneously, we carefully choose which lanes should execute so only one branch is active at a time, this will give us ability to use more SGPRs.

Let’s consider an example with lights where scalarization can be beneficial, particularly when calculating multiple light types within the same shader. The following code might seem complex initially, but understanding it will enhance your grasp of wave concepts:

uint lightType = LightData[threadID.x].LightType;

// Execute until all branches are processed
while(WaveActiveAnyTrue(lightType== WaveReadLaneFirst(lightType)))
{
// Using read lane first so it is stored in SGPR
uint currentLightType = WaveReadLaneFirst(lightType);

// Only the lanes that share the same lightType as first lane will go in
if(lightType == currentLightType)
{
if(lightType == LIGHT_TYPE_SPOT)
CalculateSpot(...);
else if(lightType == LIGHT_TYPE_DIR)
CalculateDir(...);
...

// After this line all of the lanes with currentLightType will be inactive
// so WaveReadLaneFirst in the next iteration will give different value
// Since WaveReadLaneFirst gives the first ACTIVE lane
break;
}
}

We can combine this pattern with 1. pattern we talked about, so if all waves have same light type we can go the faster route

uint lightType = LightData[threadID.x].LightType;
uint laneMask = WaveBallot(lightType == WaveReadLaneFirst(lightType));
if(laneMask == WaveBallot(true)) // Fast way without scalarization
{
if(lightType == LIGHT_TYPE_SPOT)
.....
}
else // Slower way with scalarization but less VGPRs
{
while(WaveActiveAnyTrue.....)
.....
}

6. Multiple wave parallelization

The all algorithms I’ve mentioned and the future ones have for a goal that multiple waves can exist in one threadgroup, which means that size of threadgroup is independent of size of waves. This can be done as we can work of one threadgroup divide withinn waves and use groupshared memory as memory for transfer between multiple waves in the group.

This is better explained with the example, so I will go straight to that:

groupshared float g_GroupCache[CacheSize];

void FetchCache()
{
// There goes code for calculating cache
// This can be divided by waves or by lanes
......

// At the end we must call this so we stop faster waves of
// starting work before everything is done
GroupMemoryBarrierWithGroupSync();
}

float DoWorkForLane(SomeInputStruct inputData)
{
const uint laneIndex = WaveGetLaneIndex();
// Do work for lane
....
}

void ProcessInput(uint inputIndex)
{
// SGPR
const SomeInputStruct input = WaveReadLaneFirst(InputData[inputIndex]);

const float result = DoWorkForLane(input);

const float totalResult = WaveActiveSum(result);
if (WaveIsFirstLane())
{
InterlockedAdd(SomeOutputBuffer[0], totalResult);
}
}

// numthreads 128 => for wavesize 32 we will have 4 waves per group
// or 2 waves per group if its 64
[numthreads(128, 1, 1)]
CS(void, uint3 threadID : SV_DispatchThreadID)
{
FetchCache();

// Every wave will process one input
const uint inputIndex = threadID.x / WaveGetLaneCount();

// Process input index
// It is important to say, in one threadgroup this value is not the same for all threads
// But it is the same for all lanes in a wave
// Which means in the case of wavesize=64, we will have 4 different inputIndex
// which will be parallely processed by 4 different waves
ProcessInput(inputIndex);
}

7. Indirect dispatch thread group count calculation

In certain scenarios, we encounter a compute shader that’s getting things ready for the next indirect compute shader because we’re unsure about the amount of data we’ll have.

It can be the challenge of how to increase the group count. This can be tricky because it needs to go up by one for every TARGET_NUMTHREADS inputs.

So here is an example how this can be done:

// We are preparing for some shader what will have numthreads(256,1,1)
// Log2(256) = 8
// This should be passed from CPU
#define TARGET_NUM_THREADS_LOG2 8

[numthreads(128, 1, 1)]
CS(void, uint3 threadID : SV_DispatchThreadID)
{
const uint numberOfItems; // Number of items we want to process
const uint totalItems = WaveActiveSum(numberOfItems);
if (WaveIsFirstLane())
{
uint baseIndex;
InterlockedAdd(SomeCountBuffer, totalItems, baseIndex);

// If we initialize the buffer on CPU as (1,1,1) we can delete this
if (baseIndex == 0)
{
InterlockedAdd(IndirectArguments[0].ThreadGroupCountX, 1);
InterlockedAdd(IndirectArguments[0].ThreadGroupCountY, 1);
InterlockedAdd(IndirectArguments[0].ThreadGroupCountZ, 1);
}

// Calculation how much threadgroups we need after totalItems addition
const uint minGroupRequired = baseIndex >> TARGET_NUM_THREADS_LOG2;
const uint maxGroupRequired = (baseIndex + totalItems) >> TARGET_NUM_THREADS_LOG2;
if(maxGroupRequired > minGroupRequired)
{
InterlockedAdd(IndirectArguments[0].ThreadGroupCountX, maxGroupRequired - minGroupRequired);
}
}
}

8. Dividing the work between lanes

In the situation where each lanes have uneven work distribution, if we can keep full occupancy if we evenly distribute the work among the lanes in the shader.

I will show you one way to do this. Just keep in mind using this way in some situation can cause SALU bottleneck.

[numthreads(128, 1, 1)]
CS(void, uint3 groupThreadID : SV_GroupThreadID)
{
// We are assuming that inputs of one lane are one after another
const uint firstDataIndex;

// Number of inputs we want to add (that higly varies between lanes)
const uint dataCount;

// See how much inputs we will insert in this wave
const uint totalInputs = WaveActiveSum(dataCount);

// First we add the total amount of inputs to the DataCountBuffer to allocate space
uint writeOffset;
if(WaveIsFirstLane())
{
InterlockedAdd(DataCountBuffer[0], totalInputs, writeOffset);
}
writeOffset = WaveReadLaneFirst(writeOffset)

// Read lane by lane data and process it
uint currentOffset = 0;
for(uint currentWave = 0; currentWave < WaveGetLaneCount(); currentWave++)
{
// This is data for the currentWave-th lane
const uint waveFirstDataIndex = WaveReadLaneAt(firstDataIndex, currentWave);
const uint waveFirstDataOffset = writeOffset + currentOffset;
const uint waveDataCount = WaveReadLaneAt(dataCount, currentWave);

// Divide the work among all of the lanes in a wave
for(uint i = waveFirstDataIndex; i < waveFirstDataIndex + waveDataCount + WaveGetLaneCount(); i += WaveGetLaneCount())
{
const int dataRemaining = waveDataCount - i;
if(dataRemaining < groupThreadID.x)
{
Data[waveFirstDataOffset + i + groupThreadID.x] = OtherDataBuffer[i + groupThreadID.x];
}
}
currentOffset += waveDataCount;
}
}

--

--