Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use node's util.types.isUint8Array etc for isTypedArray #7181

Merged
merged 4 commits into from Dec 16, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions tfjs-core/src/platforms/platform.ts
Expand Up @@ -48,4 +48,7 @@ export interface Platform {
decode(bytes: Uint8Array, encoding: string): string;

setTimeoutCustom?(functionRef: Function, delay: number): void;

isTypedArray(a: unknown): a is Float32Array|Int32Array|Uint8Array|
Uint8ClampedArray;
}
6 changes: 6 additions & 0 deletions tfjs-core/src/platforms/platform_browser.ts
Expand Up @@ -90,6 +90,12 @@ export class PlatformBrowser implements Platform {
}, true);
}
}

isTypedArray(a: unknown): a is Uint8Array | Float32Array | Int32Array
| Uint8ClampedArray {
return a instanceof Float32Array || a instanceof Int32Array ||
a instanceof Uint8Array || a instanceof Uint8ClampedArray;
}
}

if (env().get('IS_BROWSER')) {
Expand Down
15 changes: 15 additions & 0 deletions tfjs-core/src/platforms/platform_browser_test.ts
Expand Up @@ -147,4 +147,19 @@ describeWithFlags('setTimeout', BROWSER_ENVS, () => {
env().platform.setTimeoutCustom(_testSetTimeoutCustom, 0);
}
});

it('isTypedArray returns false if not a typed array', () => {
const platform = new PlatformBrowser();
expect(platform.isTypedArray([1, 2, 3])).toBeFalse();
});

for (const typedArrayConstructor of [Float32Array, Int32Array, Uint8Array,
Uint8ClampedArray]) {
it(`isTypedArray returns true if it is a ${typedArrayConstructor.name}`,
() => {
const platform = new PlatformBrowser();
const array = new typedArrayConstructor([1,2,3]);
expect(platform.isTypedArray(array)).toBeTrue();
});
}
});
7 changes: 7 additions & 0 deletions tfjs-core/src/platforms/platform_node.ts
Expand Up @@ -79,6 +79,13 @@ export class PlatformNode implements Platform {
}
return new this.util.TextDecoder(encoding).decode(bytes);
}
isTypedArray(a: unknown): a is Float32Array | Int32Array | Uint8Array
| Uint8ClampedArray {
return this.util.types.isFloat32Array(a)
|| this.util.types.isInt32Array(a)
|| this.util.types.isUint8Array(a)
|| this.util.types.isUint8ClampedArray(a);
}
}

if (env().get('IS_NODE') && !env().get('IS_BROWSER')) {
Expand Down
15 changes: 15 additions & 0 deletions tfjs-core/src/platforms/platform_node_test.ts
Expand Up @@ -125,4 +125,19 @@ describeWithFlags('PlatformNode', NODE_ENVS, () => {
expect(s.length).toBe(6);
expect(s).toEqual('Здраво');
});

it('isTypedArray returns false if not a typed array', () => {
const platform = new PlatformNode();
expect(platform.isTypedArray([1, 2, 3])).toBeFalse();
});

for (const typedArrayConstructor of [Float32Array, Int32Array, Uint8Array,
Uint8ClampedArray]) {
it(`isTypedArray returns true if it is a ${typedArrayConstructor.name}`,
() => {
const platform = new PlatformNode();
const array = new typedArrayConstructor([1,2,3]);
expect(platform.isTypedArray(array)).toBeTrue();
});
}
});
58 changes: 56 additions & 2 deletions tfjs-core/src/util.ts
Expand Up @@ -16,7 +16,7 @@
*/

import {env} from './environment';
import {BackendValues, DataType, TensorLike, TypedArray} from './types';
import {BackendValues, DataType, RecursiveArray, TensorLike, TypedArray} from './types';
import * as base from './util_base';
export * from './util_base';
export * from './hash_util';
Expand Down Expand Up @@ -44,7 +44,7 @@ export function toTypedArray(a: TensorLike, dtype: DataType): TypedArray {
throw new Error('Cannot convert a string[] to a TypedArray');
}
if (Array.isArray(a)) {
a = base.flatten(a);
a = flatten(a);
}

