From cf948c1eb1866513fb784779e111db0615a1428d Mon Sep 17 00:00:00 2001 From: Linchenn <40653845+Linchenn@users.noreply.github.com> Date: Mon, 21 Nov 2022 13:04:40 -0800 Subject: [PATCH] Fix batch matching for batch mat mul (#7062) BUG * fix * lint * fix --- tfjs-backend-cpu/src/kernels/BatchMatMul.ts | 10 ++++++---- tfjs-backend-wasm/src/cc/batch_mat_mul_impl.cc | 16 +++++++--------- tfjs-backend-webgl/src/mulmat_packed_gpu.ts | 4 ++-- tfjs-core/src/ops/mat_mul_test.ts | 9 +++++++++ 4 files changed, 24 insertions(+), 15 deletions(-) diff --git a/tfjs-backend-cpu/src/kernels/BatchMatMul.ts b/tfjs-backend-cpu/src/kernels/BatchMatMul.ts index b933729750..14c9bb4916 100644 --- a/tfjs-backend-cpu/src/kernels/BatchMatMul.ts +++ b/tfjs-backend-cpu/src/kernels/BatchMatMul.ts @@ -106,12 +106,14 @@ export function batchMatMul(args: { let sum = 0.0; for (let k = k0; k < kBlock; k++) { - const batchOffsetA = Math.min(bi, batchDimA - 1) * aBatch; - const batchOffsetB = Math.min(bi, batchDimB - 1) * bBatch; + const batchIndexA = bi % batchDimA; + const batchIndexB = bi % batchDimB; const aVal = - a3dValues[batchOffsetA + i * aOuterStep + k * aInnerStep]; + // tslint:disable-next-line: max-line-length + a3dValues[batchIndexA * aBatch + i * aOuterStep + k * aInnerStep]; const bVal = - b3dValues[k * bInnerStep + j * bOuterStep + batchOffsetB]; + // tslint:disable-next-line: max-line-length + b3dValues[k * bInnerStep + j * bOuterStep + batchIndexB * bBatch]; sum += aVal * bVal; } resVals[bi * size + (i * rightDim + j)] += sum; diff --git a/tfjs-backend-wasm/src/cc/batch_mat_mul_impl.cc b/tfjs-backend-wasm/src/cc/batch_mat_mul_impl.cc index 8ca7902bb7..941c21b6b7 100644 --- a/tfjs-backend-wasm/src/cc/batch_mat_mul_impl.cc +++ b/tfjs-backend-wasm/src/cc/batch_mat_mul_impl.cc @@ -174,7 +174,7 @@ void slow_batch_matmul(const size_t a_id, const size_t* a_shape_ptr, const size_t shared_dim = transpose_a ? a_shape_ptr[1] : a_shape_ptr[2]; const size_t left_dim = transpose_a ? a_shape_ptr[2] : a_shape_ptr[1]; const size_t right_dim = transpose_b ? b_shape_ptr[1] : b_shape_ptr[2]; - const size_t batch_dim = a_shape_ptr[0]; + const size_t batch_dim = std::max(a_shape_ptr[0], b_shape_ptr[0]); std::vector a_shape(a_shape_ptr, a_shape_ptr + a_shape_len); std::vector b_shape(b_shape_ptr, b_shape_ptr + b_shape_len); @@ -235,14 +235,12 @@ void slow_batch_matmul(const size_t a_id, const size_t* a_shape_ptr, float sum = 0.0; for (size_t k = k0; k < k_block; ++k) { - const size_t batch_offset_a = - std::min(b, a_shape[0] - 1) * a_batch; - const size_t batch_offset_b = - std::min(b, b_shape[0] - 1) * b_batch; - sum += - a_buf[batch_offset_a + i * a_outer_step + - k * a_inner_step] * - b_buf[k * b_inner_step + j * b_outer_step + batch_offset_b]; + const size_t batch_index_a = b % a_shape[0]; + const size_t batch_index_b = b % b_shape[0]; + sum += a_buf[batch_index_a * a_batch + i * a_outer_step + + k * a_inner_step] * + b_buf[k * b_inner_step + j * b_outer_step + + batch_index_b * b_batch]; } size_t innermost_dim = i * right_dim + j; size_t out_buf_index = b * size + innermost_dim; diff --git a/tfjs-backend-webgl/src/mulmat_packed_gpu.ts b/tfjs-backend-webgl/src/mulmat_packed_gpu.ts index e6cdd4efcf..cafa97eac9 100644 --- a/tfjs-backend-webgl/src/mulmat_packed_gpu.ts +++ b/tfjs-backend-webgl/src/mulmat_packed_gpu.ts @@ -78,9 +78,9 @@ export class MatMulPackedProgram implements GPGPUProgram { let batchASnippet = 'rc.x'; let batchBSnippet = 'rc.x'; if (aShape[0] < bShape[0]) { - batchASnippet = `int(min(float(rc.x), ${aShape[0] - 1}.))`; + batchASnippet = `imod(rc.x, ${aShape[0]})`; } else if (bShape[0] < aShape[0]) { - batchBSnippet = `int(min(float(rc.x), ${bShape[0] - 1}.))`; + batchBSnippet = `imod(rc.x, ${bShape[0]})`; } this.userCode = ` diff --git a/tfjs-core/src/ops/mat_mul_test.ts b/tfjs-core/src/ops/mat_mul_test.ts index 43226f49ba..636ef14e9c 100644 --- a/tfjs-core/src/ops/mat_mul_test.ts +++ b/tfjs-core/src/ops/mat_mul_test.ts @@ -894,6 +894,15 @@ describeWithFlags('matmulBatch', ALL_ENVS, () => { ]); }); + it('A has more batch dimensions than B', async () => { + const a = tf.tensor4d( + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], [2, 2, 2, 2]); + const b = tf.tensor3d([1, 2, 3, 4], [2, 2, 1]); + + const c = tf.matMul(a, b); + expectArraysClose(await c.data(), [5, 11, 39, 53, 29, 35, 95, 109]); + }); + it('batch dimensions do not match', () => { const a = tf.tensor3d( [