Skip to content

Commit

Permalink
Resizing Layer (#6879)
Browse files Browse the repository at this point in the history
Implement the resizing layer.

The branch and commit histories have been cleaned up for the PR.

IMPORT NOTE:
The lower-level op implementation of image resizing-nearest neighbor in TensorFlow.js differs from the implementation of the comparable op in Keras-Python.

While the Python version of the op function always selects the bottom right cell of the sub-matrix to be used as the representative value of that region in the downscaled matrix, the JavaScript implementation defaults to the top left cell of the sub-matrix, and then preferentially shifts to the right side of the sub-matrix in all sub-matrices past the lateral halfway point ( calculated by floor((length-1)/2) ), and the bottom side of the sub-matrix in all sub-matrices past the vertical halfway point, when considering the top-left side of the parent matrix as the origin.

This causes a slight variation in the output values from nearest neighbor downscaling between the Python and JavaScript versions of the code as it currently stands, and the unit tests for the resizing layer has been implemented to reflect this difference in op-function behavior.

Co-authored-by: Adam Lang (@AdamLang96) adamglang96@gmail.com
Co-authored-by: Brian Zheng (@Brianzheng123) brianzheng345@gmail.com
  • Loading branch information
koyykdy committed Oct 11, 2022
1 parent 18be40c commit 9752734
Show file tree
Hide file tree
Showing 4 changed files with 267 additions and 2 deletions.
32 changes: 30 additions & 2 deletions tfjs-layers/src/exports_layers.ts
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ import {ZeroPadding2D, ZeroPadding2DLayerArgs} from './layers/padding';
import {AveragePooling1D, AveragePooling2D, AveragePooling3D, GlobalAveragePooling1D, GlobalAveragePooling2D, GlobalMaxPooling1D, GlobalMaxPooling2D, GlobalPooling2DLayerArgs, MaxPooling1D, MaxPooling2D, MaxPooling3D, Pooling1DLayerArgs, Pooling2DLayerArgs, Pooling3DLayerArgs} from './layers/pooling';
import {GRU, GRUCell, GRUCellLayerArgs, GRULayerArgs, LSTM, LSTMCell, LSTMCellLayerArgs, LSTMLayerArgs, RNN, RNNCell, RNNLayerArgs, SimpleRNN, SimpleRNNCell, SimpleRNNCellLayerArgs, SimpleRNNLayerArgs, StackedRNNCells, StackedRNNCellsArgs} from './layers/recurrent';
import {Bidirectional, BidirectionalLayerArgs, TimeDistributed, WrapperLayerArgs} from './layers/wrappers';
import { Rescaling, RescalingArgs } from './layers/preprocessing/image_preprocessing';
import { CategoryEncoding, CategoryEncodingArgs } from './layers/preprocessing/category_encoding';
import {Rescaling, RescalingArgs} from './layers/preprocessing/image_preprocessing';
import {Resizing, ResizingArgs} from './layers/preprocessing/image_resizing';
import {CategoryEncoding, CategoryEncodingArgs} from './layers/preprocessing/category_encoding';

// TODO(cais): Add doc string to all the public static functions in this
// class; include exectuable JavaScript code snippets where applicable
// (b/74074458).
Expand Down Expand Up @@ -1730,6 +1732,32 @@ export function rescaling(args?: RescalingArgs) {
return new Rescaling(args);
}

/**
* A preprocessing layer which resizes images.
* This layer resizes an image input to a target height and width. The input
* should be a 4D (batched) or 3D (unbatched) tensor in `"channels_last"`
* format. Input pixel values can be of any range (e.g. `[0., 1.)` or `[0,
* 255]`) and of interger or floating point dtype. By default, the layer will
* output floats.
*
* Arguments:
* - `height`: number, the height for the output tensor.
* - `width`: number, the width for the output tensor.
* - `interpolation`: string, the method for image resizing interpolation.
* - `cropToAspectRatio`: boolean, whether to keep image aspect ratio.
*
* Input shape:
* Arbitrary.
*
* Output shape:
* height, width, num channels.
*
* @doc {heading: 'Layers', subheading: 'Resizing', namespace: 'layers'}
*/
export function resizing(args?: ResizingArgs) {
return new Resizing(args);
}

/**
* A preprocessing layer which encodes integer features.
*
Expand Down
10 changes: 10 additions & 0 deletions tfjs-layers/src/layers/preprocessing/image_preprocessing_test.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
/**
* @license
* Copyright 2022 CodeSmith LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/

import { Tensor, randomNormal, mul, add} from '@tensorflow/tfjs-core';
import { Rescaling } from './image_preprocessing';
import { describeMathCPUAndGPU, expectTensorsClose } from '../../utils/test_utils';
Expand Down
102 changes: 102 additions & 0 deletions tfjs-layers/src/layers/preprocessing/image_resizing.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/**
* @license
* Copyright 2022 CodeSmith LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/

import {image, Rank, serialization, Tensor, tidy} from '@tensorflow/tfjs-core'; // mul, add

import {Layer, LayerArgs} from '../../engine/topology';
import {ValueError} from '../../errors';
import {Shape} from '../../keras_format/common';
import {Kwargs} from '../../types';
import {getExactlyOneShape} from '../../utils/types_utils'; //, getExactlyOneTensor

// tf methods unimplemented in tfjs: 'bicubic', 'area', 'lanczos3', 'lanczos5',
// 'gaussian', 'mitchellcubic'
const INTERPOLATION_KEYS = ['bilinear', 'nearest'] as const;
const INTERPOLATION_METHODS = new Set(INTERPOLATION_KEYS);
type InterpolationType = typeof INTERPOLATION_KEYS[number];

export declare interface ResizingArgs extends LayerArgs {
height: number;
width: number;
interpolation?: InterpolationType; // default = 'bilinear';
cropToAspectRatio?: boolean; // default = false;
}

/**
* Preprocessing Resizing Layer
*
* This resizes images by a scaling and offset factor
*/

export class Resizing extends Layer {
/** @nocollapse */
static className = 'Resizing';
private readonly height: number;
private readonly width: number;
// method of interpolation to be used; default = "bilinear";
private readonly interpolation: InterpolationType;
// toggle whether the aspect ratio should be preserved; default = false;
private readonly cropToAspectRatio: boolean;

constructor(args: ResizingArgs) {
super(args);

this.height = args.height;
this.width = args.width;

if (args.interpolation) {
if (INTERPOLATION_METHODS.has(args.interpolation)) {
this.interpolation = args.interpolation;
} else {
throw new ValueError(`Invalid interpolation parameter: ${
args.interpolation} is not implemented`);
}
} else {
this.interpolation = 'bilinear';
}
this.cropToAspectRatio = Boolean(args.cropToAspectRatio);
}

computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] {
inputShape = getExactlyOneShape(inputShape);
const numChannels = inputShape[2];
return [this.height, this.width, numChannels];
}

getConfig(): serialization.ConfigDict {
const config: serialization.ConfigDict = {
'height': this.height,
'width': this.width,
'interpolation': this.interpolation,
'cropToAspectRatio': this.cropToAspectRatio
};

const baseConfig = super.getConfig();
Object.assign(config, baseConfig);
return config;
}

call(inputs: Tensor<Rank.R3>|Tensor<Rank.R4>, kwargs: Kwargs):
Tensor[]|Tensor {
return tidy(() => {
const size: [number, number] = [this.height, this.width];
if (this.interpolation === 'bilinear') {
return image.resizeBilinear(inputs, size, !this.cropToAspectRatio);
} else if (this.interpolation === 'nearest') {
return image.resizeNearestNeighbor(
inputs, size, !this.cropToAspectRatio);
} else {
throw new Error(`Interpolation is ${this.interpolation} but only ${[...INTERPOLATION_METHODS]} are supported`);
}
});
}
}

