Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add positive dilation and strides check #7063

Merged
merged 2 commits into from Nov 22, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions tfjs-core/src/ops/conv1d.ts
Expand Up @@ -83,6 +83,12 @@ function conv1d_<T extends Tensor2D|Tensor3D>(
conv_util.eitherStridesOrDilationsAreOne(stride, dilation),
() => 'Error in conv1D: Either stride or dilation must be 1. ' +
`Got stride ${stride} and dilation '${dilation}'`);
util.assert(
conv_util.stridesOrDilationsArePositive(dilation),
() => 'Error in conv1D: Dilated rates should be larger than 0.');
util.assert(
conv_util.stridesOrDilationsArePositive(stride),
() => 'Error in conv1D: Stride should be larger than 0.');
util.assert(
dataFormat === 'NWC',
() => `Error in conv1d: got dataFormat of ${
Expand Down
46 changes: 41 additions & 5 deletions tfjs-core/src/ops/conv1d_test.ts
Expand Up @@ -149,8 +149,8 @@ describeWithFlags('conv1d', ALL_ENVS, () => {
const w = tf.tensor3d([3], [fSize, inputDepth, outputDepth]);

expect(
() => tf.conv1d(
x, w, stride, pad, dataFormat, dilation, dimRoundingMode))
() =>
tf.conv1d(x, w, stride, pad, dataFormat, dilation, dimRoundingMode))
.toThrowError();
});

Expand All @@ -169,8 +169,8 @@ describeWithFlags('conv1d', ALL_ENVS, () => {
const w = tf.tensor3d([3], [fSize, inputDepth, outputDepth]);

expect(
() => tf.conv1d(
x, w, stride, pad, dataFormat, dilation, dimRoundingMode))
() =>
tf.conv1d(x, w, stride, pad, dataFormat, dilation, dimRoundingMode))
.toThrowError();
});

Expand Down Expand Up @@ -203,7 +203,7 @@ describeWithFlags('conv1d', ALL_ENVS, () => {
const outputDepth = 1;
const fSize = 1;
const pad = [[0, 0], [0, 2.1], [1, 1], [0, 0]] as
tf.backend_util.ExplicitPadding;
tf.backend_util.ExplicitPadding;
const stride = 1;
const dataFormat = 'NWC';
const dilation = 1;
Expand Down Expand Up @@ -300,6 +300,42 @@ describeWithFlags('conv1d', ALL_ENVS, () => {
.toThrowError();
});

it('throws when stride is less than or equal to 0', () => {
const inputDepth = 1;
const inputShape: [number, number, number] = [2, 2, inputDepth];
const outputDepth = 1;
const fSize = 1;
const pad =
[[0, 0], [0, 0], [0, 0], [0, 0]] as tf.backend_util.ExplicitPadding;
const stride = 0;
const dataFormat = 'NWC';
const dilation = 1;

const x = tf.tensor3d([1, 2, 3, 4], inputShape);
const w = tf.tensor3d([3], [fSize, inputDepth, outputDepth]);

expect(() => tf.conv1d(x, w, stride, pad, dataFormat, dilation))
.toThrowError();
});

it('throws when dilation is less than or equal to 0', () => {
const inputDepth = 1;
const inputShape: [number, number, number] = [2, 2, inputDepth];
const outputDepth = 1;
const fSize = 1;
const pad =
[[0, 0], [0, 0], [0, 0], [0, 0]] as tf.backend_util.ExplicitPadding;
const stride = 1;
const dataFormat = 'NWC';
const dilation = 0;

const x = tf.tensor3d([1, 2, 3, 4], inputShape);
const w = tf.tensor3d([3], [fSize, inputDepth, outputDepth]);

expect(() => tf.conv1d(x, w, stride, pad, dataFormat, dilation))
.toThrowError();
});

it('throws when both stride and dilation are greater than 1', () => {
const inputDepth = 1;
const inputShape: [number, number, number] = [2, 2, inputDepth];
Expand Down
6 changes: 6 additions & 0 deletions tfjs-core/src/ops/conv2d.ts
Expand Up @@ -94,6 +94,12 @@ function conv2d_<T extends Tensor3D|Tensor4D>(
conv_util.eitherStridesOrDilationsAreOne(strides, dilations),
() => 'Error in conv2D: Either strides or dilations must be 1. ' +
`Got strides ${strides} and dilations '${dilations}'`);
util.assert(
conv_util.stridesOrDilationsArePositive(dilations),
() => 'Error in conv2D: Dilated rates should be larger than 0.');
util.assert(
conv_util.stridesOrDilationsArePositive(strides),
() => 'Error in conv2D: Strides should be larger than 0.');

const inputs: Conv2DInputs = {x: x4D, filter: $filter};
const attrs:
Expand Down
31 changes: 31 additions & 0 deletions tfjs-core/src/ops/conv2d_test.ts
Expand Up @@ -747,6 +747,37 @@ describeWithFlags('conv2d', ALL_ENVS, () => {
.toThrowError();
});

it('throws when stride is less than or equal to 0', async () => {
const inputDepth = 1;
const inputShape: [number, number, number] = [2, 2, inputDepth];
const outputDepth = 1;
const fSize = 1;
const pad = 0;
const stride: [number, number] = [1, 0];

const x = tf.tensor3d([1, 2, 3, 4], inputShape);
const w = tf.tensor4d([2], [fSize, fSize, inputDepth, outputDepth]);

expect(() => tf.conv2d(x, w, stride, pad)).toThrowError();
});

it('throws when dilation is less than or equal to 0', async () => {
const inputDepth = 1;
const inputShape: [number, number, number] = [2, 2, inputDepth];
const outputDepth = 1;
const fSize = 1;
const pad = 0;
const stride = 1;
const dataFormat = 'NHWC';
const dilation: [number, number] = [1, 0];

const x = tf.tensor3d([1, 2, 3, 4], inputShape);
const w = tf.tensor4d([2], [fSize, fSize, inputDepth, outputDepth]);

expect(() => tf.conv2d(x, w, stride, pad, dataFormat, dilation))
.toThrowError();
});

it('throws when both stride and dilation are greater than 1', () => {
const inputDepth = 1;
const inputShape: [number, number, number] = [2, 2, inputDepth];
Expand Down
8 changes: 7 additions & 1 deletion tfjs-core/src/ops/conv3d.ts
Expand Up @@ -23,7 +23,7 @@ import {convertToTensor} from '../tensor_util_env';
import {TensorLike} from '../types';
import * as util from '../util';

import {eitherStridesOrDilationsAreOne} from './conv_util';
import {eitherStridesOrDilationsAreOne, stridesOrDilationsArePositive} from './conv_util';
import {op} from './operation';
import {reshape} from './reshape';

Expand Down Expand Up @@ -93,6 +93,12 @@ function conv3d_<T extends Tensor4D|Tensor5D>(
dataFormat === 'NDHWC',
() => `Error in conv3d: got dataFormat of ${
dataFormat} but only NDHWC is currently supported.`);
util.assert(
stridesOrDilationsArePositive(dilations),
() => 'Error in conv3D: Dilated rates should be larger than 0.');
util.assert(
stridesOrDilationsArePositive(strides),
() => 'Error in conv3D: Strides should be larger than 0.');

const inputs: Conv3DInputs = {x: x5D, filter: $filter};

Expand Down
32 changes: 32 additions & 0 deletions tfjs-core/src/ops/conv3d_test.ts
Expand Up @@ -483,4 +483,36 @@ describeWithFlags('conv3d', ALL_ENVS, () => {

expect(() => tf.conv3d(x, w, stride, pad, dataFormat)).toThrowError();
});

it('throws when stride is less than or equal to 0', async () => {
const inputDepth = 1;
const outputDepth = 1;
const inputShape: [number, number, number, number] = [2, 2, 1, inputDepth];
const pad = 'valid';
const fSize = 1;
const stride = 0;
const dataFormat = 'NDHWC';

const x = tf.tensor4d([1, 2, 3, 4], inputShape);
const w = tf.tensor5d([2], [fSize, fSize, fSize, inputDepth, outputDepth]);

expect(() => tf.conv3d(x, w, stride, pad, dataFormat)).toThrowError();
});

it('throws when dilation is less than or equal to 0', async () => {
const inputDepth = 1;
const outputDepth = 1;
const inputShape: [number, number, number, number] = [2, 2, 1, inputDepth];
const pad = 'valid';
const fSize = 1;
const stride = 0;
const dataFormat = 'NDHWC';
const dilation: [number, number, number] = [1, 1, 0];

const x = tf.tensor4d([1, 2, 3, 4], inputShape);
const w = tf.tensor5d([2], [fSize, fSize, fSize, inputDepth, outputDepth]);

expect(() => tf.conv3d(x, w, stride, pad, dataFormat, dilation))
.toThrowError();
});
});
20 changes: 13 additions & 7 deletions tfjs-core/src/ops/conv_util.ts
Expand Up @@ -582,6 +582,11 @@ export function eitherStridesOrDilationsAreOne(
return tupleValuesAreOne(strides) || tupleValuesAreOne(dilations);
}

export function stridesOrDilationsArePositive(values: number|
number[]): boolean {
return parseTupleParam(values).every(value => value > 0);
}

/**
* Convert Conv2D dataFormat from 'NHWC'|'NCHW' to
* 'channelsLast'|'channelsFirst'
Expand Down Expand Up @@ -621,19 +626,20 @@ export function checkPadOnDimRoundingMode(
if (dimRoundingMode != null) {
if (typeof pad === 'string') {
throw Error(
`Error in ${opDesc}: pad must be an integer when using ` +
`Error in ${opDesc}: pad must be an integer when using ` +
`dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
} else if (typeof pad === 'number') {
util.assert(
util.isInt(pad),
util.isInt(pad),
() => `Error in ${opDesc}: pad must be an integer when using ` +
`dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
} else if (typeof pad === 'object') {
(pad as ExplicitPadding).forEach(p => {p.forEach(v =>{
util.assert(
util.isInt(v),
() => `Error in ${opDesc}: pad must be an integer when using ` +
`dimRoundingMode ${dimRoundingMode} but got pad ${v}.`);
(pad as ExplicitPadding).forEach(p => {
p.forEach(v => {
util.assert(
util.isInt(v),
() => `Error in ${opDesc}: pad must be an integer when using ` +
`dimRoundingMode ${dimRoundingMode} but got pad ${v}.`);
});
});
} else {
Expand Down