diff --git a/package.json b/package.json index aabd535640..80c5af87a8 100644 --- a/package.json +++ b/package.json @@ -19,7 +19,7 @@ "@types/js-yaml": "^4.0.5", "@types/long": "4.0.1", "@types/mkdirp": "^0.5.2", - "@types/node": "^12.7.5", + "@types/node": "^18.11.15", "@types/node-fetch": "~2.1.2", "@types/offscreencanvas": "^2019.7.0", "@types/rollup-plugin-visualizer": "^4.2.1", diff --git a/tfjs-core/src/BUILD.bazel b/tfjs-core/src/BUILD.bazel index 49001f47fb..e10b620163 100644 --- a/tfjs-core/src/BUILD.bazel +++ b/tfjs-core/src/BUILD.bazel @@ -31,6 +31,7 @@ TEST_ENTRYPOINTS = [ "setup_test.ts", "worker_test.ts", "worker_node_test.ts", + "platforms/platform_node_test.ts", "ops/from_pixels_worker_test.ts", ] @@ -185,6 +186,26 @@ jasmine_node_test( ], ) +ts_library( + name = "platform_node_test_lib", + srcs = [ + "platforms/platform_node_test.ts", + ], + deps = [ + ":tfjs-core_lib", + ":tfjs-core_src_lib", + "//tfjs-backend-cpu/src:tfjs-backend-cpu_lib", + "@npm//@types/node", + ], +) + +jasmine_node_test( + name = "platform_node_test", + deps = [ + ":platform_node_test_lib", + ], +) + ts_library( name = "worker_test_lib", srcs = [ diff --git a/tfjs-core/src/platforms/platform.ts b/tfjs-core/src/platforms/platform.ts index 0b0ebbf4c5..60934f39d5 100644 --- a/tfjs-core/src/platforms/platform.ts +++ b/tfjs-core/src/platforms/platform.ts @@ -48,4 +48,7 @@ export interface Platform { decode(bytes: Uint8Array, encoding: string): string; setTimeoutCustom?(functionRef: Function, delay: number): void; + + isTypedArray(a: unknown): a is Float32Array|Int32Array|Uint8Array| + Uint8ClampedArray; } diff --git a/tfjs-core/src/platforms/platform_browser.ts b/tfjs-core/src/platforms/platform_browser.ts index e0cce4492a..1a1c0d9943 100644 --- a/tfjs-core/src/platforms/platform_browser.ts +++ b/tfjs-core/src/platforms/platform_browser.ts @@ -90,6 +90,12 @@ export class PlatformBrowser implements Platform { }, true); } } + + isTypedArray(a: unknown): a is Uint8Array | Float32Array | Int32Array + | Uint8ClampedArray { + return a instanceof Float32Array || a instanceof Int32Array || + a instanceof Uint8Array || a instanceof Uint8ClampedArray; + } } if (env().get('IS_BROWSER')) { diff --git a/tfjs-core/src/platforms/platform_browser_test.ts b/tfjs-core/src/platforms/platform_browser_test.ts index 46ec45569e..6b7661a3de 100644 --- a/tfjs-core/src/platforms/platform_browser_test.ts +++ b/tfjs-core/src/platforms/platform_browser_test.ts @@ -147,4 +147,19 @@ describeWithFlags('setTimeout', BROWSER_ENVS, () => { env().platform.setTimeoutCustom(_testSetTimeoutCustom, 0); } }); + + it('isTypedArray returns false if not a typed array', () => { + const platform = new PlatformBrowser(); + expect(platform.isTypedArray([1, 2, 3])).toBeFalse(); + }); + + for (const typedArrayConstructor of [Float32Array, Int32Array, Uint8Array, + Uint8ClampedArray]) { + it(`isTypedArray returns true if it is a ${typedArrayConstructor.name}`, + () => { + const platform = new PlatformBrowser(); + const array = new typedArrayConstructor([1,2,3]); + expect(platform.isTypedArray(array)).toBeTrue(); + }); + } }); diff --git a/tfjs-core/src/platforms/platform_node.ts b/tfjs-core/src/platforms/platform_node.ts index 05994d383b..361eca12f6 100644 --- a/tfjs-core/src/platforms/platform_node.ts +++ b/tfjs-core/src/platforms/platform_node.ts @@ -79,6 +79,13 @@ export class PlatformNode implements Platform { } return new this.util.TextDecoder(encoding).decode(bytes); } + isTypedArray(a: unknown): a is Float32Array | Int32Array | Uint8Array + | Uint8ClampedArray { + return this.util.types.isFloat32Array(a) + || this.util.types.isInt32Array(a) + || this.util.types.isUint8Array(a) + || this.util.types.isUint8ClampedArray(a); + } } if (env().get('IS_NODE') && !env().get('IS_BROWSER')) { diff --git a/tfjs-core/src/platforms/platform_node_test.ts b/tfjs-core/src/platforms/platform_node_test.ts index b34a0300bf..8e7d5808fe 100644 --- a/tfjs-core/src/platforms/platform_node_test.ts +++ b/tfjs-core/src/platforms/platform_node_test.ts @@ -16,11 +16,11 @@ */ import * as tf from '../index'; -import {describeWithFlags, NODE_ENVS} from '../jasmine_util'; import * as platform_node from './platform_node'; import {PlatformNode} from './platform_node'; +import * as vm from 'node:vm'; -describeWithFlags('PlatformNode', NODE_ENVS, () => { +describe('PlatformNode', () => { it('fetch should use global.fetch if defined', async () => { const globalFetch = tf.env().global.fetch; @@ -125,4 +125,33 @@ describeWithFlags('PlatformNode', NODE_ENVS, () => { expect(s.length).toBe(6); expect(s).toEqual('Здраво'); }); + + describe('isTypedArray', () => { + let platform: PlatformNode; + beforeEach(() => { + platform = new PlatformNode(); + }); + + it('returns false if not a typed array', () => { + expect(platform.isTypedArray([1, 2, 3])).toBeFalse(); + }); + + for (const typedArrayConstructor of [Float32Array, Int32Array, Uint8Array, + Uint8ClampedArray]) { + it(`returns true if it is a ${typedArrayConstructor.name}`, + () => { + const array = new typedArrayConstructor([1,2,3]); + expect(platform.isTypedArray(array)).toBeTrue(); + }); + } + + it('works on values created in a new node context', async () => { + const array = await new Promise((resolve) => { + const code = `resolve(new Uint8Array([1, 2, 3]));`; + vm.runInNewContext(code, {resolve}); + }); + + expect(platform.isTypedArray(array)).toBeTrue(); + }); + }); }); diff --git a/tfjs-core/src/util.ts b/tfjs-core/src/util.ts index 16ee267186..f60349b846 100644 --- a/tfjs-core/src/util.ts +++ b/tfjs-core/src/util.ts @@ -16,7 +16,7 @@ */ import {env} from './environment'; -import {BackendValues, DataType, TensorLike, TypedArray} from './types'; +import {BackendValues, DataType, RecursiveArray, TensorLike, TypedArray} from './types'; import * as base from './util_base'; export * from './util_base'; export * from './hash_util'; @@ -44,7 +44,7 @@ export function toTypedArray(a: TensorLike, dtype: DataType): TypedArray { throw new Error('Cannot convert a string[] to a TypedArray'); } if (Array.isArray(a)) { - a = base.flatten(a); + a = flatten(a); } if (env().getBool('DEBUG')) { @@ -131,3 +131,57 @@ export function decodeString(bytes: Uint8Array, encoding = 'utf-8'): string { encoding = encoding || 'utf-8'; return env().platform.decode(bytes, encoding); } + +export function isTypedArray(a: {}): a is Float32Array|Int32Array|Uint8Array| + Uint8ClampedArray { + return env().platform.isTypedArray(a); +} + +// NOTE: We explicitly type out what T extends instead of any so that +// util.flatten on a nested array of number doesn't try to infer T as a +// number[][], causing us to explicitly type util.flatten(). +/** + * Flattens an arbitrarily nested array. + * + * ```js + * const a = [[1, 2], [3, 4], [5, [6, [7]]]]; + * const flat = tf.util.flatten(a); + * console.log(flat); + * ``` + * + * @param arr The nested array to flatten. + * @param result The destination array which holds the elements. + * @param skipTypedArray If true, avoids flattening the typed arrays. Defaults + * to false. + * + * @doc {heading: 'Util', namespace: 'util'} + */ +export function +flatten|TypedArray>( + arr: T|RecursiveArray, result: T[] = [], skipTypedArray = false): T[] { + if (result == null) { + result = []; + } + if (typeof arr === 'boolean' || typeof arr === 'number' || + typeof arr === 'string' || base.isPromise(arr) || arr == null || + isTypedArray(arr) && skipTypedArray) { + result.push(arr as T); + } else if (Array.isArray(arr) || isTypedArray(arr)) { + for (let i = 0; i < arr.length; ++i) { + flatten(arr[i], result, skipTypedArray); + } + } else { + let maxIndex = -1; + for (const key of Object.keys(arr)) { + // 0 or positive integer. + if (/^([1-9]+[0-9]*|0)$/.test(key)) { + maxIndex = Math.max(maxIndex, Number(key)); + } + } + for (let i = 0; i <= maxIndex; i++) { + // tslint:disable-next-line: no-unnecessary-type-assertion + flatten((arr as RecursiveArray)[i], result, skipTypedArray); + } + } + return result; +} diff --git a/tfjs-core/src/util_base.ts b/tfjs-core/src/util_base.ts index 6d4ec37779..1d81a0ab77 100644 --- a/tfjs-core/src/util_base.ts +++ b/tfjs-core/src/util_base.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {BackendValues, DataType, DataTypeMap, FlatVector, NumericDataType, RecursiveArray, TensorLike, TypedArray, WebGLData, WebGPUData} from './types'; +import {BackendValues, DataType, DataTypeMap, FlatVector, NumericDataType, TensorLike, TypedArray, WebGLData, WebGPUData} from './types'; /** * Shuffles the array in-place using Fisher-Yates algorithm. @@ -167,55 +167,6 @@ export function assertNonNull(a: TensorLike): void { () => `The input to the tensor constructor must be a non-null value.`); } -// NOTE: We explicitly type out what T extends instead of any so that -// util.flatten on a nested array of number doesn't try to infer T as a -// number[][], causing us to explicitly type util.flatten(). -/** - * Flattens an arbitrarily nested array. - * - * ```js - * const a = [[1, 2], [3, 4], [5, [6, [7]]]]; - * const flat = tf.util.flatten(a); - * console.log(flat); - * ``` - * - * @param arr The nested array to flatten. - * @param result The destination array which holds the elements. - * @param skipTypedArray If true, avoids flattening the typed arrays. Defaults - * to false. - * - * @doc {heading: 'Util', namespace: 'util'} - */ -export function -flatten|TypedArray>( - arr: T|RecursiveArray, result: T[] = [], skipTypedArray = false): T[] { - if (result == null) { - result = []; - } - if (typeof arr === 'boolean' || typeof arr === 'number' || - typeof arr === 'string' || isPromise(arr) || arr == null || - isTypedArray(arr) && skipTypedArray) { - result.push(arr as T); - } else if (Array.isArray(arr) || isTypedArray(arr)) { - for (let i = 0; i < arr.length; ++i) { - flatten(arr[i], result, skipTypedArray); - } - } else { - let maxIndex = -1; - for (const key of Object.keys(arr)) { - // 0 or positive integer. - if (/^([1-9]+[0-9]*|0)$/.test(key)) { - maxIndex = Math.max(maxIndex, Number(key)); - } - } - for (let i = 0; i <= maxIndex; i++) { - // tslint:disable-next-line: no-unnecessary-type-assertion - flatten((arr as RecursiveArray)[i], result, skipTypedArray); - } - } - return result; -} - /** * Returns the size (number of elements) of the tensor given its shape. * @@ -527,12 +478,6 @@ export function hasEncodingLoss(oldType: DataType, newType: DataType): boolean { return true; } -export function isTypedArray(a: {}): a is Float32Array|Int32Array|Uint8Array| - Uint8ClampedArray { - return a instanceof Float32Array || a instanceof Int32Array || - a instanceof Uint8Array || a instanceof Uint8ClampedArray; -} - export function bytesPerElement(dtype: DataType): number { if (dtype === 'float32' || dtype === 'int32') { return 4; diff --git a/yarn.lock b/yarn.lock index 18a53916fc..21c37745cc 100644 --- a/yarn.lock +++ b/yarn.lock @@ -384,10 +384,10 @@ resolved "https://registry.yarnpkg.com/@types/node/-/node-10.17.60.tgz#35f3d6213daed95da7f0f73e75bcc6980e90597b" integrity sha512-F0KIgDJfy2nA3zMLmWGKxcH2ZVEtCZXHHdOQs2gSaQ27+lNeEfGxzkIw90aXswATX7AZ33tahPbzy6KAfUreVw== -"@types/node@^12.7.5": - version "12.20.28" - resolved "https://registry.yarnpkg.com/@types/node/-/node-12.20.28.tgz#4b20048c6052b5f51a8d5e0d2acbf63d5a17e1e2" - integrity sha512-cBw8gzxUPYX+/5lugXIPksioBSbE42k0fZ39p+4yRzfYjN6++eq9kAPdlY9qm+MXyfbk9EmvCYAYRn380sF46w== +"@types/node@^18.11.15": + version "18.11.15" + resolved "https://registry.yarnpkg.com/@types/node/-/node-18.11.15.tgz#de0e1fbd2b22b962d45971431e2ae696643d3f5d" + integrity sha512-VkhBbVo2+2oozlkdHXLrb3zjsRkpdnaU2bXmX8Wgle3PUi569eLRaHGlgETQHR7lLL1w7GiG3h9SnePhxNDecw== "@types/offscreencanvas@^2019.7.0": version "2019.7.0"