if (env().getBool('DEBUG')) {
Expand Down Expand Up @@ -131,3 +131,57 @@ export function decodeString(bytes: Uint8Array, encoding = 'utf-8'): string {
encoding = encoding || 'utf-8';
return env().platform.decode(bytes, encoding);
}

export function isTypedArray(a: {}): a is Float32Array|Int32Array|Uint8Array|
Uint8ClampedArray {
return env().platform.isTypedArray(a);
}

// NOTE: We explicitly type out what T extends instead of any so that
// util.flatten on a nested array of number doesn't try to infer T as a
// number[][], causing us to explicitly type util.flatten<number>().
/**
* Flattens an arbitrarily nested array.
*
* ```js
* const a = [[1, 2], [3, 4], [5, [6, [7]]]];
* const flat = tf.util.flatten(a);
* console.log(flat);
* ```
*
* @param arr The nested array to flatten.
* @param result The destination array which holds the elements.
* @param skipTypedArray If true, avoids flattening the typed arrays. Defaults
* to false.
*
* @doc {heading: 'Util', namespace: 'util'}
*/
export function
flatten<T extends number|boolean|string|Promise<number>|TypedArray>(
arr: T|RecursiveArray<T>, result: T[] = [], skipTypedArray = false): T[] {
if (result == null) {
result = [];
}
if (typeof arr === 'boolean' || typeof arr === 'number' ||
typeof arr === 'string' || base.isPromise(arr) || arr == null ||
isTypedArray(arr) && skipTypedArray) {
result.push(arr as T);
} else if (Array.isArray(arr) || isTypedArray(arr)) {
for (let i = 0; i < arr.length; ++i) {
flatten(arr[i], result, skipTypedArray);
}
} else {
let maxIndex = -1;
for (const key of Object.keys(arr)) {
// 0 or positive integer.
if (/^([1-9]+[0-9]*|0)$/.test(key)) {
maxIndex = Math.max(maxIndex, Number(key));
}
}
for (let i = 0; i <= maxIndex; i++) {
// tslint:disable-next-line: no-unnecessary-type-assertion
flatten((arr as RecursiveArray<T>)[i], result, skipTypedArray);
}
}
return result;
}
57 changes: 1 addition & 56 deletions tfjs-core/src/util_base.ts
Expand Up @@ -15,7 +15,7 @@
* =============================================================================
*/

import {BackendValues, DataType, DataTypeMap, FlatVector, NumericDataType, RecursiveArray, TensorLike, TypedArray, WebGLData, WebGPUData} from './types';
import {BackendValues, DataType, DataTypeMap, FlatVector, NumericDataType, TensorLike, TypedArray, WebGLData, WebGPUData} from './types';

/**
* Shuffles the array in-place using Fisher-Yates algorithm.
Expand Down Expand Up @@ -167,55 +167,6 @@ export function assertNonNull(a: TensorLike): void {
() => `The input to the tensor constructor must be a non-null value.`);
}

// NOTE: We explicitly type out what T extends instead of any so that
// util.flatten on a nested array of number doesn't try to infer T as a
// number[][], causing us to explicitly type util.flatten<number>().
/**
* Flattens an arbitrarily nested array.
*
* ```js
* const a = [[1, 2], [3, 4], [5, [6, [7]]]];
* const flat = tf.util.flatten(a);
* console.log(flat);
* ```
*
* @param arr The nested array to flatten.
* @param result The destination array which holds the elements.
* @param skipTypedArray If true, avoids flattening the typed arrays. Defaults
* to false.
*
* @doc {heading: 'Util', namespace: 'util'}
*/
export function
flatten<T extends number|boolean|string|Promise<number>|TypedArray>(
arr: T|RecursiveArray<T>, result: T[] = [], skipTypedArray = false): T[] {
if (result == null) {
result = [];
}
if (typeof arr === 'boolean' || typeof arr === 'number' ||
typeof arr === 'string' || isPromise(arr) || arr == null ||
isTypedArray(arr) && skipTypedArray) {
result.push(arr as T);
} else if (Array.isArray(arr) || isTypedArray(arr)) {
for (let i = 0; i < arr.length; ++i) {
flatten(arr[i], result, skipTypedArray);
}
} else {
let maxIndex = -1;
for (const key of Object.keys(arr)) {
// 0 or positive integer.
if (/^([1-9]+[0-9]*|0)$/.test(key)) {
maxIndex = Math.max(maxIndex, Number(key));
}
}
for (let i = 0; i <= maxIndex; i++) {
// tslint:disable-next-line: no-unnecessary-type-assertion
flatten((arr as RecursiveArray<T>)[i], result, skipTypedArray);
}
}
return result;
}

/**
* Returns the size (number of elements) of the tensor given its shape.
*
Expand Down Expand Up @@ -527,12 +478,6 @@ export function hasEncodingLoss(oldType: DataType, newType: DataType): boolean {
return true;
}

export function isTypedArray(a: {}): a is Float32Array|Int32Array|Uint8Array|
Uint8ClampedArray {
return a instanceof Float32Array || a instanceof Int32Array ||
a instanceof Uint8Array || a instanceof Uint8ClampedArray;
}

export function bytesPerElement(dtype: DataType): number {
if (dtype === 'float32' || dtype === 'int32') {
return 4;
Expand Down