Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor type conversion for read back #7044

Merged
merged 3 commits into from Dec 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 4 additions & 3 deletions tfjs-backend-cpu/src/backend_cpu.ts
Expand Up @@ -49,7 +49,8 @@ export class MathBackendCPU extends KernelBackend {
this.data = new DataStorage(this, engine());
}

override write(values: backend_util.BackendValues, shape: number[],
override write(
values: backend_util.BackendValues, shape: number[],
dtype: DataType): DataId {
if (this.firstUse) {
this.firstUse = false;
Expand Down Expand Up @@ -138,8 +139,8 @@ export class MathBackendCPU extends KernelBackend {
this.readSync(complexTensorInfos.imag.dataId) as Float32Array;
return backend_util.mergeRealAndImagArrays(realValues, imagValues);
}

return this.data.get(dataId).values;
return util.convertBackendValuesAndArrayBuffer(
this.data.get(dataId).values, dtype);
}

bufferSync<R extends Rank, D extends DataType>(t: TensorInfo):
Expand Down
41 changes: 17 additions & 24 deletions tfjs-backend-webgpu/src/backend_webgpu.ts
Expand Up @@ -17,7 +17,7 @@

import './flags_webgpu';

import {backend_util, buffer, DataStorage, DataType, engine, env, GPUData, KernelBackend, Rank, RecursiveArray, ShapeMap, Tensor, TensorBuffer, TensorInfo, TimingInfo, TypedArray, util, WebGPUData} from '@tensorflow/tfjs-core';
import {backend_util, BackendValues, buffer, DataStorage, DataType, engine, env, GPUData, KernelBackend, Rank, RecursiveArray, ShapeMap, Tensor, TensorBuffer, TensorInfo, TimingInfo, TypedArray, util, WebGPUData} from '@tensorflow/tfjs-core';

import {AdapterInfo} from './adapter_info';
import {BufferManager} from './buffer_manager';
Expand Down Expand Up @@ -46,7 +46,7 @@ export type TextureInfo = {
};

type TensorData = {
values: backend_util.BackendValues,
values: BackendValues,
dtype: DataType,
shape: number[],
refCount: number,
Expand Down Expand Up @@ -290,9 +290,8 @@ export class WebGPUBackend extends KernelBackend {
}
}

override write(
values: backend_util.BackendValues, shape: number[],
dtype: DataType): DataId {
override write(values: BackendValues, shape: number[], dtype: DataType):
DataId {
if (dtype === 'complex64' && values != null) {
throw new Error(
`Cannot write to a complex64 dtype. ` +
Expand All @@ -304,8 +303,8 @@ export class WebGPUBackend extends KernelBackend {
}

override move(
dataId: DataId, values: backend_util.BackendValues, shape: number[],
dtype: DataType, refCount: number): void {
dataId: DataId, values: BackendValues, shape: number[], dtype: DataType,
refCount: number): void {
if (dtype === 'complex64') {
throw new Error(
`Cannot write to a complex64 dtype. ` +
Expand Down Expand Up @@ -357,7 +356,7 @@ export class WebGPUBackend extends KernelBackend {
}

public async getBufferData(buffer: GPUBuffer, size: number):
Promise<backend_util.BackendValues> {
Promise<ArrayBuffer> {
const staging = this.bufferManager.acquireBuffer(
size, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ);
this.ensureCommandEncoderReady();
Expand All @@ -383,11 +382,11 @@ export class WebGPUBackend extends KernelBackend {
this.dummyContext.getCurrentTexture();
}

return values as backend_util.BackendValues;
return values;
}

private convertAndCacheOnCPU(dataId: DataId, data: backend_util.TypedArray):
backend_util.TypedArray {
private convertAndCacheOnCPU(dataId: DataId, data: BackendValues):
BackendValues {
const tensorData = this.tensorMap.get(dataId);
this.releaseResource(dataId);
tensorData.values = data;
Expand All @@ -396,7 +395,7 @@ export class WebGPUBackend extends KernelBackend {

// TODO: Remove once this is fixed:
// https://github.com/tensorflow/tfjs/issues/1595
override readSync(dataId: object): backend_util.BackendValues {
override readSync(dataId: object): BackendValues {
const tensorData = this.tensorMap.get(dataId);
const {values} = tensorData;

Expand All @@ -408,7 +407,7 @@ export class WebGPUBackend extends KernelBackend {
return values;
}

override async read(dataId: object): Promise<backend_util.BackendValues> {
override async read(dataId: object): Promise<BackendValues> {
if (!this.tensorMap.has(dataId)) {
throw new Error(`Tensor ${dataId} was not registered!`);
}
Expand All @@ -417,15 +416,11 @@ export class WebGPUBackend extends KernelBackend {
const {values} = tensorData;

if (values != null) {
// TODO(xing.xu@intel.com): Merge backend_util.BackendValues and
// backend_util.TypedArray.
return this.convertAndCacheOnCPU(
dataId, values as backend_util.TypedArray) as
backend_util.BackendValues;
return this.convertAndCacheOnCPU(dataId, values);
}

// Download the values from the GPU.
let vals: backend_util.BackendValues;
let vals: BackendValues;
if (tensorData.dtype === 'complex64') {
const ps = await Promise.all([
this.read(tensorData.complexTensorInfos.real.dataId),
Expand All @@ -439,8 +434,7 @@ export class WebGPUBackend extends KernelBackend {
} else {
const bufferInfo = tensorData.resourceInfo as BufferInfo;
const data = await this.getBufferData(bufferInfo.buffer, bufferInfo.size);
vals = webgpu_util.ArrayBufferToTypedArray(
data as ArrayBuffer, tensorData.dtype);
vals = util.convertBackendValuesAndArrayBuffer(data, tensorData.dtype);
}
this.convertAndCacheOnCPU(dataId, vals);
return vals;
Expand Down Expand Up @@ -605,13 +599,12 @@ export class WebGPUBackend extends KernelBackend {

makeTensorInfo(
shape: number[], dtype: DataType,
values?: backend_util.BackendValues|string[]): TensorInfo {
values?: BackendValues|string[]): TensorInfo {
if (dtype === 'string' && values != null && values.length > 0 &&
util.isString(values[0])) {
values = (values as unknown as string[]).map(d => util.encodeString(d));
}
const dataId =
this.write(values as backend_util.BackendValues, shape, dtype);
const dataId = this.write(values as BackendValues, shape, dtype);
return {dataId, shape, dtype};
}

Expand Down
13 changes: 6 additions & 7 deletions tfjs-backend-webgpu/src/backend_webgpu_test.ts
Expand Up @@ -22,7 +22,6 @@ const {expectArraysEqual, expectArraysClose} = test_util;

import {WebGPUBackend, WebGPUMemoryInfo} from './backend_webgpu';
import {describeWebGPU} from './test_util';
import * as webgpu_util from './webgpu_util';

describeWebGPU('backend webgpu cpu forwarding turned on', () => {
let cpuForwardFlagSaved: boolean;
Expand Down Expand Up @@ -274,8 +273,8 @@ describeWebGPU('keeping data on gpu ', () => {
`Expected: float32`);
}
const resData = await webGPUBackend.getBufferData(res.buffer, res.bufSize);
const values = webgpu_util.ArrayBufferToTypedArray(
resData as ArrayBuffer, res.tensorRef.dtype);
const values = tf.util.convertBackendValuesAndArrayBuffer(
resData, res.tensorRef.dtype);
expectArraysEqual(values, data);
});

Expand All @@ -294,8 +293,8 @@ describeWebGPU('keeping data on gpu ', () => {
`Expected: float32`);
}
const resData = await webGPUBackend.getBufferData(res.buffer, res.bufSize);
const values = webgpu_util.ArrayBufferToTypedArray(
resData as ArrayBuffer, res.tensorRef.dtype);
const values = tf.util.convertBackendValuesAndArrayBuffer(
resData, res.tensorRef.dtype);
expectArraysEqual(values, data);
});

Expand Down Expand Up @@ -340,8 +339,8 @@ describeWebGPU('keeping data on gpu ', () => {

const res = result as unknown as GPUData;
const resData = await webGPUBackend.getBufferData(res.buffer, res.bufSize);
const values = webgpu_util.ArrayBufferToTypedArray(
resData as ArrayBuffer, res.tensorRef.dtype);
const values = tf.util.convertBackendValuesAndArrayBuffer(
resData, res.tensorRef.dtype);
expectArraysEqual(values, data);
});

Expand Down
12 changes: 0 additions & 12 deletions tfjs-backend-webgpu/src/webgpu_util.ts
Expand Up @@ -154,18 +154,6 @@ export function GPUBytesPerElement(dtype: DataType): number {
}
}

export function ArrayBufferToTypedArray(data: ArrayBuffer, dtype: DataType) {
if (dtype === 'float32') {
return new Float32Array(data);
} else if (dtype === 'int32') {
return new Int32Array(data);
} else if (dtype === 'bool' || dtype === 'string') {
return Uint8Array.from(new Int32Array(data));
} else {
throw new Error(`Unknown dtype ${dtype}`);
}
}

export function isWebGPUSupported(): boolean {
return ((typeof window !== 'undefined') ||
//@ts-ignore
Expand Down
19 changes: 18 additions & 1 deletion tfjs-core/src/util_base.ts
Expand Up @@ -15,7 +15,7 @@
* =============================================================================
*/

import {DataType, DataTypeMap, FlatVector, NumericDataType, RecursiveArray, TensorLike, TypedArray, WebGLData, WebGPUData} from './types';
import {BackendValues, DataType, DataTypeMap, FlatVector, NumericDataType, RecursiveArray, TensorLike, TypedArray, WebGLData, WebGPUData} from './types';

/**
* Shuffles the array in-place using Fisher-Yates algorithm.
Expand Down Expand Up @@ -661,6 +661,23 @@ export function toNestedArray(
return createNestedArray(0, shape, a, isComplex);
}

export function convertBackendValuesAndArrayBuffer(
data: BackendValues|ArrayBuffer, dtype: DataType) {
// If is type Uint8Array[], return it directly.
if (Array.isArray(data)) {
return data;
}
if (dtype === 'float32') {
return data instanceof Float32Array ? data : new Float32Array(data);
} else if (dtype === 'int32') {
return data instanceof Int32Array ? data : new Int32Array(data);
} else if (dtype === 'bool' || dtype === 'string') {
return Uint8Array.from(new Int32Array(data));
} else {
throw new Error(`Unknown dtype ${dtype}`);
}
}

export function makeOnesTypedArray<D extends DataType>(
size: number, dtype: D): DataTypeMap[D] {
const array = makeZerosTypedArray(size, dtype);
Expand Down