Skip to content

Commit

Permalink
webgpu: support selu operator (#7118)
Browse files Browse the repository at this point in the history
  • Loading branch information
xhcao committed Nov 29, 2022
1 parent 40160ef commit bea721d
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 2 deletions.
30 changes: 30 additions & 0 deletions tfjs-backend-webgpu/src/kernels/Selu.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/**
* @license
* Copyright 2022 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {KernelConfig, Selu} from '@tensorflow/tfjs-core';

import {unaryKernelFunc} from '../kernel_utils/kernel_funcs_utils';

import {UnaryOpType} from '../unary_op_util';

export const selu = unaryKernelFunc({opType: UnaryOpType.SELU});

export const seluConfig: KernelConfig = {
kernelName: Selu,
backendName: 'webgpu',
kernelFunc: selu
};
2 changes: 2 additions & 0 deletions tfjs-backend-webgpu/src/register_all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ import {rsqrtConfig} from './kernels/Rsqrt';
import {scatterNdConfig} from './kernels/ScatterNd';
import {searchSortedConfig} from './kernels/SearchSorted';
import {selectConfig} from './kernels/Select';
import {seluConfig} from './kernels/Selu';
import {sigmoidConfig} from './kernels/Sigmoid';
import {signConfig} from './kernels/Sign';
import {sinConfig} from './kernels/Sin';
Expand Down Expand Up @@ -252,6 +253,7 @@ const kernelConfigs: KernelConfig[] = [
scatterNdConfig,
searchSortedConfig,
selectConfig,
seluConfig,
sigmoidConfig,
signConfig,
sinConfig,
Expand Down
2 changes: 0 additions & 2 deletions tfjs-backend-webgpu/src/setup_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ const TEST_FILTERS: TestFilter[] = [
{
startsWith: 'elu ',
excludes: [
'selu', // Not yet implemented.
'derivative', // gradient function not found.
'gradient' // gradient function not found.
]
Expand Down Expand Up @@ -265,7 +264,6 @@ const TEST_FILTERS: TestFilter[] = [
'raggedRange ',
'raggedTensorToTensor ',
'method otsu', // round
'selu ',
'sparseFillEmptyRows ',
'sparseReshape ',
'sparseSegmentMean ',
Expand Down
12 changes: 12 additions & 0 deletions tfjs-backend-webgpu/src/unary_op_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ export enum UnaryOpType {
RECIPROCAL,
ROUND,
RSQRT,
SELU,
SIGMOID,
SIGN,
SIN,
Expand Down Expand Up @@ -166,6 +167,15 @@ const RELU_VEC4 = `
`;
const ROUND = `return round(a);`;
const RSQRT = `return inverseSqrt(a);`;
// Stable and Attracting Fixed Point (0, 1) for Normalized Weights.
// See: https://arxiv.org/abs/1706.02515
const SELU = `
if (a >= 0.0) {
return ${backend_util.SELU_SCALE} * a;
} else {
return ${backend_util.SELU_SCALEALPHA} * (exp(a) - 1.0);
}
`;
const SIGMOID = `return 1.0 / (1.0 + exp(-1.0 * a));`;
const SIGN = `return sign(a);`;
const SIN = `return sin(a);`;
Expand Down Expand Up @@ -258,6 +268,8 @@ export function getUnaryOpString(type: UnaryOpType, useVec4?: boolean): string {
return ROUND;
case UnaryOpType.RSQRT:
return RSQRT;
case UnaryOpType.SELU:
return SELU;
case UnaryOpType.SIGMOID:
return SIGMOID;
case UnaryOpType.SIGN:
Expand Down

0 comments on commit bea721d

Please sign in to comment.