Skip to content

Commit

Permalink
Fix #7104 - tf.initializers.<random | glorot | he | leCunn>Uniform() …
Browse files Browse the repository at this point in the history
…ignores seed argument & add tests that replicated the issue, fix wrong serialization name registered for LeCunUniform initializer class (#7108)

Fix #7104 - tf.initializers.<random | glorot | he | leCunn>Uniform() ignores seed argument & add tests that replicated the issue, fix wrong serialization name registered for LeCunUniform initializer class
  • Loading branch information
adrian-branescu committed Nov 29, 2022
1 parent bea721d commit aa14065
Show file tree
Hide file tree
Showing 2 changed files with 191 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tfjs-layers/src/initializers.ts
Expand Up @@ -128,7 +128,7 @@ export class RandomUniform extends Initializer {
}

apply(shape: Shape, dtype?: DataType): Tensor {
return randomUniform(shape, this.minval, this.maxval, dtype);
return randomUniform(shape, this.minval, this.maxval, dtype, this.seed);
}

override getConfig(): serialization.ConfigDict {
Expand Down Expand Up @@ -352,7 +352,7 @@ export class VarianceScaling extends Initializer {
return truncatedNormal(shape, 0, stddev, dtype, this.seed);
} else {
const limit = Math.sqrt(3 * scale);
return randomUniform(shape, -limit, limit, dtype);
return randomUniform(shape, -limit, limit, dtype, this.seed);
}
}

Expand Down Expand Up @@ -498,7 +498,7 @@ serialization.registerClass(LeCunNormal);

export class LeCunUniform extends VarianceScaling {
/** @nocollapse */
static override className = 'LeCunNormal';
static override className = 'LeCunUniform';

constructor(args?: SeedOnlyInitializerArgs) {
super({
Expand Down
188 changes: 188 additions & 0 deletions tfjs-layers/src/initializers_test.ts
Expand Up @@ -180,6 +180,25 @@ describeMathCPU('RandomUniform initializer', () => {
expect(weights.dtype).toEqual('float32');
expectTensorsValuesInRange(weights, 17, 47);
});

it('with configured seed', () => {

const initializerConfig: serialization.ConfigDict = {
className: 'RandomUniform',
config: { seed: 42 }
};

const expectedInitializer = getInitializer(initializerConfig);
const actualInitializer = getInitializer(initializerConfig);

const expected = expectedInitializer.apply(shape, 'float32');
const actual = actualInitializer.apply(shape, 'float32');

expect(actual.shape).toEqual(expected.shape);
expect(actual.dtype).toEqual(expected.dtype);
expectTensorsClose(actual, expected);
});

it('Does not leak', () => {
expectNoLeakedTensors(() => getInitializer('RandomUniform').apply([3]), 1);
});
Expand Down Expand Up @@ -214,6 +233,25 @@ describeMathCPU('RandomNormal initializer', () => {
expect(weights.dtype).toEqual('float32');
// TODO(bileschi): Add test to assert the values match expectations.
});

it('with configured seed', () => {

const initializerConfig: serialization.ConfigDict = {
className: 'RandomNormal',
config: { seed: 42 }
};

const expectedInitializer = getInitializer(initializerConfig);
const actualInitializer = getInitializer(initializerConfig);

const expected = expectedInitializer.apply(shape, 'float32');
const actual = actualInitializer.apply(shape, 'float32');

expect(actual.shape).toEqual(expected.shape);
expect(actual.dtype).toEqual(expected.dtype);
expectTensorsClose(actual, expected);
});

it('Does not leak', () => {
expectNoLeakedTensors(() => getInitializer('RandomNormal').apply([3]), 1);
});
Expand All @@ -239,6 +277,24 @@ describeMathCPU('HeNormal initializer', () => {
expectTensorsValuesInRange(weights, -2 * stddev, 2 * stddev);
});

it('with configured seed', () => {

const initializerConfig: serialization.ConfigDict = {
className: 'HeNormal',
config: { seed: 42 }
};

const expectedInitializer = getInitializer(initializerConfig);
const actualInitializer = getInitializer(initializerConfig);

const expected = expectedInitializer.apply(shape, 'float32');
const actual = actualInitializer.apply(shape, 'float32');

expect(actual.shape).toEqual(expected.shape);
expect(actual.dtype).toEqual(expected.dtype);
expectTensorsClose(actual, expected);
});

it('Does not leak', () => {
expectNoLeakedTensors(() => getInitializer('HeNormal').apply([3]), 1);
});
Expand All @@ -264,6 +320,24 @@ describeMathCPU('HeUniform initializer', () => {
expectTensorsValuesInRange(weights, -bound, bound);
});

it('with configured seed', () => {

const initializerConfig: serialization.ConfigDict = {
className: 'HeUniform',
config: { seed: 42 }
};

const expectedInitializer = getInitializer(initializerConfig);
const actualInitializer = getInitializer(initializerConfig);

const expected = expectedInitializer.apply(shape, 'float32');
const actual = actualInitializer.apply(shape, 'float32');

expect(actual.shape).toEqual(expected.shape);
expect(actual.dtype).toEqual(expected.dtype);
expectTensorsClose(actual, expected);
});

it('Does not leak', () => {
expectNoLeakedTensors(() => getInitializer('heUniform').apply([3]), 1);
});
Expand All @@ -289,6 +363,24 @@ describeMathCPU('LecunNormal initializer', () => {
expectTensorsValuesInRange(weights, -2 * stddev, 2 * stddev);
});

it('with configured seed', () => {

const initializerConfig: serialization.ConfigDict = {
className: 'LeCunNormal',
config: { seed: 42 }
};

const expectedInitializer = getInitializer(initializerConfig);
const actualInitializer = getInitializer(initializerConfig);

const expected = expectedInitializer.apply(shape, 'float32');
const actual = actualInitializer.apply(shape, 'float32');

expect(actual.shape).toEqual(expected.shape);
expect(actual.dtype).toEqual(expected.dtype);
expectTensorsClose(actual, expected);
});

it('Does not leak', () => {
expectNoLeakedTensors(() => getInitializer('LeCunNormal').apply([3]), 1);
});
Expand All @@ -314,6 +406,24 @@ describeMathCPU('LeCunUniform initializer', () => {
expectTensorsValuesInRange(weights, -bound, bound);
});

it('with configured seed', () => {

const initializerConfig: serialization.ConfigDict = {
className: 'LeCunUniform',
config: { seed: 42 }
};

const expectedInitializer = getInitializer(initializerConfig);
const actualInitializer = getInitializer(initializerConfig);

const expected = expectedInitializer.apply(shape, 'float32');
const actual = actualInitializer.apply(shape, 'float32');

expect(actual.shape).toEqual(expected.shape);
expect(actual.dtype).toEqual(expected.dtype);
expectTensorsClose(actual, expected);
});

it('Does not leak', () => {
expectNoLeakedTensors(() => getInitializer('LeCunUniform').apply([3]), 1);
});
Expand Down Expand Up @@ -348,6 +458,25 @@ describeMathCPU('TruncatedNormal initializer', () => {
expect(weights.dtype).toEqual('float32');
expectTensorsValuesInRange(weights, 0.0, 2.0);
});

it('with configured seed', () => {

const initializerConfig: serialization.ConfigDict = {
className: 'TruncatedNormal',
config: { seed: 42 }
};

const expectedInitializer = getInitializer(initializerConfig);
const actualInitializer = getInitializer(initializerConfig);

const expected = expectedInitializer.apply(shape, 'float32');
const actual = actualInitializer.apply(shape, 'float32');

expect(actual.shape).toEqual(expected.shape);
expect(actual.dtype).toEqual(expected.dtype);
expectTensorsClose(actual, expected);
});

it('Does not leak', () => {
expectNoLeakedTensors(
() => getInitializer('TruncatedNormal').apply([3]), 1);
Expand Down Expand Up @@ -403,6 +532,25 @@ describeMathCPU('Glorot uniform initializer', () => {
.toBeGreaterThan(-limit);
});
});

it('with configured seed', () => {

const initializerConfig: serialization.ConfigDict = {
className: 'GlorotUniform',
config: { seed: 42 }
};

const expectedInitializer = getInitializer(initializerConfig);
const actualInitializer = getInitializer(initializerConfig);

const expected = expectedInitializer.apply([7, 2], 'float32');
const actual = actualInitializer.apply([7, 2], 'float32');

expect(actual.shape).toEqual(expected.shape);
expect(actual.dtype).toEqual(expected.dtype);
expectTensorsClose(actual, expected);
});

it('Does not leak', () => {
expectNoLeakedTensors(() => getInitializer('GlorotUniform').apply([3]), 1);
});
Expand All @@ -429,6 +577,27 @@ describeMathCPU('VarianceScaling initializer', () => {
const newConfig = newInit.getConfig();
expect(newConfig['distribution']).toEqual(baseConfig['distribution']);
});

it(`${distribution} with configured seed`, () => {

const initializerConfig: serialization.ConfigDict = {
className: 'VarianceScaling',
config: {
distribution,
seed: 42
}
};

const expectedInitializer = getInitializer(initializerConfig);
const actualInitializer = getInitializer(initializerConfig);

const expected = expectedInitializer.apply([7, 2], 'float32');
const actual = actualInitializer.apply([7, 2], 'float32');

expect(actual.shape).toEqual(expected.shape);
expect(actual.dtype).toEqual(expected.dtype);
expectTensorsClose(actual, expected);
});
});
});

Expand Down Expand Up @@ -485,6 +654,25 @@ describeMathCPU('Glorot normal initializer', () => {
expect(variance2).toBeLessThan(variance1);
});
});

it('with configured seed', () => {

const initializerConfig: serialization.ConfigDict = {
className: 'GlorotNormal',
config: { seed: 42 }
};

const expectedInitializer = getInitializer(initializerConfig);
const actualInitializer = getInitializer(initializerConfig);

const expected = expectedInitializer.apply([7, 2], 'float32');
const actual = actualInitializer.apply([7, 2], 'float32');

expect(actual.shape).toEqual(expected.shape);
expect(actual.dtype).toEqual(expected.dtype);
expectTensorsClose(actual, expected);
});

it('Does not leak', () => {
expectNoLeakedTensors(() => getInitializer('GlorotNormal').apply([3]), 1);
});
Expand Down

0 comments on commit aa14065

Please sign in to comment.