diff --git a/tfjs-core/src/platforms/platform_browser.ts b/tfjs-core/src/platforms/platform_browser.ts index 1a1c0d9943..92a6799231 100644 --- a/tfjs-core/src/platforms/platform_browser.ts +++ b/tfjs-core/src/platforms/platform_browser.ts @@ -93,11 +93,16 @@ export class PlatformBrowser implements Platform { isTypedArray(a: unknown): a is Uint8Array | Float32Array | Int32Array | Uint8ClampedArray { - return a instanceof Float32Array || a instanceof Int32Array || - a instanceof Uint8Array || a instanceof Uint8ClampedArray; + return isTypedArrayBrowser(a); } } +export function isTypedArrayBrowser(a: unknown): a is Uint8Array + | Float32Array | Int32Array | Uint8ClampedArray { + return a instanceof Float32Array || a instanceof Int32Array || + a instanceof Uint8Array || a instanceof Uint8ClampedArray; +} + if (env().get('IS_BROWSER')) { env().setPlatform('browser', new PlatformBrowser()); diff --git a/tfjs-core/src/util.ts b/tfjs-core/src/util.ts index f60349b846..493a4dbc8a 100644 --- a/tfjs-core/src/util.ts +++ b/tfjs-core/src/util.ts @@ -16,6 +16,7 @@ */ import {env} from './environment'; +import { isTypedArrayBrowser, PlatformBrowser } from './platforms/platform_browser'; import {BackendValues, DataType, RecursiveArray, TensorLike, TypedArray} from './types'; import * as base from './util_base'; export * from './util_base'; @@ -134,7 +135,8 @@ export function decodeString(bytes: Uint8Array, encoding = 'utf-8'): string { export function isTypedArray(a: {}): a is Float32Array|Int32Array|Uint8Array| Uint8ClampedArray { - return env().platform.isTypedArray(a); + const isTypedArray = env().platform.isTypedArray || isTypedArrayBrowser; + return isTypedArray(a); } // NOTE: We explicitly type out what T extends instead of any so that