Skip to content

Commit

Permalink
webgpu: refactor atomicAdd code (tensorflow#7025)
Browse files Browse the repository at this point in the history
  • Loading branch information
xhcao authored and Linchenn committed Jan 9, 2023
1 parent 794418e commit 0e46551
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 63 deletions.
21 changes: 6 additions & 15 deletions tfjs-backend-webgpu/src/bincount_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,19 @@
* =============================================================================
*/

import {atomicAddSnippet} from './shader_util';
import {getMainHeaderString as main, WebGPUProgram} from './webgpu_program';
import {computeDispatch, flatDispatchLayout} from './webgpu_util';

const writeSnippet = `
fn bincount_write(index: i32, value: f32) {
var oldValue = atomicLoad(& (result[index]));
var exchanged = false;
for (; !exchanged;) {
let newValueF32 = bitcast<f32>(oldValue) + value;
let newValue = bitcast<i32>(newValueF32);
let res = atomicCompareExchangeWeak(
&(result[index]), oldValue, newValue);
oldValue = res.old_value;
exchanged = res.exchanged;
}
${atomicAddSnippet('&result[index]', 'value', 'float32')}
}
`;

const binaryWriteSnippet = `
fn bincount_write(index: i32, value: f32) {
result[index] = value;
atomicStore(&result[index], bitcast<i32>(value));
}
`;

Expand Down Expand Up @@ -83,9 +75,8 @@ export class BincountProgram implements WebGPUProgram {
let indexVal = i32(getX(index));
if (indexVal < uniforms.binCountSize) {
let value = ${
this.binaryOutput ?
1. :
(this.hasWeights ? 'f32(getW(index))' : '1.')};
this.binaryOutput ? 1. :
(this.hasWeights ? 'getW(index)' : '1.')};
bincount_write(indexVal, value);
}
}` :
Expand All @@ -96,7 +87,7 @@ export class BincountProgram implements WebGPUProgram {
let value = ${
this.binaryOutput ?
1. :
(this.hasWeights ? 'f32(getW(coord[0], coord[1]))' : '1.')};
(this.hasWeights ? 'getW(coord[0], coord[1])' : '1.')};
bincount_write(coord.x * uniforms.binCountSize + indexVal, value);
}
}`}
Expand Down
28 changes: 7 additions & 21 deletions tfjs-backend-webgpu/src/matmul_splitK_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import {backend_util, TensorInfo, util} from '@tensorflow/tfjs-core';

