From bdba7ed0945e730db809f47a5510b93754913638 Mon Sep 17 00:00:00 2001 From: Adrian Branescu <52697528+adrian-branescu@users.noreply.github.com> Date: Tue, 29 Nov 2022 19:31:19 +0200 Subject: [PATCH] Fix #7104 - tf.initializers.Uniform() ignores seed argument & add tests that replicated the issue, fix wrong serialization name registered for LeCunUniform initializer class (#7108) Fix #7104 - tf.initializers.Uniform() ignores seed argument & add tests that replicated the issue, fix wrong serialization name registered for LeCunUniform initializer class --- tfjs-layers/src/initializers.ts | 6 +- tfjs-layers/src/initializers_test.ts | 188 +++++++++++++++++++++++++++ 2 files changed, 191 insertions(+), 3 deletions(-) diff --git a/tfjs-layers/src/initializers.ts b/tfjs-layers/src/initializers.ts index 976054e58e..c2e7e01e62 100644 --- a/tfjs-layers/src/initializers.ts +++ b/tfjs-layers/src/initializers.ts @@ -128,7 +128,7 @@ export class RandomUniform extends Initializer { } apply(shape: Shape, dtype?: DataType): Tensor { - return randomUniform(shape, this.minval, this.maxval, dtype); + return randomUniform(shape, this.minval, this.maxval, dtype, this.seed); } override getConfig(): serialization.ConfigDict { @@ -352,7 +352,7 @@ export class VarianceScaling extends Initializer { return truncatedNormal(shape, 0, stddev, dtype, this.seed); } else { const limit = Math.sqrt(3 * scale); - return randomUniform(shape, -limit, limit, dtype); + return randomUniform(shape, -limit, limit, dtype, this.seed); } } @@ -498,7 +498,7 @@ serialization.registerClass(LeCunNormal); export class LeCunUniform extends VarianceScaling { /** @nocollapse */ - static override className = 'LeCunNormal'; + static override className = 'LeCunUniform'; constructor(args?: SeedOnlyInitializerArgs) { super({ diff --git a/tfjs-layers/src/initializers_test.ts b/tfjs-layers/src/initializers_test.ts index ef82f8eb64..14401c8ac1 100644 --- a/tfjs-layers/src/initializers_test.ts +++ b/tfjs-layers/src/initializers_test.ts @@ -180,6 +180,25 @@ describeMathCPU('RandomUniform initializer', () => { expect(weights.dtype).toEqual('float32'); expectTensorsValuesInRange(weights, 17, 47); }); + + it('with configured seed', () => { + + const initializerConfig: serialization.ConfigDict = { + className: 'RandomUniform', + config: { seed: 42 } + }; + + const expectedInitializer = getInitializer(initializerConfig); + const actualInitializer = getInitializer(initializerConfig); + + const expected = expectedInitializer.apply(shape, 'float32'); + const actual = actualInitializer.apply(shape, 'float32'); + + expect(actual.shape).toEqual(expected.shape); + expect(actual.dtype).toEqual(expected.dtype); + expectTensorsClose(actual, expected); + }); + it('Does not leak', () => { expectNoLeakedTensors(() => getInitializer('RandomUniform').apply([3]), 1); }); @@ -214,6 +233,25 @@ describeMathCPU('RandomNormal initializer', () => { expect(weights.dtype).toEqual('float32'); // TODO(bileschi): Add test to assert the values match expectations. }); + + it('with configured seed', () => { + + const initializerConfig: serialization.ConfigDict = { + className: 'RandomNormal', + config: { seed: 42 } + }; + + const expectedInitializer = getInitializer(initializerConfig); + const actualInitializer = getInitializer(initializerConfig); + + const expected = expectedInitializer.apply(shape, 'float32'); + const actual = actualInitializer.apply(shape, 'float32'); + + expect(actual.shape).toEqual(expected.shape); + expect(actual.dtype).toEqual(expected.dtype); + expectTensorsClose(actual, expected); + }); + it('Does not leak', () => { expectNoLeakedTensors(() => getInitializer('RandomNormal').apply([3]), 1); }); @@ -239,6 +277,24 @@ describeMathCPU('HeNormal initializer', () => { expectTensorsValuesInRange(weights, -2 * stddev, 2 * stddev); }); + it('with configured seed', () => { + + const initializerConfig: serialization.ConfigDict = { + className: 'HeNormal', + config: { seed: 42 } + }; + + const expectedInitializer = getInitializer(initializerConfig); + const actualInitializer = getInitializer(initializerConfig); + + const expected = expectedInitializer.apply(shape, 'float32'); + const actual = actualInitializer.apply(shape, 'float32'); + + expect(actual.shape).toEqual(expected.shape); + expect(actual.dtype).toEqual(expected.dtype); + expectTensorsClose(actual, expected); + }); + it('Does not leak', () => { expectNoLeakedTensors(() => getInitializer('HeNormal').apply([3]), 1); }); @@ -264,6 +320,24 @@ describeMathCPU('HeUniform initializer', () => { expectTensorsValuesInRange(weights, -bound, bound); }); + it('with configured seed', () => { + + const initializerConfig: serialization.ConfigDict = { + className: 'HeUniform', + config: { seed: 42 } + }; + + const expectedInitializer = getInitializer(initializerConfig); + const actualInitializer = getInitializer(initializerConfig); + + const expected = expectedInitializer.apply(shape, 'float32'); + const actual = actualInitializer.apply(shape, 'float32'); + + expect(actual.shape).toEqual(expected.shape); + expect(actual.dtype).toEqual(expected.dtype); + expectTensorsClose(actual, expected); + }); + it('Does not leak', () => { expectNoLeakedTensors(() => getInitializer('heUniform').apply([3]), 1); }); @@ -289,6 +363,24 @@ describeMathCPU('LecunNormal initializer', () => { expectTensorsValuesInRange(weights, -2 * stddev, 2 * stddev); }); + it('with configured seed', () => { + + const initializerConfig: serialization.ConfigDict = { + className: 'LeCunNormal', + config: { seed: 42 } + }; + + const expectedInitializer = getInitializer(initializerConfig); + const actualInitializer = getInitializer(initializerConfig); + + const expected = expectedInitializer.apply(shape, 'float32'); + const actual = actualInitializer.apply(shape, 'float32'); + + expect(actual.shape).toEqual(expected.shape); + expect(actual.dtype).toEqual(expected.dtype); + expectTensorsClose(actual, expected); + }); + it('Does not leak', () => { expectNoLeakedTensors(() => getInitializer('LeCunNormal').apply([3]), 1); }); @@ -314,6 +406,24 @@ describeMathCPU('LeCunUniform initializer', () => { expectTensorsValuesInRange(weights, -bound, bound); }); + it('with configured seed', () => { + + const initializerConfig: serialization.ConfigDict = { + className: 'LeCunUniform', + config: { seed: 42 } + }; + + const expectedInitializer = getInitializer(initializerConfig); + const actualInitializer = getInitializer(initializerConfig); + + const expected = expectedInitializer.apply(shape, 'float32'); + const actual = actualInitializer.apply(shape, 'float32'); + + expect(actual.shape).toEqual(expected.shape); + expect(actual.dtype).toEqual(expected.dtype); + expectTensorsClose(actual, expected); + }); + it('Does not leak', () => { expectNoLeakedTensors(() => getInitializer('LeCunUniform').apply([3]), 1); }); @@ -348,6 +458,25 @@ describeMathCPU('TruncatedNormal initializer', () => { expect(weights.dtype).toEqual('float32'); expectTensorsValuesInRange(weights, 0.0, 2.0); }); + + it('with configured seed', () => { + + const initializerConfig: serialization.ConfigDict = { + className: 'TruncatedNormal', + config: { seed: 42 } + }; + + const expectedInitializer = getInitializer(initializerConfig); + const actualInitializer = getInitializer(initializerConfig); + + const expected = expectedInitializer.apply(shape, 'float32'); + const actual = actualInitializer.apply(shape, 'float32'); + + expect(actual.shape).toEqual(expected.shape); + expect(actual.dtype).toEqual(expected.dtype); + expectTensorsClose(actual, expected); + }); + it('Does not leak', () => { expectNoLeakedTensors( () => getInitializer('TruncatedNormal').apply([3]), 1); @@ -403,6 +532,25 @@ describeMathCPU('Glorot uniform initializer', () => { .toBeGreaterThan(-limit); }); }); + + it('with configured seed', () => { + + const initializerConfig: serialization.ConfigDict = { + className: 'GlorotUniform', + config: { seed: 42 } + }; + + const expectedInitializer = getInitializer(initializerConfig); + const actualInitializer = getInitializer(initializerConfig); + + const expected = expectedInitializer.apply([7, 2], 'float32'); + const actual = actualInitializer.apply([7, 2], 'float32'); + + expect(actual.shape).toEqual(expected.shape); + expect(actual.dtype).toEqual(expected.dtype); + expectTensorsClose(actual, expected); + }); + it('Does not leak', () => { expectNoLeakedTensors(() => getInitializer('GlorotUniform').apply([3]), 1); }); @@ -429,6 +577,27 @@ describeMathCPU('VarianceScaling initializer', () => { const newConfig = newInit.getConfig(); expect(newConfig['distribution']).toEqual(baseConfig['distribution']); }); + + it(`${distribution} with configured seed`, () => { + + const initializerConfig: serialization.ConfigDict = { + className: 'VarianceScaling', + config: { + distribution, + seed: 42 + } + }; + + const expectedInitializer = getInitializer(initializerConfig); + const actualInitializer = getInitializer(initializerConfig); + + const expected = expectedInitializer.apply([7, 2], 'float32'); + const actual = actualInitializer.apply([7, 2], 'float32'); + + expect(actual.shape).toEqual(expected.shape); + expect(actual.dtype).toEqual(expected.dtype); + expectTensorsClose(actual, expected); + }); }); }); @@ -485,6 +654,25 @@ describeMathCPU('Glorot normal initializer', () => { expect(variance2).toBeLessThan(variance1); }); }); + + it('with configured seed', () => { + + const initializerConfig: serialization.ConfigDict = { + className: 'GlorotNormal', + config: { seed: 42 } + }; + + const expectedInitializer = getInitializer(initializerConfig); + const actualInitializer = getInitializer(initializerConfig); + + const expected = expectedInitializer.apply([7, 2], 'float32'); + const actual = actualInitializer.apply([7, 2], 'float32'); + + expect(actual.shape).toEqual(expected.shape); + expect(actual.dtype).toEqual(expected.dtype); + expectTensorsClose(actual, expected); + }); + it('Does not leak', () => { expectNoLeakedTensors(() => getInitializer('GlorotNormal').apply([3]), 1); });