Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eliminate explicit Concat operations in Attention #20556

Merged
merged 13 commits into from
May 24, 2024
204 changes: 138 additions & 66 deletions js/web/lib/wasm/jsep/webgpu/ops/attention.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import {TensorView} from '../../tensor-view';
import {ComputeContext, GpuDataType, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';

import {getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, tensorTypeToWsglValueType, UniformDataElementType, UniformsArrayType} from './common';
import {createConcatProgramInfo} from './concat';

export const enum AttentionQkvFormat {
unknown, // enum value not set, or depends on qkv projection implementation details
Expand Down Expand Up @@ -336,10 +335,15 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor
};

const createAttentionProbsProgramInfo =
(_context: ComputeContext, q: TensorView, key: TensorView, relativePositionBias: TensorView|undefined,
parameters: AttentionParameters, attributes: AttentionAttrs, pastSequenceLength: number) => {
(context: ComputeContext, q: TensorView, key: TensorView, pastKey: TensorView|undefined,
relativePositionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs,
pastSequenceLength: number) => {
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
const probsShape = [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, totalSequenceLength];
const presentKey = parameters.kvNumHeads === undefined && context.outputCount > 1;
const presentKeyShape = presentKey ?
[parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize] :
undefined;

// TODO: handle mask

Expand All @@ -355,34 +359,51 @@ const createAttentionProbsProgramInfo =
const programUniforms: ProgramUniform[] = [
{type: DataType.uint32, data: parameters.sequenceLength}, {type: DataType.uint32, data: vectorizedHeadSize},
{type: DataType.uint32, data: totalSequenceLength}, {type: DataType.uint32, data: parameters.numHeads},
{type: DataType.float, data: alpha}
{type: DataType.float, data: alpha}, {type: DataType.uint32, data: pastSequenceLength},
{type: DataType.uint32, data: parameters.kvSequenceLength}
];

const inputDependencies: ProgramInputTensorInfoDependency[] =
relativePositionBias ? ['type', 'type', 'type'] : ['type', 'type'];

const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
if (pastKey) {
inputDependencies.push('type');
}
if (relativePositionBias) {
inputDependencies.push('type');
}
const outputs = [{dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default}];
if (presentKey) {
outputs.push({dims: presentKeyShape!, dataType: q.dataType, gpuDataType: GpuDataType.default});
}
const getShaderSource = (shaderHelper: ShaderHelper) => {
const qInput = inputVariable('q', q.dataType, q.dims, components);
const kInput = inputVariable('key', key.dataType, key.dims, components);
const inputVars = [qInput, kInput];
if (pastKey) {
const pastKeyInput = inputVariable('past_key', pastKey.dataType, pastKey.dims, components);
inputVars.push(pastKeyInput);
}
if (relativePositionBias) {
inputVars.push(
inputVariable('relative_position_bias', relativePositionBias.dataType, relativePositionBias.dims));
}
const output = outputVariable('output', q.dataType, probsShape);
// const dataType = tensorTypeToWsglStorageType(q.dataType);
const outputVars = [output];
if (presentKey) {
outputVars.push(outputVariable('present_key', q.dataType, presentKeyShape!, components));
}
const f32Type = tensorTypeToWsglValueType(DataType.float, components);

const uniforms: UniformsArrayType = [
{name: 'M', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'},
{name: 'num_heads', type: 'u32'}, {name: 'alpha', type: 'f32' as UniformDataElementType}
{name: 'num_heads', type: 'u32'}, {name: 'alpha', type: 'f32' as UniformDataElementType},
{name: 'past_sequence_length', type: 'u32'}, {name: 'kv_sequence_length', type: 'u32'}
];
return `
const TILE_SIZE = ${TILE_SIZE}u;

var<workgroup> tileQ: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>;
var<workgroup> tileK: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>;
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, output)}
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, ...outputVars)}
${shaderHelper.mainStart([
TILE_SIZE, TILE_SIZE, 1
])}
Expand All @@ -391,15 +412,41 @@ const createAttentionProbsProgramInfo =
let m = workgroup_id.y * TILE_SIZE;
let n = workgroup_id.x * TILE_SIZE;
let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K;
let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K;

${(() => {
if (pastKey && presentKey) {
return `
let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx;
let pastKeyOffset = uniforms.past_sequence_length * uniforms.K * headIdx;`;
} else {
return `
let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K;`;
}
})()}
${presentKey ? 'let presentKeyOffset = headIdx * uniforms.N * uniforms.K;' : ''}
var value = ${f32Type}(0);
for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {
tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];
}
if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {
tileK[TILE_SIZE * local_id.y + local_id.x] = key[kOffset + local_id.y * uniforms.K + w + local_id.x];
var idx = TILE_SIZE * local_id.y + local_id.x;
${(() => {
if (pastKey && presentKey) {
return `
if (n + local_id.y < uniforms.past_sequence_length) {
tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
} else {
tileK[idx] =
key[kOffset + (n + local_id.y - uniforms.past_sequence_length) * uniforms.K + w + local_id.x];
}`;
} else {
return 'tileK[idx] = key[kOffset + local_id.y * uniforms.K + w + local_id.x];';
}
})()}
${
presentKey ?
'present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];' :
''}
}
workgroupBarrier();

Expand Down Expand Up @@ -432,23 +479,25 @@ const createAttentionProbsProgramInfo =
};
return {
name: 'AttentionProbs',
shaderCache: {hint: `${components}`, inputDependencies},
getRunData: () => ({
outputs: [{dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default}],
dispatchGroup: dispatch,
programUniforms
}),
shaderCache: {
hint: `${components};${relativePositionBias !== undefined};${pastKey !== undefined};${context.outputCount}`,
inputDependencies
},
getRunData: () => ({outputs, dispatchGroup: dispatch, programUniforms}),
getShaderSource,
};
};


const createVxAttentionScoreProgramInfo =
(_context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters,
pastSequenceLength: number) => {
(context: ComputeContext, probs: TensorView, v: TensorView, pastValue: TensorView|undefined,
params: AttentionParameters, pastSequenceLength: number) => {
const totalSequenceLength = pastSequenceLength + params.kvSequenceLength;
const nReps = params.nReps ? params.nReps : 1;
const repeatedVHiddenSize = params.vHiddenSize * nReps;
const presentValue = params.kvNumHeads == null && context.outputCount > 1;
const presentValueShape =
presentValue ? [params.batchSize, params.numHeads, totalSequenceLength, params.headSize] : undefined;
const outputShape = [params.batchSize, params.sequenceLength, repeatedVHiddenSize];
const TILE_SIZE = 12;
const dispatch = {
Expand All @@ -460,23 +509,37 @@ const createVxAttentionScoreProgramInfo =
const programUniforms: ProgramUniform[] = [
{type: DataType.uint32, data: params.sequenceLength}, {type: DataType.uint32, data: totalSequenceLength},
{type: DataType.uint32, data: params.vHeadSize}, {type: DataType.uint32, data: params.numHeads},
{type: DataType.uint32, data: repeatedVHiddenSize}
{type: DataType.uint32, data: repeatedVHiddenSize}, {type: DataType.uint32, data: pastSequenceLength},
{type: DataType.uint32, data: params.kvSequenceLength}
];

const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
const inputDependencies: ProgramInputTensorInfoDependency[] =
pastValue ? ['type', 'type', 'type'] : ['type', 'type'];
const outputs = [{dims: outputShape, dataType: probs.dataType, gpuDataType: GpuDataType.default}];
if (presentValue) {
outputs.push({dims: presentValueShape!, dataType: probs.dataType, gpuDataType: GpuDataType.default});
}
const getShaderSource = (shaderHelper: ShaderHelper) => {
const probsHelper = inputVariable('probs', probs.dataType, probs.dims);
const vHelper = inputVariable('v', v.dataType, v.dims);
const inputVars = [probsHelper, vHelper];
if (pastValue) {
inputVars.push(inputVariable('past_value', pastValue.dataType, pastValue.dims));
}
const output = outputVariable('output', probs.dataType, outputShape);
const outputVars = [output];
if (presentValue) {
outputVars.push(outputVariable('present_value', probs.dataType, presentValueShape!));
}
const uniforms: UniformsArrayType = [
{name: 'M', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'},
{name: 'num_heads', type: 'u32'}, {name: 'v_hidden_size', type: 'u32'}
{name: 'num_heads', type: 'u32'}, {name: 'v_hidden_size', type: 'u32'},
{name: 'past_sequence_length', type: 'u32'}, {name: 'kv_sequence_length', type: 'u32'}
];
return `
const TILE_SIZE = ${TILE_SIZE}u;
var<workgroup> tileQ: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>;
var<workgroup> tileK: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>;
${shaderHelper.registerUniforms(uniforms).declareVariables(probsHelper, vHelper, output)}
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, ...outputVars)}
${shaderHelper.mainStart([
TILE_SIZE, TILE_SIZE, 1
])}
Expand All @@ -485,16 +548,43 @@ const createVxAttentionScoreProgramInfo =
let n = global_id.x;

let offsetA = headIdx * (uniforms.M * uniforms.K) + m * uniforms.K;
let offsetB = headIdx * (uniforms.N * uniforms.K) + n;

${(() => {
if (pastValue && presentValue) {
return `
let pastValueOffset = headIdx * uniforms.N * uniforms.past_sequence_length + n;
let vOffset = headIdx * uniforms.N * uniforms.kv_sequence_length + n;
`;
} else {
return `
let offsetB = headIdx * uniforms.N * uniforms.K + n;
`;
}
})()}
${presentValue ? 'let presentValueOffset = headIdx * uniforms.N * uniforms.K + n;' : ''}
var value = ${probsHelper.type.storage}(0);
for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
if (m < uniforms.M && w + local_id.x < uniforms.K) {
tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x];
}
if (n < uniforms.N && w + local_id.y < uniforms.K) {
tileK[TILE_SIZE * local_id.y + local_id.x] = v[offsetB + (w + local_id.y) * uniforms.N];
}
if (m < uniforms.M && w + local_id.x < uniforms.K) {
tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x];
}
if (n < uniforms.N && w + local_id.y < uniforms.K) {
var idx = TILE_SIZE * local_id.y + local_id.x;
${(() => {
if (pastValue && presentValue) {
return `
if (w + local_id.y < uniforms.past_sequence_length) {
tileK[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N];
} else {
tileK[idx] = v[vOffset + (w + local_id.y - uniforms.past_sequence_length) * uniforms.N];
}
`;
} else {
return `
tileK[idx] = v[offsetB + (w + local_id.y) * uniforms.N];
`;
}
})()}
${presentValue ? 'present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];' : ''}
}
workgroupBarrier();
for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {
value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * k + local_id.x];
Expand All @@ -515,12 +605,8 @@ const createVxAttentionScoreProgramInfo =

return {
name: 'AttentionScore',
shaderCache: {inputDependencies},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: probs.dataType, gpuDataType: GpuDataType.default}],
dispatchGroup: dispatch,
programUniforms
}),
shaderCache: {hint: `${pastValue !== undefined};${context.outputCount}`, inputDependencies},
getRunData: () => ({outputs, dispatchGroup: dispatch, programUniforms}),
getShaderSource,
};
};
Expand All @@ -529,38 +615,22 @@ export const applyAttention =
(context: ComputeContext, q: TensorView, k: TensorView, v: TensorView, _maskIndex: TensorView|undefined,
_past: TensorView|undefined, pastKey: TensorView|undefined, pastValue: TensorView|undefined,
relativePositionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs) => {
const outputPresentKey = context.outputCount > 1;
const outputPresentValue = context.outputCount > 2;
const outputCount = context.outputCount;
const pastSequenceLength =
parameters.kvNumHeads != null || (outputPresentKey && outputPresentValue) ? parameters.pastSequenceLength : 0;
parameters.kvNumHeads !== undefined || outputCount > 1 ? parameters.pastSequenceLength : 0;
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
// Concatinate pastKey and K to produce presentKey.
const presentKeyShape = [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize];
const concatKeyInputs = pastKey ? [pastKey, k] : [k];
const key = parameters.kvNumHeads == null && outputPresentKey ?
context.compute(
createConcatProgramInfo(concatKeyInputs, 2, presentKeyShape, k.dataType),
{inputs: concatKeyInputs, outputs: [1]})[0] :
k;

// Concatinate pastValue and V to produce presentValue.
const presentValueShape = [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize];
const concatValueInputs = pastValue ? [pastValue, v] : [v];
const value = parameters.kvNumHeads == null && outputPresentValue ?
context.compute(
createConcatProgramInfo(concatValueInputs, 2, presentValueShape, v.dataType),
{inputs: concatValueInputs, outputs: [2]})[0] :
v;
const inputsK = [q, key];

const inputsK = (parameters.kvNumHeads === undefined && outputCount > 1 && pastKey) ? [q, k, pastKey] : [q, k];
if (relativePositionBias) {
inputsK.push(relativePositionBias);
}

// Run AttentionProbs
const probs = context.compute(
createAttentionProbsProgramInfo(
context, q, key, relativePositionBias, parameters, attributes, pastSequenceLength),
{inputs: inputsK, outputs: [-1]})[0];
context, q, k, outputCount > 1 ? pastKey : undefined, relativePositionBias, parameters, attributes,
pastSequenceLength),
{inputs: inputsK, outputs: (parameters.kvNumHeads === undefined && outputCount > 1) ? [-1, 1] : [-1]})[0];

