diff --git a/tfjs-backend-cpu/src/utils/pool_utils.ts b/tfjs-backend-cpu/src/utils/pool_utils.ts index f3d497b966..4730026e20 100644 --- a/tfjs-backend-cpu/src/utils/pool_utils.ts +++ b/tfjs-backend-cpu/src/utils/pool_utils.ts @@ -245,8 +245,9 @@ export function pool3d( } } const outputOffset = outputColOffset + channel; - outputVals[outputOffset] = - poolType === 'avg' ? avgValue / count : minMaxValue; + outputVals[outputOffset] = poolType === 'avg' ? + avgValue / Math.max(count, 1) : + minMaxValue; } } } diff --git a/tfjs-backend-webgl/src/pool_gpu.ts b/tfjs-backend-webgl/src/pool_gpu.ts index 4ae4ae45d8..f3ea8e4d76 100644 --- a/tfjs-backend-webgl/src/pool_gpu.ts +++ b/tfjs-backend-webgl/src/pool_gpu.ts @@ -121,7 +121,7 @@ export class Pool2DProgram implements GPGPUProgram { let returnValue = `${poolType}(${poolType}(${poolType}(` + 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])'; if (poolType === 'avg') { - returnValue = `avgValue / count`; + returnValue = `avgValue / max(count, 1.0)`; } const filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4; @@ -342,7 +342,10 @@ export class Pool3DProgram implements GPGPUProgram { let returnValue = `${poolType}(${poolType}(${poolType}(` + 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])'; if (poolType === 'avg') { - returnValue = `avgValue / count`; + // Use `max(count, 1.0)` instead of `count` in case count === 0.0. + // If count === 0.0, `avgValue` is always 0.0 and we change `count`'s + // value to avoid dividing zero. + returnValue = `avgValue / max(count, 1.0)`; } const filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4; @@ -448,8 +451,8 @@ export class Pool3DProgram implements GPGPUProgram { ${updateSnippet} } } - setOutput(${returnValue}); } + setOutput(${returnValue}); } `; } diff --git a/tfjs-backend-webgpu/src/pool2d_webgpu.ts b/tfjs-backend-webgpu/src/pool2d_webgpu.ts index b87cf319c3..ff31a01daa 100644 --- a/tfjs-backend-webgpu/src/pool2d_webgpu.ts +++ b/tfjs-backend-webgpu/src/pool2d_webgpu.ts @@ -53,7 +53,7 @@ export class Pool2DProgram implements WebGPUProgram { let returnValue = `resultValue`; if (this.poolType === 'avg') { - returnValue = `resultValue / count`; + returnValue = `resultValue / max(count, 1.0)`; } const userCode = ` diff --git a/tfjs-core/src/ops/avg_pool_3d.ts b/tfjs-core/src/ops/avg_pool_3d.ts index 7c35a09a64..3c79d75d6d 100644 --- a/tfjs-core/src/ops/avg_pool_3d.ts +++ b/tfjs-core/src/ops/avg_pool_3d.ts @@ -24,8 +24,8 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; -import {checkPadOnDimRoundingMode} from './conv_util'; import {cast} from './cast'; +import {checkPadOnDimRoundingMode} from './conv_util'; import {op} from './operation'; import {reshape} from './reshape'; @@ -86,6 +86,11 @@ function avgPool3d_( dataFormat === 'NDHWC', () => `Error in avgPool3d: Only NDHWC is currently supported, ` + `but got dataFormat of ${dataFormat}`); + util.assert( + (typeof strides === 'number' && strides > 0) || + (Array.isArray(strides) && strides[0] > 0 && strides[1] > 0 && + strides[2] > 0), + () => `Error in avgPool3d: Stride must be > 0, but got '${strides}'`); checkPadOnDimRoundingMode('avgPool3d', pad, dimRoundingMode); const inputs: AvgPool3DInputs = {x: x5D}; const attrs: diff --git a/tfjs-core/src/ops/avg_pool_3d_test.ts b/tfjs-core/src/ops/avg_pool_3d_test.ts index 308eca2504..7c2c12cbe8 100644 --- a/tfjs-core/src/ops/avg_pool_3d_test.ts +++ b/tfjs-core/src/ops/avg_pool_3d_test.ts @@ -29,6 +29,15 @@ describeWithFlags('avgPool3d', ALL_ENVS, () => { expectArraysClose(await result.data(), [4.5]); }); + it('x=[2,2,2,1] f=[1,2,2] s=1 p=valid', async () => { + const x = tf.tensor4d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2, 1]); + + const result = tf.avgPool3d(x, [1, 2, 2], 1, 'valid'); + + expect(result.shape).toEqual([2, 1, 1, 1]); + expectArraysClose(await result.data(), [2.5, 6.5]); + }); + it('x=[1,1,1,1,1] f=[1,1,1] s=1 [0] => [0]', async () => { const x = tf.tensor5d([0], [1, 1, 1, 1, 1]); @@ -150,6 +159,41 @@ describeWithFlags('avgPool3d', ALL_ENVS, () => { expectArraysClose(await result.data(), expected); }); + it('x=[1,1,1,1,1] f=[1,1,3] s=1 p=valid', async () => { + // Output tensor would have a dimension of zero, if a certain filter's + // dimension is larger than the input's. + const x = tf.tensor5d([1], [1, 1, 1, 1, 1]); + const expected: number[] = []; + const result = tf.avgPool3d(x, [1, 1, 3], 1, 'valid'); + + expect(result.shape).toEqual([1, 1, 1, 0, 1]); + expectArraysClose(await result.data(), expected); + }); + + it('x=[1,1,1,4,1] f=[1,1,1] s=[1,1,2] p=0', async () => { + // Works if the padding is a number. + const x = tf.ones([1, 1, 1, 4, 1]) as tf.Tensor5D; + const expected = [1, 1]; + const result = tf.avgPool3d(x, [1, 1, 1], [1, 1, 2], 0); + + expect(result.shape).toEqual([1, 1, 1, 2, 1]); + expectArraysClose(await result.data(), expected); + }); + + it('x=[1,1,1,1,1] f=[2,2,2] s=1 p=2', async () => { + // Works if the padding is larger than filter size. + const x = tf.ones([1, 1, 1, 1, 1]) as tf.Tensor5D; + const expected = [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + ]; + const result = tf.avgPool3d(x, [2, 2, 2], 1, 2); + + expect(result.shape).toEqual([1, 4, 4, 4, 1]); + expectArraysClose(await result.data(), expected); + }); + it('throws when x is not rank 5', async () => { // tslint:disable-next-line:no-any const x: any = tf.tensor1d([1]); diff --git a/tfjs-core/src/ops/conv_util.ts b/tfjs-core/src/ops/conv_util.ts index 3be66a02ca..93ef7efccf 100644 --- a/tfjs-core/src/ops/conv_util.ts +++ b/tfjs-core/src/ops/conv_util.ts @@ -365,24 +365,23 @@ function computeOutputShape2D( } function computeOutputShape4D( - inShape: [number, number, number, number], fieldSize: number, - outChannels: number, stride: number, zeroPad?: number, + inShape: [number, number, number, number], + filterShape: [number, number, number], outChannels: number, + strides: [number, number, number], zeroPad?: number, roundingMode?: 'floor'|'round'|'ceil'): [number, number, number, number] { if (zeroPad == null) { - zeroPad = computeDefaultPad(inShape, fieldSize, stride); + zeroPad = computeDefaultPad(inShape, filterShape[0], strides[0]); } - const inputDepth = inShape[0]; - const inputRows = inShape[1]; - const inputCols = inShape[2]; - - const outputDepths = - round((inputDepth - fieldSize + 2 * zeroPad) / stride + 1, roundingMode); - const outputRows = - round((inputRows - fieldSize + 2 * zeroPad) / stride + 1, roundingMode); - const outputCols = - round((inputCols - fieldSize + 2 * zeroPad) / stride + 1, roundingMode); - - return [outputDepths, outputRows, outputCols, outChannels]; + const outShape: [number, number, number, number] = [0, 0, 0, outChannels]; + for (let index = 0; index < 3; index++) { + if (inShape[index] + 2 * zeroPad >= filterShape[index]) { + outShape[index] = round( + (inShape[index] - filterShape[index] + 2 * zeroPad) / strides[index] + + 1, + roundingMode); + } + } + return outShape; } export function computeDefaultPad( @@ -496,6 +495,10 @@ function get3DPadAndOutInfo( let outHeight: number; let outWidth: number; + if (pad === 'valid') { + pad = 0; + } + if (typeof pad === 'number') { const padType = (pad === 0) ? 'VALID' : 'NUMBER'; padInfo = { @@ -508,8 +511,9 @@ function get3DPadAndOutInfo( type: padType }; const outShape = computeOutputShape4D( - [inDepth, inHeight, inWidth, 1], filterDepth, 1, strideDepth, pad, - roundingMode); + [inDepth, inHeight, inWidth, 1], + [filterDepth, filterHeight, filterWidth], 1, + [strideDepth, strideHeight, strideWidth], pad, roundingMode); outDepth = outShape[0]; outHeight = outShape[1]; outWidth = outShape[2]; @@ -529,19 +533,6 @@ function get3DPadAndOutInfo( const right = padAlongWidth - left; padInfo = {top, bottom, left, right, front, back, type: 'SAME'}; - } else if (pad === 'valid') { - padInfo = { - top: 0, - bottom: 0, - left: 0, - right: 0, - front: 0, - back: 0, - type: 'VALID' - }; - outDepth = Math.ceil((inDepth - filterDepth + 1) / strideDepth); - outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight); - outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth); } else { throw Error(`Unknown padding parameter: ${pad}`); } diff --git a/tfjs-node/src/run_tests.ts b/tfjs-node/src/run_tests.ts index b4ca9560d6..42c32b358d 100644 --- a/tfjs-node/src/run_tests.ts +++ b/tfjs-node/src/run_tests.ts @@ -84,6 +84,10 @@ const IGNORE_LIST: string[] = [ 'avgPool test-tensorflow {} gradient x=[3,3,1] f=[3,3] s=1 p=explicit', // tslint:disable-next-line:max-line-length 'avgPool3d test-tensorflow {} x=[1,2,2,2,1] f=[2,2,2] s=1 p=1 roundingMode=floor', + // https://github.com/tensorflow/tensorflow/issues/58758 + 'avgPool3d test-tensorflow {} x=[1,1,1,1,1] f=[1,1,3] s=1 p=valid', + // Node backend which uses TF 2.11.0 doesn't support number padding + 'avgPool3d test-tensorflow {} x=[1,1,1,1,1] f=[2,2,2] s=1 p=2', // Node backend which uses TF 2.4.0 doesn't support explicit padding 'maxPool test-tensorflow {} x=[3,3,1] f=[3,3] s=1 p=explicit', 'maxPoolBackprop test-tensorflow {} gradient x=[3,3,1] f=3 s=1 p=explicit',