Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resizing Layer #6879

Merged
merged 13 commits into from
Oct 11, 2022
Merged
32 changes: 30 additions & 2 deletions tfjs-layers/src/exports_layers.ts
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ import {ZeroPadding2D, ZeroPadding2DLayerArgs} from './layers/padding';
import {AveragePooling1D, AveragePooling2D, AveragePooling3D, GlobalAveragePooling1D, GlobalAveragePooling2D, GlobalMaxPooling1D, GlobalMaxPooling2D, GlobalPooling2DLayerArgs, MaxPooling1D, MaxPooling2D, MaxPooling3D, Pooling1DLayerArgs, Pooling2DLayerArgs, Pooling3DLayerArgs} from './layers/pooling';
import {GRU, GRUCell, GRUCellLayerArgs, GRULayerArgs, LSTM, LSTMCell, LSTMCellLayerArgs, LSTMLayerArgs, RNN, RNNCell, RNNLayerArgs, SimpleRNN, SimpleRNNCell, SimpleRNNCellLayerArgs, SimpleRNNLayerArgs, StackedRNNCells, StackedRNNCellsArgs} from './layers/recurrent';
import {Bidirectional, BidirectionalLayerArgs, TimeDistributed, WrapperLayerArgs} from './layers/wrappers';
import { Rescaling, RescalingArgs } from './layers/preprocessing/image_preprocessing';
import { CategoryEncoding, CategoryEncodingArgs } from './layers/preprocessing/category_encoding';
import {Rescaling, RescalingArgs} from './layers/preprocessing/image_preprocessing';
import {Resizing, ResizingArgs} from './layers/preprocessing/image_resizing';
import {CategoryEncoding, CategoryEncodingArgs} from './layers/preprocessing/category_encoding';

// TODO(cais): Add doc string to all the public static functions in this
// class; include exectuable JavaScript code snippets where applicable
// (b/74074458).
Expand Down Expand Up @@ -1730,6 +1732,32 @@ export function rescaling(args?: RescalingArgs) {
return new Rescaling(args);
}

/**
* A preprocessing layer which resizes images.
* This layer resizes an image input to a target height and width. The input
* should be a 4D (batched) or 3D (unbatched) tensor in `"channels_last"`
* format. Input pixel values can be of any range (e.g. `[0., 1.)` or `[0,
* 255]`) and of interger or floating point dtype. By default, the layer will
* output floats.
*
* Arguments:
* - `height`: number, the height for the output tensor.
* - `width`: number, the width for the output tensor.
* - `interpolation`: string, the method for image resizing interpolation.
* - `cropToAspectRatio`: boolean, whether to keep image aspect ratio.
*
* Input shape:
* Arbitrary.
*
* Output shape:
* height, width, num channels.
*
* @doc {heading: 'Layers', subheading: 'Resizing', namespace: 'layers'}
*/
export function resizing(args?: ResizingArgs) {
return new Resizing(args);
}

/**
* A preprocessing layer which encodes integer features.
*
Expand Down
10 changes: 10 additions & 0 deletions tfjs-layers/src/layers/preprocessing/image_preprocessing_test.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
/**
* @license
* Copyright 2022 CodeSmith LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/

import { Tensor, randomNormal, mul, add} from '@tensorflow/tfjs-core';
import { Rescaling } from './image_preprocessing';
import { describeMathCPUAndGPU, expectTensorsClose } from '../../utils/test_utils';
Expand Down
102 changes: 102 additions & 0 deletions tfjs-layers/src/layers/preprocessing/image_resizing.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/**
* @license
* Copyright 2022 CodeSmith LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/

import {image, Rank, serialization, Tensor, tidy} from '@tensorflow/tfjs-core'; // mul, add

import {Layer, LayerArgs} from '../../engine/topology';
import {ValueError} from '../../errors';
import {Shape} from '../../keras_format/common';
import {Kwargs} from '../../types';
import {getExactlyOneShape} from '../../utils/types_utils'; //, getExactlyOneTensor

// tf methods unimplemented in tfjs: 'bicubic', 'area', 'lanczos3', 'lanczos5',
// 'gaussian', 'mitchellcubic'
const INTERPOLATION_KEYS = ['bilinear', 'nearest'] as const;
const INTERPOLATION_METHODS = new Set(INTERPOLATION_KEYS);
type InterpolationType = typeof INTERPOLATION_KEYS[number];

export declare interface ResizingArgs extends LayerArgs {
height: number;
width: number;
interpolation?: InterpolationType; // default = 'bilinear';
cropToAspectRatio?: boolean; // default = false;
}

/**
* Preprocessing Resizing Layer
*
* This resizes images by a scaling and offset factor
*/

export class Resizing extends Layer {
/** @nocollapse */
static className = 'Resizing';
private readonly height: number;
private readonly width: number;
// method of interpolation to be used; default = "bilinear";
private readonly interpolation: InterpolationType;
// toggle whether the aspect ratio should be preserved; default = false;
private readonly cropToAspectRatio: boolean;

constructor(args: ResizingArgs) {
super(args);

this.height = args.height;
this.width = args.width;

if (args.interpolation) {
if (INTERPOLATION_METHODS.has(args.interpolation)) {
this.interpolation = args.interpolation;
} else {
throw new ValueError(`Invalid interpolation parameter: ${
args.interpolation} is not implemented`);
}
} else {
this.interpolation = 'bilinear';
}
this.cropToAspectRatio = Boolean(args.cropToAspectRatio);
}

computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] {
inputShape = getExactlyOneShape(inputShape);
const numChannels = inputShape[2];
return [this.height, this.width, numChannels];
}

