Skip to content

Commit

Permalink
webgpu: support logicalOr and logicalXor operators (tensorflow#7046)
Browse files Browse the repository at this point in the history
  • Loading branch information
xhcao authored and Linchenn committed Jan 9, 2023
1 parent 8c659d7 commit 5b74d0a
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 3 deletions.
8 changes: 7 additions & 1 deletion tfjs-backend-webgpu/src/binary_op_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ export enum BinaryOpType {
LESS,
LESS_EQUAL,
LOGICAL_AND,
LOGICAL_OR,
MAX,
MIN,
MOD,
Expand Down Expand Up @@ -112,9 +113,12 @@ const LESS = 'return f32(a < b);';
const LESS_VEC4 = 'return vec4<f32>(a < b);';
const LESS_EQUAL = 'return f32(a <= b);';
const LESS_EQUAL_VEC4 = 'return vec4<f32>(a <= b);';
const LOGICAL_AND = 'return f32(f32(a) >= 1.0 && f32(b) >= 1.0);';
const LOGICAL_AND = 'return f32(a >= 1.0 && b >= 1.0);';
const LOGICAL_AND_VEC4 = `return (vec4<f32>(a >= vec4<f32>(1.0)) *
vec4<f32>(b >= vec4<f32>(1.0)));`;
const LOGICAL_OR = 'return f32(a >= 1.0 || b >= 1.0);';
const LOGICAL_OR_VEC4 = `return min(vec4<f32>(a >= vec4<f32>(1.0)) +
vec4<f32>(b >= vec4<f32>(1.0)), vec4<f32>(1.0));`;
const MOD = `
${CHECK_NAN_SNIPPET}
if (b == 0.) {
Expand Down Expand Up @@ -263,6 +267,8 @@ export function getBinaryOpString(
return useVec4 ? LESS_EQUAL_VEC4 : LESS_EQUAL;
case BinaryOpType.LOGICAL_AND:
return useVec4 ? LOGICAL_AND_VEC4 : LOGICAL_AND;
case BinaryOpType.LOGICAL_OR:
return useVec4 ? LOGICAL_OR_VEC4 : LOGICAL_OR;
case BinaryOpType.MAX:
return getBinaryWithNanString('max', useVec4);
case BinaryOpType.MIN:
Expand Down
29 changes: 29 additions & 0 deletions tfjs-backend-webgpu/src/kernels/LogicalOr.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/**
* @license
* Copyright 2022 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {KernelConfig, LogicalOr} from '@tensorflow/tfjs-core';

import {BinaryOpType} from '../binary_op_util';
import {binaryKernelFunc} from '../kernel_utils/kernel_funcs_utils';

export const logicalOr = binaryKernelFunc({opType: BinaryOpType.LOGICAL_OR});

export const logicalOrConfig: KernelConfig = {
kernelName: LogicalOr,
backendName: 'webgpu',
kernelFunc: logicalOr
};
2 changes: 2 additions & 0 deletions tfjs-backend-webgpu/src/register_all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ import {logConfig} from './kernels/Log';
import {log1pConfig} from './kernels/Log1p';
import {logicalAndConfig} from './kernels/LogicalAnd';
import {logicalNotConfig} from './kernels/LogicalNot';
import {logicalOrConfig} from './kernels/LogicalOr';
import {maxConfig} from './kernels/Max';
import {maximumConfig} from './kernels/Maximum';
import {maxPoolConfig} from './kernels/MaxPool';
Expand Down Expand Up @@ -211,6 +212,7 @@ const kernelConfigs: KernelConfig[] = [
logConfig,
logicalAndConfig,
logicalNotConfig,
logicalOrConfig,
maxConfig,
maximumConfig,
maxPoolConfig,
Expand Down
2 changes: 0 additions & 2 deletions tfjs-backend-webgpu/src/setup_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,6 @@ const TEST_FILTERS: TestFilter[] = [
'linspace ',
'localResponseNormalization ',
'logSigmoid ',
'logicalOr ',
'logicalXor ',
'maxPool3d ',
'maxPool3dBackprop ',
'maxPoolBackprop ',
Expand Down

0 comments on commit 5b74d0a

Please sign in to comment.