Skip to content

Commit

Permalink
Add depth check for dilation2d (tensorflow#7137)
Browse files Browse the repository at this point in the history
BUG
  • Loading branch information
Linchenn committed Jan 9, 2023
1 parent c86a27b commit 8a9c8e7
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion tfjs-core/src/ops/dilation2d.ts
Expand Up @@ -31,7 +31,7 @@ import {reshape} from './reshape';
* Computes the grayscale dilation over the input `x`.
*
* @param x The input tensor, rank 3 or rank 4 of shape
* `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
* `[batch, height, width, depth]`. If rank 3, batch of 1 is assumed.
* @param filter The filter tensor, rank 3, of shape
* `[filterHeight, filterWidth, depth]`.
* @param strides The strides of the sliding window for each dimension of the
Expand Down Expand Up @@ -87,6 +87,11 @@ function dilation2d_<T extends Tensor3D|Tensor4D>(
reshapedTo4D = true;
}

util.assert(
x4D.shape[3] === $filter.shape[2],
() => `Error in dilation2d: input and filter must have the same depth: ${
x4D.shape[3]} vs ${$filter.shape[2]}`);

const inputs: Dilation2DInputs = {x: x4D, filter: $filter};
const attrs: Dilation2DAttrs = {strides, pad, dilations};

Expand Down

0 comments on commit 8a9c8e7

Please sign in to comment.