serialization.registerClass(Resizing);
125 changes: 125 additions & 0 deletions tfjs-layers/src/layers/preprocessing/image_resizing_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/**
* @license
* Copyright 2022 CodeSmith LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/

/**
* Unit Tests for image resizing layer.
*/

import {image, Rank, Tensor, tensor, zeros, range, reshape} from '@tensorflow/tfjs-core';

// import {Shape} from '../../keras_format/common';
import {describeMathCPUAndGPU, expectTensorsClose} from '../../utils/test_utils';

import {Resizing, ResizingArgs} from './image_resizing';

describeMathCPUAndGPU('Resizing Layer', () => {
it('Check if output shape matches specifications', () => {
// resize and check output shape
const maxHeight = 40;
const height = Math.floor(Math.random() * maxHeight);
const maxWidth = 60;
const width = Math.floor(Math.random() * maxWidth);
const numChannels = 3;
const inputTensor = zeros([height * 2, width * 2, numChannels]);
const expectedOutputShape = [height, width, numChannels];
const resizingLayer = new Resizing({height, width});
const layerOutputTensor = resizingLayer.apply(inputTensor) as Tensor;
expect(layerOutputTensor.shape).toEqual(expectedOutputShape);
});

it('Returns correctly downscaled tensor', () => {
// resize and check output content (not batched)
const rangeTensor = range(0, 16);
const inputTensor = reshape(rangeTensor, [4,4,1]);
const height = 2;
const width = 2;
const interpolation = 'nearest';
const resizingLayer = new Resizing({height, width, interpolation});
const layerOutputTensor = resizingLayer.apply(inputTensor) as Tensor;
const expectedArr = [[0, 3], [12, 15]];
const expectedOutput = tensor(expectedArr, [2,2,1]);
expectTensorsClose(layerOutputTensor, expectedOutput);
});

it('Returns correctly downscaled tensor', () => {
// resize and check output content (batched)
const rangeTensor = range(0, 36);
const inputTensor = reshape(rangeTensor, [1,6,6,1]);
const height = 3;
const width = 3;
const interpolation = 'nearest';
const resizingLayer = new Resizing({height, width, interpolation});
const layerOutputTensor = resizingLayer.apply(inputTensor) as Tensor;
const expectedArr = [[0,3,5], [18,21,23], [30,33,35]];
const expectedOutput = tensor([expectedArr], [1,3,3,1]);
expectTensorsClose(layerOutputTensor, expectedOutput);
});

it('Returns correctly upscaled tensor', () => {
const rangeTensor = range(0, 4);
const inputTensor = reshape(rangeTensor, [1, 2, 2, 1]);
const height = 4;
const width = 4;
const interpolation = 'nearest';
const resizingLayer = new Resizing({height, width, interpolation});
const layerOutputTensor = resizingLayer.apply(inputTensor) as Tensor;
const expectedArr = [[0,0,1,1], [0,0,1,1], [2,2,3,3], [2,2,3,3]];
const expectedOutput = tensor([expectedArr], [1,4,4,1]);
expectTensorsClose(layerOutputTensor, expectedOutput);
});

it('Returns the same tensor when given same shape as input', () => {
// create a resizing layer with same shape as input
const height = 64;
const width = 32;
const numChannels = 1;
const rangeTensor = range(0, height * width);
const inputTensor = reshape(rangeTensor, [height, width, numChannels]);
const resizingLayer = new Resizing({height, width});
const layerOutputTensor = resizingLayer.apply(inputTensor) as Tensor;
expectTensorsClose(layerOutputTensor, inputTensor);
});

it('Returns a tensor of the correct dtype', () => {
// do a same resizing operation, cheeck tensors dtypes and content
const height = 40;
const width = 60;
const numChannels = 3;
const inputTensor: Tensor<Rank.R3> =
zeros([height, width, numChannels]);
const size: [number, number] = [height, width];
const expectedOutputTensor = image.resizeBilinear(inputTensor, size);
const resizingLayer = new Resizing({height, width});
const layerOutputTensor = resizingLayer.apply(inputTensor) as Tensor;
expect(layerOutputTensor.dtype).toBe(inputTensor.dtype);
expectTensorsClose(layerOutputTensor, expectedOutputTensor);
});

it('Throws an error given incorrect parameters', () => {
// pass incorrect interpolation method string to layer init
const height = 16;
const width = 16;
const interpolation = 'unimplemented';
const incorrectArgs = {height, width, interpolation};
const expectedError =
`Invalid interpolation parameter: ${interpolation} is not implemented`;
expect(() => new Resizing(incorrectArgs as ResizingArgs))
.toThrowError(expectedError);
});

it('Config holds correct name', () => {
// layer name property set properly
const height = 40;
const width = 60;
const resizingLayer = new Resizing({height, width, name:'Resizing'});
const config = resizingLayer.getConfig();
expect(config.name).toEqual('Resizing');
});
});

0 comments on commit 9752734

Please sign in to comment.