From 723898b93478c03328da5af84e014cea1b758351 Mon Sep 17 00:00:00 2001 From: Linchenn Date: Thu, 1 Dec 2022 16:59:56 -0800 Subject: [PATCH] Update dilation2d.ts --- tfjs-core/src/ops/dilation2d.ts | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tfjs-core/src/ops/dilation2d.ts b/tfjs-core/src/ops/dilation2d.ts index 7ea1c7a191..5f83ce4bf3 100644 --- a/tfjs-core/src/ops/dilation2d.ts +++ b/tfjs-core/src/ops/dilation2d.ts @@ -31,7 +31,7 @@ import {reshape} from './reshape'; * Computes the grayscale dilation over the input `x`. * * @param x The input tensor, rank 3 or rank 4 of shape - * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed. + * `[batch, height, width, depth]`. If rank 3, batch of 1 is assumed. * @param filter The filter tensor, rank 3, of shape * `[filterHeight, filterWidth, depth]`. * @param strides The strides of the sliding window for each dimension of the @@ -87,6 +87,11 @@ function dilation2d_( reshapedTo4D = true; } + util.assert( + x4D.shape[3] === $filter.shape[2], + () => `Error in dilation2d: input and filter must have the same depth: ${ + x4D.shape[3]} vs ${$filter.shape[2]}`); + const inputs: Dilation2DInputs = {x: x4D, filter: $filter}; const attrs: Dilation2DAttrs = {strides, pad, dilations};