Skip to content

Commit

Permalink
Refactor type conversion for read back (tensorflow#7044)
Browse files Browse the repository at this point in the history
* Refactor type conversion for read back

Bug: tensorflow#6965

* Cleanup

Co-authored-by: Ping Yu <4018+pyu10055@users.noreply.github.com>
  • Loading branch information
axinging and pyu10055 committed Dec 9, 2022
1 parent e6f19ef commit b92f803
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 47 deletions.
7 changes: 4 additions & 3 deletions tfjs-backend-cpu/src/backend_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ export class MathBackendCPU extends KernelBackend {
this.data = new DataStorage(this, engine());
}

override write(values: backend_util.BackendValues, shape: number[],
override write(
values: backend_util.BackendValues, shape: number[],
dtype: DataType): DataId {
if (this.firstUse) {
this.firstUse = false;
Expand Down Expand Up @@ -138,8 +139,8 @@ export class MathBackendCPU extends KernelBackend {
this.readSync(complexTensorInfos.imag.dataId) as Float32Array;
return backend_util.mergeRealAndImagArrays(realValues, imagValues);
}

return this.data.get(dataId).values;
return util.convertBackendValuesAndArrayBuffer(
this.data.get(dataId).values, dtype);
}

bufferSync<R extends Rank, D extends DataType>(t: TensorInfo):
Expand Down
41 changes: 17 additions & 24 deletions tfjs-backend-webgpu/src/backend_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import './flags_webgpu';

import {backend_util, buffer, DataStorage, DataType, engine, env, GPUData, KernelBackend, Rank, RecursiveArray, ShapeMap, Tensor, TensorBuffer, TensorInfo, TimingInfo, TypedArray, util, WebGPUData} from '@tensorflow/tfjs-core';
import {backend_util, BackendValues, buffer, DataStorage, DataType, engine, env, GPUData, KernelBackend, Rank, RecursiveArray, ShapeMap, Tensor, TensorBuffer, TensorInfo, TimingInfo, TypedArray, util, WebGPUData} from '@tensorflow/tfjs-core';

import {AdapterInfo} from './adapter_info';
import {BufferManager} from './buffer_manager';
Expand Down Expand Up @@ -46,7 +46,7 @@ export type TextureInfo = {
};

type TensorData = {
values: backend_util.BackendValues,
values: BackendValues,
dtype: DataType,
shape: number[],
refCount: number,
Expand Down Expand Up @@ -290,9 +290,8 @@ export class WebGPUBackend extends KernelBackend {
}
}

override write(
values: backend_util.BackendValues, shape: number[],
dtype: DataType): DataId {
override write(values: BackendValues, shape: number[], dtype: DataType):
DataId {
if (dtype === 'complex64' && values != null) {
throw new Error(
`Cannot write to a complex64 dtype. ` +
Expand All @@ -304,8 +303,8 @@ export class WebGPUBackend extends KernelBackend {
}

override move(
dataId: DataId, values: backend_util.BackendValues, shape: number[],
dtype: DataType, refCount: number): void {
dataId: DataId, values: BackendValues, shape: number[], dtype: DataType,
refCount: number): void {
if (dtype === 'complex64') {
throw new Error(
`Cannot write to a complex64 dtype. ` +
Expand Down Expand Up @@ -357,7 +356,7 @@ export class WebGPUBackend extends KernelBackend {
}

public async getBufferData(buffer: GPUBuffer, size: number):
Promise<backend_util.BackendValues> {
Promise<ArrayBuffer> {
const staging = this.bufferManager.acquireBuffer(
size, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ);
this.ensureCommandEncoderReady();
Expand All @@ -383,11 +382,11 @@ export class WebGPUBackend extends KernelBackend {
this.dummyContext.getCurrentTexture();
}

return values as backend_util.BackendValues;
return values;
}

private convertAndCacheOnCPU(dataId: DataId, data: backend_util.TypedArray):
backend_util.TypedArray {
private convertAndCacheOnCPU(dataId: DataId, data: BackendValues):
BackendValues {
const tensorData = this.tensorMap.get(dataId);
this.releaseResource(dataId);
tensorData.values = data;
Expand All @@ -396,7 +395,7 @@ export class WebGPUBackend extends KernelBackend {

// TODO: Remove once this is fixed:
// https://github.com/tensorflow/tfjs/issues/1595
override readSync(dataId: object): backend_util.BackendValues {
override readSync(dataId: object): BackendValues {
const tensorData = this.tensorMap.get(dataId);
const {values} = tensorData;

Expand All @@ -408,7 +407,7 @@ export class WebGPUBackend extends KernelBackend {
return values;
}

override async read(dataId: object): Promise<backend_util.BackendValues> {
override async read(dataId: object): Promise<BackendValues> {
if (!this.tensorMap.has(dataId)) {
throw new Error(`Tensor ${dataId} was not registered!`);
}
Expand All @@ -417,15 +416,11 @@ export class WebGPUBackend extends KernelBackend {
const {values} = tensorData;

if (values != null) {
// TODO(xing.xu@intel.com): Merge backend_util.BackendValues and
// backend_util.TypedArray.
return this.convertAndCacheOnCPU(
dataId, values as backend_util.TypedArray) as
backend_util.BackendValues;
return this.convertAndCacheOnCPU(dataId, values);
}

// Download the values from the GPU.
let vals: backend_util.BackendValues;
let vals: BackendValues;
if (tensorData.dtype === 'complex64') {
const ps = await Promise.all([
this.read(tensorData.complexTensorInfos.real.dataId),
Expand All @@ -439,8 +434,7 @@ export class WebGPUBackend extends KernelBackend {
} else {
const bufferInfo = tensorData.resourceInfo as BufferInfo;
const data = await this.getBufferData(bufferInfo.buffer, bufferInfo.size);
vals = webgpu_util.ArrayBufferToTypedArray(
data as ArrayBuffer, tensorData.dtype);
vals = util.convertBackendValuesAndArrayBuffer(data, tensorData.dtype);
}
this.convertAndCacheOnCPU(dataId, vals);
return vals;
Expand Down Expand Up @@ -605,13 +599,12 @@ export class WebGPUBackend extends KernelBackend {

makeTensorInfo(
shape: number[], dtype: DataType,
values?: backend_util.BackendValues|string[]): TensorInfo {
values?: BackendValues|string[]): TensorInfo {
if (dtype === 'string' && values != null && values.length > 0 &&
util.isString(values[0])) {
values = (values as unknown as string[]).map(d => util.encodeString(d));
}
const dataId =
this.write(values as backend_util.BackendValues, shape, dtype);
const dataId = this.write(values as BackendValues, shape, dtype);
return {dataId, shape, dtype};
}

Expand Down
13 changes: 6 additions & 7 deletions tfjs-backend-webgpu/src/backend_webgpu_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ const {expectArraysEqual, expectArraysClose} = test_util;

import {WebGPUBackend, WebGPUMemoryInfo} from './backend_webgpu';
import {describeWebGPU} from './test_util';
import * as webgpu_util from './webgpu_util';

describeWebGPU('backend webgpu cpu forwarding turned on', () => {
let cpuForwardFlagSaved: boolean;
Expand Down Expand Up @@ -274,8 +273,8 @@ describeWebGPU('keeping data on gpu ', () => {
`Expected: float32`);
}
const resData = await webGPUBackend.getBufferData(res.buffer, res.bufSize);
const values = webgpu_util.ArrayBufferToTypedArray(
resData as ArrayBuffer, res.tensorRef.dtype);
const values = tf.util.convertBackendValuesAndArrayBuffer(
resData, res.tensorRef.dtype);
expectArraysEqual(values, data);
});

Expand All @@ -294,8 +293,8 @@ describeWebGPU('keeping data on gpu ', () => {
`Expected: float32`);
}
const resData = await webGPUBackend.getBufferData(res.buffer, res.bufSize);
const values = webgpu_util.ArrayBufferToTypedArray(
resData as ArrayBuffer, res.tensorRef.dtype);
const values = tf.util.convertBackendValuesAndArrayBuffer(
resData, res.tensorRef.dtype);
expectArraysEqual(values, data);
});

Expand Down Expand Up @@ -340,8 +339,8 @@ describeWebGPU('keeping data on gpu ', () => {

const res = result as unknown as GPUData;
const resData = await webGPUBackend.getBufferData(res.buffer, res.bufSize);
const values = webgpu_util.ArrayBufferToTypedArray(
resData as ArrayBuffer, res.tensorRef.dtype);
const values = tf.util.convertBackendValuesAndArrayBuffer(
resData, res.tensorRef.dtype);
expectArraysEqual(values, data);
});

Expand Down
12 changes: 0 additions & 12 deletions tfjs-backend-webgpu/src/webgpu_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -154,18 +154,6 @@ export function GPUBytesPerElement(dtype: DataType): number {
}
}

export function ArrayBufferToTypedArray(data: ArrayBuffer, dtype: DataType) {
if (dtype === 'float32') {
return new Float32Array(data);
} else if (dtype === 'int32') {
return new Int32Array(data);
} else if (dtype === 'bool' || dtype === 'string') {
return Uint8Array.from(new Int32Array(data));
} else {
throw new Error(`Unknown dtype ${dtype}`);
}
}

export function isWebGPUSupported(): boolean {
return ((typeof window !== 'undefined') ||
//@ts-ignore
Expand Down
19 changes: 18 additions & 1 deletion tfjs-core/src/util_base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* =============================================================================
*/

import {DataType, DataTypeMap, FlatVector, NumericDataType, RecursiveArray, TensorLike, TypedArray, WebGLData, WebGPUData} from './types';
import {BackendValues, DataType, DataTypeMap, FlatVector, NumericDataType, RecursiveArray, TensorLike, TypedArray, WebGLData, WebGPUData} from './types';

/**
* Shuffles the array in-place using Fisher-Yates algorithm.
Expand Down Expand Up @@ -661,6 +661,23 @@ export function toNestedArray(
return createNestedArray(0, shape, a, isComplex);
}

export function convertBackendValuesAndArrayBuffer(
data: BackendValues|ArrayBuffer, dtype: DataType) {
// If is type Uint8Array[], return it directly.
if (Array.isArray(data)) {
return data;
}
if (dtype === 'float32') {
return data instanceof Float32Array ? data : new Float32Array(data);
} else if (dtype === 'int32') {
return data instanceof Int32Array ? data : new Int32Array(data);
} else if (dtype === 'bool' || dtype === 'string') {
return Uint8Array.from(new Int32Array(data));
} else {
throw new Error(`Unknown dtype ${dtype}`);
}
}

export function makeOnesTypedArray<D extends DataType>(
size: number, dtype: D): DataTypeMap[D] {
const array = makeZerosTypedArray(size, dtype);
Expand Down

0 comments on commit b92f803

Please sign in to comment.