diff --git a/tfjs-backend-webgpu/src/backend_webgpu.ts b/tfjs-backend-webgpu/src/backend_webgpu.ts index 1227f74890..ab78ef7777 100644 --- a/tfjs-backend-webgpu/src/backend_webgpu.ts +++ b/tfjs-backend-webgpu/src/backend_webgpu.ts @@ -17,7 +17,7 @@ import './flags_webgpu'; -import {backend_util, buffer, DataStorage, DataType, engine, env, GPUData, KernelBackend, Rank, RecursiveArray, ShapeMap, Tensor, TensorBuffer, TensorInfo, TimingInfo, TypedArray, util, WebGPUData} from '@tensorflow/tfjs-core'; +import {backend_util, BackendValues, buffer, DataStorage, DataType, engine, env, GPUData, KernelBackend, Rank, RecursiveArray, ShapeMap, Tensor, TensorBuffer, TensorInfo, TimingInfo, TypedArray, util, WebGPUData} from '@tensorflow/tfjs-core'; import {AdapterInfo} from './adapter_info'; import {BufferManager} from './buffer_manager'; @@ -46,7 +46,7 @@ export type TextureInfo = { }; type TensorData = { - values: backend_util.BackendValues, + values: BackendValues, dtype: DataType, shape: number[], refCount: number, @@ -290,9 +290,8 @@ export class WebGPUBackend extends KernelBackend { } } - override write( - values: backend_util.BackendValues, shape: number[], - dtype: DataType): DataId { + override write(values: BackendValues, shape: number[], dtype: DataType): + DataId { if (dtype === 'complex64' && values != null) { throw new Error( `Cannot write to a complex64 dtype. ` + @@ -304,8 +303,8 @@ export class WebGPUBackend extends KernelBackend { } override move( - dataId: DataId, values: backend_util.BackendValues, shape: number[], - dtype: DataType, refCount: number): void { + dataId: DataId, values: BackendValues, shape: number[], dtype: DataType, + refCount: number): void { if (dtype === 'complex64') { throw new Error( `Cannot write to a complex64 dtype. ` + @@ -386,8 +385,8 @@ export class WebGPUBackend extends KernelBackend { return values; } - private convertAndCacheOnCPU(dataId: DataId, data: backend_util.TypedArray): - backend_util.TypedArray { + private convertAndCacheOnCPU(dataId: DataId, data: BackendValues): + BackendValues { const tensorData = this.tensorMap.get(dataId); this.releaseResource(dataId); tensorData.values = data; @@ -396,7 +395,7 @@ export class WebGPUBackend extends KernelBackend { // TODO: Remove once this is fixed: // https://github.com/tensorflow/tfjs/issues/1595 - override readSync(dataId: object): backend_util.BackendValues { + override readSync(dataId: object): BackendValues { const tensorData = this.tensorMap.get(dataId); const {values} = tensorData; @@ -408,7 +407,7 @@ export class WebGPUBackend extends KernelBackend { return values; } - override async read(dataId: object): Promise { + override async read(dataId: object): Promise { if (!this.tensorMap.has(dataId)) { throw new Error(`Tensor ${dataId} was not registered!`); } @@ -417,15 +416,11 @@ export class WebGPUBackend extends KernelBackend { const {values} = tensorData; if (values != null) { - // TODO(xing.xu@intel.com): Merge backend_util.BackendValues and - // backend_util.TypedArray. - return this.convertAndCacheOnCPU( - dataId, values as backend_util.TypedArray) as - backend_util.BackendValues; + return this.convertAndCacheOnCPU(dataId, values); } // Download the values from the GPU. - let vals: backend_util.BackendValues; + let vals: BackendValues; if (tensorData.dtype === 'complex64') { const ps = await Promise.all([ this.read(tensorData.complexTensorInfos.real.dataId), @@ -441,7 +436,7 @@ export class WebGPUBackend extends KernelBackend { const data = await this.getBufferData(bufferInfo.buffer, bufferInfo.size); vals = util.convertBackendValuesAndArrayBuffer(data, tensorData.dtype); } - this.convertAndCacheOnCPU(dataId, vals as backend_util.TypedArray); + this.convertAndCacheOnCPU(dataId, vals); return vals; } @@ -604,13 +599,12 @@ export class WebGPUBackend extends KernelBackend { makeTensorInfo( shape: number[], dtype: DataType, - values?: backend_util.BackendValues|string[]): TensorInfo { + values?: BackendValues|string[]): TensorInfo { if (dtype === 'string' && values != null && values.length > 0 && util.isString(values[0])) { values = (values as unknown as string[]).map(d => util.encodeString(d)); } - const dataId = - this.write(values as backend_util.BackendValues, shape, dtype); + const dataId = this.write(values as BackendValues, shape, dtype); return {dataId, shape, dtype}; }