Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging committed Dec 7, 2022
1 parent 69a434e commit a74bd16
Showing 1 changed file with 15 additions and 21 deletions.
36 changes: 15 additions & 21 deletions tfjs-backend-webgpu/src/backend_webgpu.ts
Expand Up @@ -17,7 +17,7 @@

import './flags_webgpu';

import {backend_util, buffer, DataStorage, DataType, engine, env, GPUData, KernelBackend, Rank, RecursiveArray, ShapeMap, Tensor, TensorBuffer, TensorInfo, TimingInfo, TypedArray, util, WebGPUData} from '@tensorflow/tfjs-core';
import {backend_util, BackendValues, buffer, DataStorage, DataType, engine, env, GPUData, KernelBackend, Rank, RecursiveArray, ShapeMap, Tensor, TensorBuffer, TensorInfo, TimingInfo, TypedArray, util, WebGPUData} from '@tensorflow/tfjs-core';

import {AdapterInfo} from './adapter_info';
import {BufferManager} from './buffer_manager';
Expand Down Expand Up @@ -46,7 +46,7 @@ export type TextureInfo = {
};

type TensorData = {
values: backend_util.BackendValues,
values: BackendValues,
dtype: DataType,
shape: number[],
refCount: number,
Expand Down Expand Up @@ -290,9 +290,8 @@ export class WebGPUBackend extends KernelBackend {
}
}

override write(
values: backend_util.BackendValues, shape: number[],
dtype: DataType): DataId {
override write(values: BackendValues, shape: number[], dtype: DataType):
DataId {
if (dtype === 'complex64' && values != null) {
throw new Error(
`Cannot write to a complex64 dtype. ` +
Expand All @@ -304,8 +303,8 @@ export class WebGPUBackend extends KernelBackend {
}

override move(
dataId: DataId, values: backend_util.BackendValues, shape: number[],
dtype: DataType, refCount: number): void {
dataId: DataId, values: BackendValues, shape: number[], dtype: DataType,
refCount: number): void {
if (dtype === 'complex64') {
throw new Error(
`Cannot write to a complex64 dtype. ` +
Expand Down Expand Up @@ -386,8 +385,8 @@ export class WebGPUBackend extends KernelBackend {
return values;
}

private convertAndCacheOnCPU(dataId: DataId, data: backend_util.TypedArray):
backend_util.TypedArray {
private convertAndCacheOnCPU(dataId: DataId, data: BackendValues):
BackendValues {
const tensorData = this.tensorMap.get(dataId);
this.releaseResource(dataId);
tensorData.values = data;
Expand All @@ -396,7 +395,7 @@ export class WebGPUBackend extends KernelBackend {

// TODO: Remove once this is fixed:
// https://github.com/tensorflow/tfjs/issues/1595
override readSync(dataId: object): backend_util.BackendValues {
override readSync(dataId: object): BackendValues {
const tensorData = this.tensorMap.get(dataId);
const {values} = tensorData;

Expand All @@ -408,7 +407,7 @@ export class WebGPUBackend extends KernelBackend {
return values;
}

override async read(dataId: object): Promise<backend_util.BackendValues> {
override async read(dataId: object): Promise<BackendValues> {
if (!this.tensorMap.has(dataId)) {
throw new Error(`Tensor ${dataId} was not registered!`);
}
Expand All @@ -417,15 +416,11 @@ export class WebGPUBackend extends KernelBackend {
const {values} = tensorData;

if (values != null) {
// TODO(xing.xu@intel.com): Merge backend_util.BackendValues and
// backend_util.TypedArray.
return this.convertAndCacheOnCPU(
dataId, values as backend_util.TypedArray) as
backend_util.BackendValues;
return this.convertAndCacheOnCPU(dataId, values);
}

// Download the values from the GPU.
let vals: backend_util.BackendValues;
let vals: BackendValues;
if (tensorData.dtype === 'complex64') {
const ps = await Promise.all([
this.read(tensorData.complexTensorInfos.real.dataId),
Expand All @@ -441,7 +436,7 @@ export class WebGPUBackend extends KernelBackend {
const data = await this.getBufferData(bufferInfo.buffer, bufferInfo.size);
vals = util.convertBackendValuesAndArrayBuffer(data, tensorData.dtype);
}
this.convertAndCacheOnCPU(dataId, vals as backend_util.TypedArray);
this.convertAndCacheOnCPU(dataId, vals);
return vals;
}

Expand Down Expand Up @@ -604,13 +599,12 @@ export class WebGPUBackend extends KernelBackend {

makeTensorInfo(
shape: number[], dtype: DataType,
values?: backend_util.BackendValues|string[]): TensorInfo {
values?: BackendValues|string[]): TensorInfo {
if (dtype === 'string' && values != null && values.length > 0 &&
util.isString(values[0])) {
values = (values as unknown as string[]).map(d => util.encodeString(d));
}
const dataId =
this.write(values as backend_util.BackendValues, shape, dtype);
const dataId = this.write(values as BackendValues, shape, dtype);
return {dataId, shape, dtype};
}

Expand Down

0 comments on commit a74bd16

Please sign in to comment.