diff --git a/tfjs-core/src/ops/conv1d.ts b/tfjs-core/src/ops/conv1d.ts index 4a04106870..fb77ea02cf 100644 --- a/tfjs-core/src/ops/conv1d.ts +++ b/tfjs-core/src/ops/conv1d.ts @@ -83,6 +83,12 @@ function conv1d_( conv_util.eitherStridesOrDilationsAreOne(stride, dilation), () => 'Error in conv1D: Either stride or dilation must be 1. ' + `Got stride ${stride} and dilation '${dilation}'`); + util.assert( + conv_util.stridesOrDilationsArePositive(dilation), + () => 'Error in conv1D: Dilated rates should be larger than 0.'); + util.assert( + conv_util.stridesOrDilationsArePositive(stride), + () => 'Error in conv1D: Stride should be larger than 0.'); util.assert( dataFormat === 'NWC', () => `Error in conv1d: got dataFormat of ${ diff --git a/tfjs-core/src/ops/conv1d_test.ts b/tfjs-core/src/ops/conv1d_test.ts index d1eef74dbe..d0205d792f 100644 --- a/tfjs-core/src/ops/conv1d_test.ts +++ b/tfjs-core/src/ops/conv1d_test.ts @@ -149,8 +149,8 @@ describeWithFlags('conv1d', ALL_ENVS, () => { const w = tf.tensor3d([3], [fSize, inputDepth, outputDepth]); expect( - () => tf.conv1d( - x, w, stride, pad, dataFormat, dilation, dimRoundingMode)) + () => + tf.conv1d(x, w, stride, pad, dataFormat, dilation, dimRoundingMode)) .toThrowError(); }); @@ -169,8 +169,8 @@ describeWithFlags('conv1d', ALL_ENVS, () => { const w = tf.tensor3d([3], [fSize, inputDepth, outputDepth]); expect( - () => tf.conv1d( - x, w, stride, pad, dataFormat, dilation, dimRoundingMode)) + () => + tf.conv1d(x, w, stride, pad, dataFormat, dilation, dimRoundingMode)) .toThrowError(); }); @@ -203,7 +203,7 @@ describeWithFlags('conv1d', ALL_ENVS, () => { const outputDepth = 1; const fSize = 1; const pad = [[0, 0], [0, 2.1], [1, 1], [0, 0]] as - tf.backend_util.ExplicitPadding; + tf.backend_util.ExplicitPadding; const stride = 1; const dataFormat = 'NWC'; const dilation = 1; @@ -300,6 +300,42 @@ describeWithFlags('conv1d', ALL_ENVS, () => { .toThrowError(); }); + it('throws when stride is less than or equal to 0', () => { + const inputDepth = 1; + const inputShape: [number, number, number] = [2, 2, inputDepth]; + const outputDepth = 1; + const fSize = 1; + const pad = + [[0, 0], [0, 0], [0, 0], [0, 0]] as tf.backend_util.ExplicitPadding; + const stride = 0; + const dataFormat = 'NWC'; + const dilation = 1; + + const x = tf.tensor3d([1, 2, 3, 4], inputShape); + const w = tf.tensor3d([3], [fSize, inputDepth, outputDepth]); + + expect(() => tf.conv1d(x, w, stride, pad, dataFormat, dilation)) + .toThrowError(); + }); + + it('throws when dilation is less than or equal to 0', () => { + const inputDepth = 1; + const inputShape: [number, number, number] = [2, 2, inputDepth]; + const outputDepth = 1; + const fSize = 1; + const pad = + [[0, 0], [0, 0], [0, 0], [0, 0]] as tf.backend_util.ExplicitPadding; + const stride = 1; + const dataFormat = 'NWC'; + const dilation = 0; + + const x = tf.tensor3d([1, 2, 3, 4], inputShape); + const w = tf.tensor3d([3], [fSize, inputDepth, outputDepth]); + + expect(() => tf.conv1d(x, w, stride, pad, dataFormat, dilation)) + .toThrowError(); + }); + it('throws when both stride and dilation are greater than 1', () => { const inputDepth = 1; const inputShape: [number, number, number] = [2, 2, inputDepth]; diff --git a/tfjs-core/src/ops/conv2d.ts b/tfjs-core/src/ops/conv2d.ts index 97e0d293ea..564657482e 100644 --- a/tfjs-core/src/ops/conv2d.ts +++ b/tfjs-core/src/ops/conv2d.ts @@ -94,6 +94,12 @@ function conv2d_( conv_util.eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in conv2D: Either strides or dilations must be 1. ' + `Got strides ${strides} and dilations '${dilations}'`); + util.assert( + conv_util.stridesOrDilationsArePositive(dilations), + () => 'Error in conv2D: Dilated rates should be larger than 0.'); + util.assert( + conv_util.stridesOrDilationsArePositive(strides), + () => 'Error in conv2D: Strides should be larger than 0.'); const inputs: Conv2DInputs = {x: x4D, filter: $filter}; const attrs: diff --git a/tfjs-core/src/ops/conv2d_test.ts b/tfjs-core/src/ops/conv2d_test.ts index 4bab60a5c7..86a18ae8e7 100644 --- a/tfjs-core/src/ops/conv2d_test.ts +++ b/tfjs-core/src/ops/conv2d_test.ts @@ -747,6 +747,37 @@ describeWithFlags('conv2d', ALL_ENVS, () => { .toThrowError(); }); + it('throws when stride is less than or equal to 0', async () => { + const inputDepth = 1; + const inputShape: [number, number, number] = [2, 2, inputDepth]; + const outputDepth = 1; + const fSize = 1; + const pad = 0; + const stride: [number, number] = [1, 0]; + + const x = tf.tensor3d([1, 2, 3, 4], inputShape); + const w = tf.tensor4d([2], [fSize, fSize, inputDepth, outputDepth]); + + expect(() => tf.conv2d(x, w, stride, pad)).toThrowError(); + }); + + it('throws when dilation is less than or equal to 0', async () => { + const inputDepth = 1; + const inputShape: [number, number, number] = [2, 2, inputDepth]; + const outputDepth = 1; + const fSize = 1; + const pad = 0; + const stride = 1; + const dataFormat = 'NHWC'; + const dilation: [number, number] = [1, 0]; + + const x = tf.tensor3d([1, 2, 3, 4], inputShape); + const w = tf.tensor4d([2], [fSize, fSize, inputDepth, outputDepth]); + + expect(() => tf.conv2d(x, w, stride, pad, dataFormat, dilation)) + .toThrowError(); + }); + it('throws when both stride and dilation are greater than 1', () => { const inputDepth = 1; const inputShape: [number, number, number] = [2, 2, inputDepth]; diff --git a/tfjs-core/src/ops/conv3d.ts b/tfjs-core/src/ops/conv3d.ts index 79effa71ad..9df84fd528 100644 --- a/tfjs-core/src/ops/conv3d.ts +++ b/tfjs-core/src/ops/conv3d.ts @@ -23,7 +23,7 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; -import {eitherStridesOrDilationsAreOne} from './conv_util'; +import {eitherStridesOrDilationsAreOne, stridesOrDilationsArePositive} from './conv_util'; import {op} from './operation'; import {reshape} from './reshape'; @@ -93,6 +93,12 @@ function conv3d_( dataFormat === 'NDHWC', () => `Error in conv3d: got dataFormat of ${ dataFormat} but only NDHWC is currently supported.`); + util.assert( + stridesOrDilationsArePositive(dilations), + () => 'Error in conv3D: Dilated rates should be larger than 0.'); + util.assert( + stridesOrDilationsArePositive(strides), + () => 'Error in conv3D: Strides should be larger than 0.'); const inputs: Conv3DInputs = {x: x5D, filter: $filter}; diff --git a/tfjs-core/src/ops/conv3d_test.ts b/tfjs-core/src/ops/conv3d_test.ts index 28764ea6b9..2681796d16 100644 --- a/tfjs-core/src/ops/conv3d_test.ts +++ b/tfjs-core/src/ops/conv3d_test.ts @@ -483,4 +483,36 @@ describeWithFlags('conv3d', ALL_ENVS, () => { expect(() => tf.conv3d(x, w, stride, pad, dataFormat)).toThrowError(); }); + + it('throws when stride is less than or equal to 0', async () => { + const inputDepth = 1; + const outputDepth = 1; + const inputShape: [number, number, number, number] = [2, 2, 1, inputDepth]; + const pad = 'valid'; + const fSize = 1; + const stride = 0; + const dataFormat = 'NDHWC'; + + const x = tf.tensor4d([1, 2, 3, 4], inputShape); + const w = tf.tensor5d([2], [fSize, fSize, fSize, inputDepth, outputDepth]); + + expect(() => tf.conv3d(x, w, stride, pad, dataFormat)).toThrowError(); + }); + + it('throws when dilation is less than or equal to 0', async () => { + const inputDepth = 1; + const outputDepth = 1; + const inputShape: [number, number, number, number] = [2, 2, 1, inputDepth]; + const pad = 'valid'; + const fSize = 1; + const stride = 0; + const dataFormat = 'NDHWC'; + const dilation: [number, number, number] = [1, 1, 0]; + + const x = tf.tensor4d([1, 2, 3, 4], inputShape); + const w = tf.tensor5d([2], [fSize, fSize, fSize, inputDepth, outputDepth]); + + expect(() => tf.conv3d(x, w, stride, pad, dataFormat, dilation)) + .toThrowError(); + }); }); diff --git a/tfjs-core/src/ops/conv_util.ts b/tfjs-core/src/ops/conv_util.ts index efbaebfcbf..3be66a02ca 100644 --- a/tfjs-core/src/ops/conv_util.ts +++ b/tfjs-core/src/ops/conv_util.ts @@ -582,6 +582,11 @@ export function eitherStridesOrDilationsAreOne( return tupleValuesAreOne(strides) || tupleValuesAreOne(dilations); } +export function stridesOrDilationsArePositive(values: number| + number[]): boolean { + return parseTupleParam(values).every(value => value > 0); +} + /** * Convert Conv2D dataFormat from 'NHWC'|'NCHW' to * 'channelsLast'|'channelsFirst' @@ -621,19 +626,20 @@ export function checkPadOnDimRoundingMode( if (dimRoundingMode != null) { if (typeof pad === 'string') { throw Error( - `Error in ${opDesc}: pad must be an integer when using ` + + `Error in ${opDesc}: pad must be an integer when using ` + `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`); } else if (typeof pad === 'number') { util.assert( - util.isInt(pad), + util.isInt(pad), () => `Error in ${opDesc}: pad must be an integer when using ` + `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`); } else if (typeof pad === 'object') { - (pad as ExplicitPadding).forEach(p => {p.forEach(v =>{ - util.assert( - util.isInt(v), - () => `Error in ${opDesc}: pad must be an integer when using ` + - `dimRoundingMode ${dimRoundingMode} but got pad ${v}.`); + (pad as ExplicitPadding).forEach(p => { + p.forEach(v => { + util.assert( + util.isInt(v), + () => `Error in ${opDesc}: pad must be an integer when using ` + + `dimRoundingMode ${dimRoundingMode} but got pad ${v}.`); }); }); } else {