// Run Softmax
context.compute(
Expand All @@ -570,10 +640,12 @@ export const applyAttention =
{inputs: [probs], outputs: []});

// Run AttrionScore
const inputsV = [probs, value];
const inputsV =
(parameters.kvNumHeads === undefined && outputCount > 1 && pastValue) ? [probs, v, pastValue] : [probs, v];
context.compute(
createVxAttentionScoreProgramInfo(context, probs, value, parameters, pastSequenceLength),
{inputs: inputsV, outputs: [0]});
createVxAttentionScoreProgramInfo(
context, probs, v, outputCount > 1 && pastValue ? pastValue : undefined, parameters, pastSequenceLength),
{inputs: inputsV, outputs: (parameters.kvNumHeads === undefined && outputCount > 1) ? [0, 2] : [0]});
};

const prepare = (context: ComputeContext, parameters: AttentionParameters) => {
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/concat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ const assignOutputData = (inputs: readonly IndicesHelper[], output: IndicesHelpe
return codeLines.join('\n');
};

export const createConcatProgramInfo =
const createConcatProgramInfo =
(inputs: readonly TensorView[], adjustedAxis: number, outputShape: number[], dataType: DataType): ProgramInfo => {
const outputSize = ShapeUtil.size(outputShape);

Expand Down