getConfig(): serialization.ConfigDict {
const config: serialization.ConfigDict = {
'height': this.height,
'width': this.width,
'interpolation': this.interpolation,
'cropToAspectRatio': this.cropToAspectRatio
};

const baseConfig = super.getConfig();
Object.assign(config, baseConfig);
return config;
}

call(inputs: Tensor<Rank.R3>|Tensor<Rank.R4>, kwargs: Kwargs):
Tensor[]|Tensor {
return tidy(() => {
const size: [number, number] = [this.height, this.width];
if (this.interpolation === 'bilinear') {
return image.resizeBilinear(inputs, size, !this.cropToAspectRatio);
} else if (this.interpolation === 'nearest') {
return image.resizeNearestNeighbor(
inputs, size, !this.cropToAspectRatio);
} else {
throw new Error(`Interpolation is ${this.interpolation} but only ${[...INTERPOLATION_METHODS]} are supported`);
}
});
}
}

serialization.registerClass(Resizing);
125 changes: 125 additions & 0 deletions tfjs-layers/src/layers/preprocessing/image_resizing_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/**
* @license
* Copyright 2022 CodeSmith LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/

/**
* Unit Tests for image resizing layer.
*/

import {image, Rank, Tensor, tensor, zeros, range, reshape} from '@tensorflow/tfjs-core';

// import {Shape} from '../../keras_format/common';
import {describeMathCPUAndGPU, expectTensorsClose} from '../../utils/test_utils';

import {Resizing, ResizingArgs} from './image_resizing';

describeMathCPUAndGPU('Resizing Layer', () => {
it('Check if output shape matches specifications', () => {
// resize and check output shape
const maxHeight = 40;
const height = Math.floor(Math.random() * maxHeight);
const maxWidth = 60;
const width = Math.floor(Math.random() * maxWidth);
const numChannels = 3;
const inputTensor = zeros([height * 2, width * 2, numChannels]);
const expectedOutputShape = [height, width, numChannels];
const resizingLayer = new Resizing({height, width});
const layerOutputTensor = resizingLayer.apply(inputTensor) as Tensor;
expect(layerOutputTensor.shape).toEqual(expectedOutputShape);
});

it('Returns correctly downscaled tensor', () => {
// resize and check output content (not batched)
const rangeTensor = range(0, 16);
const inputTensor = reshape(rangeTensor, [4,4,1]);
const height = 2;
const width = 2;
const interpolation = 'nearest';
const resizingLayer = new Resizing({height, width, interpolation});
const layerOutputTensor = resizingLayer.apply(inputTensor) as Tensor;
const expectedArr = [[0, 3], [12, 15]];
const expectedOutput = tensor(expectedArr, [2,2,1]);
expectTensorsClose(layerOutputTensor, expectedOutput);
});

it('Returns correctly downscaled tensor', () => {
// resize and check output content (batched)
const rangeTensor = range(0, 36);
const inputTensor = reshape(rangeTensor, [1,6,6,1]);
const height = 3;
const width = 3;
const interpolation = 'nearest';
const resizingLayer = new Resizing({height, width, interpolation});
const layerOutputTensor = resizingLayer.apply(inputTensor) as Tensor;
const expectedArr = [[0,3,5], [18,21,23], [30,33,35]];
const expectedOutput = tensor([expectedArr], [1,3,3,1]);
expectTensorsClose(layerOutputTensor, expectedOutput);
});

it('Returns correctly upscaled tensor', () => {
const rangeTensor = range(0, 4);
const inputTensor = reshape(rangeTensor, [1, 2, 2, 1]);
const height = 4;
const width = 4;
const interpolation = 'nearest';
const resizingLayer = new Resizing({height, width, interpolation});
const layerOutputTensor = resizingLayer.apply(inputTensor) as Tensor;
const expectedArr = [[0,0,1,1], [0,0,1,1], [2,2,3,3], [2,2,3,3]];
const expectedOutput = tensor([expectedArr], [1,4,4,1]);
expectTensorsClose(layerOutputTensor, expectedOutput);
});

it('Returns the same tensor when given same shape as input', () => {
// create a resizing layer with same shape as input
const height = 64;
const width = 32;
const numChannels = 1;
const rangeTensor = range(0, height * width);
const inputTensor = reshape(rangeTensor, [height, width, numChannels]);
const resizingLayer = new Resizing({height, width});
const layerOutputTensor = resizingLayer.apply(inputTensor) as Tensor;
expectTensorsClose(layerOutputTensor, inputTensor);
});

it('Returns a tensor of the correct dtype', () => {
// do a same resizing operation, cheeck tensors dtypes and content
const height = 40;
const width = 60;
const numChannels = 3;
const inputTensor: Tensor<Rank.R3> =
zeros([height, width, numChannels]);
const size: [number, number] = [height, width];
const expectedOutputTensor = image.resizeBilinear(inputTensor, size);
const resizingLayer = new Resizing({height, width});
const layerOutputTensor = resizingLayer.apply(inputTensor) as Tensor;
expect(layerOutputTensor.dtype).toBe(inputTensor.dtype);
expectTensorsClose(layerOutputTensor, expectedOutputTensor);
});

it('Throws an error given incorrect parameters', () => {
// pass incorrect interpolation method string to layer init
const height = 16;
const width = 16;
const interpolation = 'unimplemented';
const incorrectArgs = {height, width, interpolation};
const expectedError =
`Invalid interpolation parameter: ${interpolation} is not implemented`;
expect(() => new Resizing(incorrectArgs as ResizingArgs))
.toThrowError(expectedError);
});

it('Config holds correct name', () => {
// layer name property set properly
const height = 40;
const width = 60;
const resizingLayer = new Resizing({height, width, name:'Resizing'});
const config = resizingLayer.getConfig();
expect(config.name).toEqual('Resizing');
});
});