From e66bc051d5388ac3765bb5da225034326d455b5c Mon Sep 17 00:00:00 2001 From: Ahmed Sabie Date: Thu, 17 Nov 2022 20:23:03 -0500 Subject: [PATCH] Fix flatten implementation on objects --- tfjs-core/src/util_base.ts | 18 ++++++++++++++++-- tfjs-core/src/util_test.ts | 19 +++++++++++++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/tfjs-core/src/util_base.ts b/tfjs-core/src/util_base.ts index f4a6f32d22..f13cdef48a 100644 --- a/tfjs-core/src/util_base.ts +++ b/tfjs-core/src/util_base.ts @@ -192,12 +192,26 @@ flatten|TypedArray>( if (result == null) { result = []; } - if (Array.isArray(arr) || isTypedArray(arr) && !skipTypedArray) { + if (typeof arr === 'boolean' || typeof arr === 'number' || + typeof arr === 'string' || isPromise(arr) || arr == null || + isTypedArray(arr) && skipTypedArray) { + result.push(arr as T); + } else if (Array.isArray(arr) || isTypedArray(arr)) { for (let i = 0; i < arr.length; ++i) { flatten(arr[i], result, skipTypedArray); } } else { - result.push(arr as T); + let maxIndex = -1; + for (const key of Object.keys(arr)) { + // 0 or positive integer. + if (/^([1-9]+[0-9]*|0)$/.test(key)) { + maxIndex = Math.max(maxIndex, Number(key)); + } + } + for (let i = 0; i <= maxIndex; i++) { + // tslint:disable-next-line: no-unnecessary-type-assertion + flatten((arr as RecursiveArray)[i], result, skipTypedArray); + } } return result; } diff --git a/tfjs-core/src/util_test.ts b/tfjs-core/src/util_test.ts index 8683e76bdd..9d9f0a229f 100644 --- a/tfjs-core/src/util_test.ts +++ b/tfjs-core/src/util_test.ts @@ -136,6 +136,11 @@ describe('Util', () => { }); describe('util.flatten', () => { + it('empty', () => { + const data: number[] = []; + expect(util.flatten(data)).toEqual([]); + }); + it('nested number arrays', () => { expect(util.flatten([[1, 2, 3], [4, 5, 6]])).toEqual([1, 2, 3, 4, 5, 6]); expect(util.flatten([[[1, 2], [3, 4], [5, 6], [7, 8]]])).toEqual([ @@ -169,6 +174,20 @@ describe('util.flatten', () => { new Uint8Array([7, 8]) ]); }); + + it('Int8Array', () => { + const data = [new Int8Array([1, 2])]; + expect(util.flatten(data)).toEqual([1, 2]); + }); + + it('index signature', () => { + const data: {[index: number]: number} = {0: 1, 1: 2}; + // Will be ignored since array iteration ignores negatives. + data[-1] = -1; + // Will be ignored since non-integer array keys are ignored. + data[3.2] = 4; + expect(util.flatten(data)).toEqual([1, 2]); + }); }); function encodeStrings(a: string[]): Uint8Array[] {