import {activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util';
import {makeMatMulPackedSource, makeMatMulPackedVec4Source, matMulReadFnSource} from './matmul_packed_webgpu';
import {atomicAddSnippet} from './shader_util';
import {getMainHeaderString as main, WebGPUProgram} from './webgpu_program';
import {computeDispatch, flatDispatchLayout} from './webgpu_util';

Expand Down Expand Up @@ -74,26 +75,6 @@ export class MatMulSplitKProgram implements WebGPUProgram {
}

getUserCode(): string {
// atomicAdd only supports uint/int type. For float, we use
// atomicCompareExchangeWeak to simulate.
const atomicAddSnippet = (component: number) => {
return `
for (var i = 0; i < ${component}; i = i + 1)
{
var oldValue = atomicLoad(&(result[flatIndex + i]));
var exchanged = false;
for (; !exchanged;) {
let newValueF32 = bitcast<f32>(oldValue) + ${
component > 1 ? 'value[i]' : 'value'};
let newValue = bitcast<i32>(newValueF32);
let res = atomicCompareExchangeWeak(&(result[flatIndex + i]), oldValue, newValue);
oldValue = res.old_value;
exchanged = res.exchanged;
}
}
`;
};

const component = this.isVec4 ? 4 : 1;
const userCode = `
${
Expand All @@ -107,7 +88,12 @@ export class MatMulSplitKProgram implements WebGPUProgram {
let flatIndex = getOutputIndexFromCoords(coords);
// The problem is that we should initialize output to zero before using.
// Otherwise, the original value will be added to the result.
${atomicAddSnippet(component)}
for (var i = 0; i < ${component}; i = i + 1) {
${
atomicAddSnippet(
'&result[flatIndex + i]', `${component > 1 ? 'value[i]' : 'value'}`,
'float32')}
}
}
}
${
Expand Down
33 changes: 7 additions & 26 deletions tfjs-backend-webgpu/src/scatter_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/

import {DataType} from '@tensorflow/tfjs-core';
import {atomicAddSnippet} from './shader_util';
import {getCoordsDataType, getMainHeaderString as main, mapToWgslTypes, WebGPUProgram} from './webgpu_program';
import {computeDispatch, flatDispatchLayout} from './webgpu_util';

Expand Down Expand Up @@ -95,33 +96,8 @@ export class ScatterProgram implements WebGPUProgram {
Array.from({length: this.updatesRank}, (_, idx) => `coords[${idx}]`);
const updatesSnippet = `getUpdates(${updatesString.join(', ')})`;

const atomicRMW = (ptr: string, val: string) => {
let atomicAddSnippet = `atomicAdd(${ptr}, bitcast<i32>(${val}))`;
if (this.type === 'float32') {
atomicAddSnippet = `
{
var oldBits = 0;
var newBits = bitcast<i32>(${val});
loop {
let info = atomicCompareExchangeWeak(${ptr}, oldBits, newBits);
if (info.exchanged) {
break;
}
oldBits = info.old_value;
let oldValue = bitcast<f32>(oldBits);
let newValue = oldValue + (${val});
newBits = bitcast<i32>(newValue);
}
}
`;
}
const atomicStoreSnippet = `atomicStore(${ptr}, bitcast<i32>(${val}));`;
return this.sumDupeIndices ? atomicAddSnippet : atomicStoreSnippet;
};

const userCode = `
${getUpdatesCoordsFromFlatIndex}
${main('index')} {
if (index < uniforms.updatesSize) {
let coords = getUpdatesCoordsFromFlatIndex(index);
Expand All @@ -134,7 +110,12 @@ export class ScatterProgram implements WebGPUProgram {
${mapToWgslTypes(this.type, false)}(${updatesSnippet});
let flatIndex = getOutputIndexFromCoords(${outCoordsString});
${atomicRMW('&result[flatIndex]', 'updateValue')};
${
this.sumDupeIndices ?
atomicAddSnippet(
'&result[flatIndex]', 'updateValue',
this.type as 'float32' | 'int32') :
`atomicStore(&result[flatIndex], bitcast<i32>(updateValue));`}
}
}`;
return userCode;
Expand Down
25 changes: 24 additions & 1 deletion tfjs-backend-webgpu/src/shader_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* =============================================================================
*/

// Generates GLSL that computes strides.
// Generates WGSL that computes strides.
export function symbolicallyComputeStrides(
indicesArr: number[], variableName: string): string[] {
if (Math.max(...indicesArr) > 3) {
Expand All @@ -32,3 +32,26 @@ export function symbolicallyComputeStrides(

return strides;
}

export const atomicAddSnippet =
(ptr: string, v: string, type: 'int32'|'float32') => {
if (type === 'int32') {
return `atomicAdd(${ptr}, bitcast<i32>(${v}));`;
} else {
// atomicAdd only supports uint/int type. For float, we use
// atomicCompareExchangeWeak to simulate.
return `
{
var oldValue = 0;
loop {
let newValueF32 = bitcast<f32>(oldValue) + (${v});
let newValue = bitcast<i32>(newValueF32);
let res = atomicCompareExchangeWeak(${ptr}, oldValue, newValue);
if res.exchanged {
break;
}
oldValue = res.old_value;
}
}`;
}
};

0 comments on commit 0e46551

Please sign in to comment.