Skip to content

Commit

Permalink
Use static getters to get optimizer class names (#7168)
Browse files Browse the repository at this point in the history
Each `Optimizer` lists its class name as a static property of the class so it can be serialized and deserialized. This prevents the class from being tree-shaken because bundlers will compile it like this:

```
class SomeOptimizer {
  ...
}

// The bundler can not remove this assignment because
// SomeOptimizer.className could be a setter with a side effect.
SomeOptimizer.className = 'SomeOptimizer';
```

This PR uses a static getter for the class name instead, which bundlers can tree-shake properly.
  • Loading branch information
mattsoulanille committed Dec 12, 2022
1 parent 031582f commit f200cb8
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 7 deletions.
7 changes: 6 additions & 1 deletion tfjs-core/src/optimizers/adadelta_optimizer.ts
Expand Up @@ -31,7 +31,12 @@ import {Optimizer, OptimizerVariable} from './optimizer';
/** @doclink Optimizer */
export class AdadeltaOptimizer extends Optimizer {
/** @nocollapse */
static className = 'Adadelta'; // Name matters for Python compatibility.
static get className() {
// Name matters for Python compatibility.
// This is a getter instead of a property because when it's a property, it
// prevents the entire class from being tree-shaken.
return 'Adadelta';
}
private accumulatedGrads: OptimizerVariable[] = [];
private accumulatedUpdates: OptimizerVariable[] = [];

Expand Down
7 changes: 6 additions & 1 deletion tfjs-core/src/optimizers/adagrad_optimizer.ts
Expand Up @@ -31,7 +31,12 @@ import {Optimizer, OptimizerVariable} from './optimizer';
/** @doclink Optimizer */
export class AdagradOptimizer extends Optimizer {
/** @nocollapse */
static className = 'Adagrad'; // Note: Name matters for Python compatibility.
static get className() {
// Name matters for Python compatibility.
// This is a getter instead of a property because when it's a property, it
// prevents the entire class from being tree-shaken.
return 'Adagrad';
}

private accumulatedGrads: OptimizerVariable[] = [];

Expand Down
7 changes: 6 additions & 1 deletion tfjs-core/src/optimizers/adam_optimizer.ts
Expand Up @@ -34,7 +34,12 @@ import {Optimizer, OptimizerVariable} from './optimizer';

export class AdamOptimizer extends Optimizer {
/** @nocollapse */
static className = 'Adam'; // Note: Name matters for Python compatibility.
static get className() {
// Name matters for Python compatibility.
// This is a getter instead of a property because when it's a property, it
// prevents the entire class from being tree-shaken.
return 'Adam';
}
private accBeta1: Variable;
private accBeta2: Variable;

Expand Down
7 changes: 6 additions & 1 deletion tfjs-core/src/optimizers/adamax_optimizer.ts
Expand Up @@ -33,7 +33,12 @@ import {Optimizer, OptimizerVariable} from './optimizer';

export class AdamaxOptimizer extends Optimizer {
/** @nocollapse */
static className = 'Adamax'; // Note: Name matters for Python compatbility.
static get className() {
// Name matters for Python compatibility.
// This is a getter instead of a property because when it's a property, it
// prevents the entire class from being tree-shaken.
return 'Adamax';
}
private accBeta1: Variable;
private iteration: Variable;

Expand Down
7 changes: 6 additions & 1 deletion tfjs-core/src/optimizers/momentum_optimizer.ts
Expand Up @@ -32,7 +32,12 @@ import {SGDOptimizer} from './sgd_optimizer';
export class MomentumOptimizer extends SGDOptimizer {
/** @nocollapse */
// Name matters for Python compatibility.
static override className = 'Momentum';
static override get className() {
// Name matters for Python compatibility.
// This is a getter instead of a property because when it's a property, it
// prevents the entire class from being tree-shaken.
return 'Momentum';
}
private m: Scalar;
private accumulations: OptimizerVariable[] = [];

Expand Down
7 changes: 6 additions & 1 deletion tfjs-core/src/optimizers/rmsprop_optimizer.ts
Expand Up @@ -32,7 +32,12 @@ import {Optimizer, OptimizerVariable} from './optimizer';
/** @doclink Optimizer */
export class RMSPropOptimizer extends Optimizer {
/** @nocollapse */
static className = 'RMSProp'; // Note: Name matters for Python compatibility.
static get className() {
// Name matters for Python compatibility.
// This is a getter instead of a property because when it's a property, it
// prevents the entire class from being tree-shaken.
return 'RMSProp';
}
private centered: boolean;

private accumulatedMeanSquares: OptimizerVariable[] = [];
Expand Down
7 changes: 6 additions & 1 deletion tfjs-core/src/optimizers/sgd_optimizer.ts
Expand Up @@ -29,7 +29,12 @@ import {Optimizer} from './optimizer';
/** @doclink Optimizer */
export class SGDOptimizer extends Optimizer {
/** @nocollapse */
static className = 'SGD'; // Note: Name matters for Python compatibility.
static get className() {
// Name matters for Python compatibility.
// This is a getter instead of a property because when it's a property, it
// prevents the entire class from being tree-shaken.
return 'SGD';
}
protected c: Scalar;

constructor(protected learningRate: number) {
Expand Down

0 comments on commit f200cb8

Please sign in to comment.