diff --git a/tfjs-backend-webgl/src/backend_webgl.ts b/tfjs-backend-webgl/src/backend_webgl.ts index b01b295e0f..3048694276 100644 --- a/tfjs-backend-webgl/src/backend_webgl.ts +++ b/tfjs-backend-webgl/src/backend_webgl.ts @@ -1304,8 +1304,9 @@ export class MathBackendWebGL extends KernelBackend { * Create a TF.js tensor out of an existing WebGL texture. A new texture will * be created. */ - override createTensorFromTexture(values: WebGLData, shape: number[], - dtype: DataType): Tensor { + override createTensorFromGPUData( + values: WebGLData, shape: number[], dtype: DataType): Tensor { + values.channels = values.channels || 'RGBA'; const {texture, height, width, channels} = values; const backend = engine().backend as MathBackendWebGL; diff --git a/tfjs-backend-webgpu/src/backend_webgpu.ts b/tfjs-backend-webgpu/src/backend_webgpu.ts index 1b296a6b8d..fbe1c7ebf2 100644 --- a/tfjs-backend-webgpu/src/backend_webgpu.ts +++ b/tfjs-backend-webgpu/src/backend_webgpu.ts @@ -17,7 +17,7 @@ import './flags_webgpu'; -import {backend_util, buffer, DataStorage, DataType, engine, env, GPUData, KernelBackend, Rank, RecursiveArray, ShapeMap, TensorBuffer, TensorInfo, TimingInfo, TypedArray, util} from '@tensorflow/tfjs-core'; +import {backend_util, buffer, DataStorage, DataType, engine, env, GPUData, KernelBackend, Rank, RecursiveArray, ShapeMap, Tensor, TensorBuffer, TensorInfo, TimingInfo, TypedArray, util, WebGPUData} from '@tensorflow/tfjs-core'; import {AdapterInfo} from './adapter_info'; import {BufferManager} from './buffer_manager'; @@ -51,6 +51,9 @@ type TensorData = { shape: number[], refCount: number, resourceInfo?: BufferInfo|TextureInfo, + // external is true means we use the resource provided by users directly + // (without a copy), so users should be responsible for its release. + external?: boolean, // For complex numbers, the real and imaginary parts are stored as their own // individual tensors, with a parent joining the two with the // complexTensorInfos field. @@ -242,6 +245,11 @@ export class WebGPUBackend extends KernelBackend { if (!tensorData || !tensorData.resourceInfo) { return; } + // If tensor's resource is from external, do not release. + if (tensorData.external) { + tensorData.resourceInfo = null; + return; + } if ('texture' in tensorData.resourceInfo) { const textureInfo = tensorData.resourceInfo; if (textureInfo.texture instanceof GPUTexture) { @@ -282,7 +290,8 @@ export class WebGPUBackend extends KernelBackend { } } - override write(values: backend_util.BackendValues, shape: number[], + override write( + values: backend_util.BackendValues, shape: number[], dtype: DataType): DataId { if (dtype === 'complex64' && values != null) { throw new Error( @@ -437,6 +446,53 @@ export class WebGPUBackend extends KernelBackend { return vals; } + // The source GPUBuffer and destination GPUBuffer have the same size and + // usage. + private copyBuffer(srcBuffer: GPUBuffer, size: number, usage: number) { + const dstBuffer = this.bufferManager.acquireBuffer(size, usage); + this.ensureCommandEncoderReady(); + this.ensureComputePassEnded(); + this.currentCommandEncoder.copyBufferToBuffer( + srcBuffer, 0, dstBuffer, 0, size); + this.submitQueue(); + return dstBuffer; + } + + /** + * Create a TF.js tensor out of an existing WebGPU buffer. + */ + override createTensorFromGPUData( + values: WebGPUData, shape: number[], dtype: DataType): Tensor { + let buffer = values.buffer; + if (dtype === 'complex64') { + throw new Error(`Cannot write to a complex64 dtype. `); + } + const dataId = {id: this.nextDataId()}; + this.tensorMap.set( + dataId, + {dtype, shape, values: null, refCount: 1, external: values.zeroCopy}); + const tensorData = this.tensorMap.get(dataId); + const size = webgpu_util.GPUBytesPerElement(tensorData.dtype) * + util.sizeFromShape(tensorData.shape); + if (values.buffer.size < size) { + throw new Error(`GPUBuffer size(${ + values.buffer.size}) is smaller than tensor size(${size})!`); + } else if ( + (values.buffer.usage & + (GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC)) !== + (GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC)) { + throw new Error( + 'GPUBuffer.usage should include GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC!'); + } + + // Do buffer copy by default. + if (values.zeroCopy !== true) { + buffer = this.copyBuffer(buffer, size, buffer.usage); + } + tensorData.resourceInfo = {size: buffer.size, usage: buffer.usage, buffer}; + return engine().makeTensorFromDataId(dataId, shape, dtype, this); + } + /** * Read tensor to a new GPUBuffer. * @param dataId The source tensor. @@ -623,9 +679,8 @@ export class WebGPUBackend extends KernelBackend { // TODO: WebGPU doesn't support read data synchronously from GPU to CPU. // So it will report error when switching backend from WebGPU to others. // There are two situations: 1) swithcing the backend after running a - // model; 2) swithcing the backend within the model. Temporarilly keep the - // values on CPU to solve the first issue. - // tensorData.values = null; + // model; 2) swithcing the backend within the model. Temporarilly keep + // the values on CPU to solve the first issue. tensorData.values = null; } } diff --git a/tfjs-backend-webgpu/src/backend_webgpu_test.ts b/tfjs-backend-webgpu/src/backend_webgpu_test.ts index c34c9e69d7..d41b241ba3 100644 --- a/tfjs-backend-webgpu/src/backend_webgpu_test.ts +++ b/tfjs-backend-webgpu/src/backend_webgpu_test.ts @@ -366,3 +366,207 @@ describeWebGPU('keeping data on gpu ', () => { expect(endDataBuckets).toEqual(startDataBuckets + 1); }); }); + +function createStagingGPUBufferFromData( + device: GPUDevice, data: number[], dtype: tf.DataType) { + const bytesPerElement = 4; + const sizeInBytes = data.length * bytesPerElement; + + const gpuWriteBuffer = device.createBuffer({ + mappedAtCreation: true, + size: sizeInBytes, + usage: GPUBufferUsage.MAP_WRITE | GPUBufferUsage.COPY_SRC + }); + const arrayBuffer = gpuWriteBuffer.getMappedRange(); + if (dtype === 'float32') { + new Float32Array(arrayBuffer).set(data); + } else if (dtype === 'int32') { + new Int32Array(arrayBuffer).set(data); + } else { + throw new Error( + `Creating tensor from GPUBuffer only supports` + + `'float32'|'int32' dtype, while the dtype is ${dtype}.`); + } + gpuWriteBuffer.unmap(); + return gpuWriteBuffer; +} + +function createGPUBufferFromData( + device: GPUDevice, data: number[], dtype: tf.DataType, + bufferUsage = GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE | + GPUBufferUsage.COPY_SRC) { + const bytesPerElement = 4; + const sizeInBytes = data.length * bytesPerElement; + + const gpuWriteBuffer = createStagingGPUBufferFromData(device, data, dtype); + const gpuReadBuffer = device.createBuffer( + {mappedAtCreation: false, size: sizeInBytes, usage: bufferUsage}); + + const copyEncoder = device.createCommandEncoder(); + copyEncoder.copyBufferToBuffer( + gpuWriteBuffer, 0, gpuReadBuffer, 0, sizeInBytes); + const copyCommands = copyEncoder.finish(); + device.queue.submit([copyCommands]); + gpuWriteBuffer.destroy(); + return gpuReadBuffer; +} + +async function testCreateTensorFromGPUBuffer( + dtype: tf.DataType, useDefaultShapeAndType = false, zeroCopy = false) { + const webGPUBackend = tf.backend() as WebGPUBackend; + const device = webGPUBackend.device; + const aData = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; + const bData = [1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4]; + const expected = [2, 4, 6, 8, 6, 8, 10, 12, 10, 12, 14, 16, 14, 16, 18, 20]; + const aBuffer = createGPUBufferFromData(device, aData, dtype); + const shape: number[] = [aData.length]; + const startNumBytes = tf.memory().numBytes; + const startNumTensors = tf.memory().numTensors; + const webGPUData = {buffer: aBuffer, zeroCopy}; + const a = useDefaultShapeAndType ? tf.tensor(webGPUData) : + tf.tensor(webGPUData, shape, dtype); + if (zeroCopy !== true) { + aBuffer.destroy(); + } + const b = tf.tensor(bData, shape, dtype); + const result = tf.add(a, b); + tf.test_util.expectArraysClose(await result.data(), expected); + a.dispose(); + b.dispose(); + result.dispose(); + const endNumBytes = tf.memory().numBytes; + const endNumTensors = tf.memory().numTensors; + expect(endNumBytes - startNumBytes).toEqual(0); + expect(endNumTensors - startNumTensors).toEqual(0); + if (zeroCopy === true) { + aBuffer.destroy(); + } +} + +function createTensorFromGPUTest(zeroCopy = false) { + it('use default shape and data type(float32)', async () => { + await testCreateTensorFromGPUBuffer('float32', true, zeroCopy); + }); + + it('work for float32', async () => { + await testCreateTensorFromGPUBuffer('float32', false, zeroCopy); + }); + + it('work for int32', async () => { + await testCreateTensorFromGPUBuffer('int32', false, zeroCopy); + }); + + it('work for read', async () => { + const webGPUBackend = tf.backend() as WebGPUBackend; + const device = webGPUBackend.device; + const aData = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; + const dtype = 'float32'; + const aBuffer = createGPUBufferFromData(device, aData, dtype); + const shape: number[] = [aData.length]; + const a = tf.tensor({buffer: aBuffer, zeroCopy}, shape, dtype); + if (zeroCopy !== true) { + aBuffer.destroy(); + } + await a.data(); + if (zeroCopy === true) { + aBuffer.destroy(); + } + }); + + it('two tensors share the same GPUBuffer', async () => { + const webGPUBackend = tf.backend() as WebGPUBackend; + const device = webGPUBackend.device; + const aData = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; + const dtype = 'float32'; + const aBuffer = createGPUBufferFromData(device, aData, dtype); + const startNumBytes = tf.memory().numBytes; + const startNumTensors = tf.memory().numTensors; + const shape: number[] = [aData.length]; + const webGPUData = {buffer: aBuffer, zeroCopy}; + const a = tf.tensor(webGPUData, shape, dtype); + const b = tf.tensor(webGPUData, shape, dtype); + if (zeroCopy !== true) { + aBuffer.destroy(); + } + const result = tf.add(a, b); + const expected = + [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32]; + tf.test_util.expectArraysClose(await result.data(), expected); + a.dispose(); + b.dispose(); + result.dispose(); + const endNumBytes = tf.memory().numBytes; + const endNumTensors = tf.memory().numTensors; + expect(endNumBytes - startNumBytes).toEqual(0); + expect(endNumTensors - startNumTensors).toEqual(0); + if (zeroCopy === true) { + aBuffer.destroy(); + } + }); + + it('GPUBuffer size is bigger than tensor size', async () => { + const webGPUBackend = tf.backend() as WebGPUBackend; + const device = webGPUBackend.device; + const aData = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; + const dtype = 'float32'; + const aBuffer = createGPUBufferFromData(device, aData, dtype); + const startNumBytes = tf.memory().numBytes; + const startNumTensors = tf.memory().numTensors; + // GPUBuffer.size is bigger than shape size + const shape: number[] = [aData.length - 1]; + const webGPUData = {buffer: aBuffer, zeroCopy}; + const a = tf.tensor(webGPUData, shape, dtype); + const b = tf.tensor(webGPUData, shape, dtype); + if (zeroCopy !== true) { + aBuffer.destroy(); + } + const result = tf.add(a, b); + const expected = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30]; + tf.test_util.expectArraysClose(await result.data(), expected); + a.dispose(); + b.dispose(); + result.dispose(); + const endNumBytes = tf.memory().numBytes; + const endNumTensors = tf.memory().numTensors; + expect(endNumBytes - startNumBytes).toEqual(0); + expect(endNumTensors - startNumTensors).toEqual(0); + if (zeroCopy === true) { + aBuffer.destroy(); + } + }); + + it('throw when GPUBuffer size is smaller than tensor size', async () => { + const webGPUBackend = tf.backend() as WebGPUBackend; + const device = webGPUBackend.device; + const aData = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; + const dtype = 'float32'; + const aBuffer = createGPUBufferFromData(device, aData, dtype); + // Throw when GPUBuffer.size is smaller than shape size + const shape: number[] = [aData.length + 1]; + const a = () => tf.tensor({buffer: aBuffer}, shape, dtype); + expect(a).toThrowError(); + aBuffer.destroy(); + }); + + it('throw when GPUBuffer usage is not correct', async () => { + const webGPUBackend = tf.backend() as WebGPUBackend; + const device = webGPUBackend.device; + const aData = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; + const dtype = 'float32'; + // Create a GPUBuffer without GPUBufferUsage.STORAGE. + const aBuffer = createStagingGPUBufferFromData(device, aData, dtype); + // Throw when GPUBuffer usage is not correct. + const shape: number[] = [aData.length]; + const a = () => tf.tensor({buffer: aBuffer, zeroCopy}, shape, dtype); + expect(a).toThrowError(); + aBuffer.destroy(); + }); +} + +describeWebGPU('create tensor from GPUBuffer', () => { + createTensorFromGPUTest(); +}); + +describeWebGPU('create tensor from GPUBuffer with zero copy', () => { + createTensorFromGPUTest(true); +}); diff --git a/tfjs-backend-webgpu/src/flags_webgpu.ts b/tfjs-backend-webgpu/src/flags_webgpu.ts index f639b67e4d..49ab70ffba 100644 --- a/tfjs-backend-webgpu/src/flags_webgpu.ts +++ b/tfjs-backend-webgpu/src/flags_webgpu.ts @@ -76,8 +76,7 @@ ENV.registerFlag('WEBGPU_USE_NAIVE_CONV2D_DEBUG', () => false); * are dispatched, it means the hardware may be in low occupancy. * 0 means it's not set by the user. A default strategy will be applied. */ -ENV.registerFlag( - 'WEBGPU_THRESHOLD_TO_INCREASE_WORKGROUPS_FOR_MATMUL', () => 0); +ENV.registerFlag('WEBGPU_THRESHOLD_TO_INCREASE_WORKGROUPS_FOR_MATMUL', () => 0); /** * Whether we will run im2col as a separate shader for convolution. diff --git a/tfjs-core/src/backends/backend.ts b/tfjs-core/src/backends/backend.ts index 33626bd31d..6f23ec0a3e 100644 --- a/tfjs-core/src/backends/backend.ts +++ b/tfjs-core/src/backends/backend.ts @@ -17,7 +17,7 @@ import {Backend, DataToGPUOptions, GPUData, Tensor} from '../tensor'; import {DataId} from '../tensor_info'; -import {BackendValues, DataType, WebGLData} from '../types'; +import {BackendValues, DataType, WebGLData, WebGPUData} from '../types'; export const EPSILON_FLOAT32 = 1e-7; export const EPSILON_FLOAT16 = 1e-4; @@ -133,10 +133,12 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer { refCount: number): void { return notYetImplemented('move'); } - createTensorFromTexture(values: WebGLData, shape: number[], dtype: DataType): - Tensor { - return notYetImplemented('createTensorFromTexture'); + + createTensorFromGPUData( + values: WebGLData|WebGPUData, shape: number[], dtype: DataType): Tensor { + return notYetImplemented('createTensorFromGPUData'); } + memory(): {unreliable: boolean; reasons?: string[]} { return notYetImplemented('memory'); } diff --git a/tfjs-core/src/base.ts b/tfjs-core/src/base.ts index c8a5ef7419..e869ede5d2 100644 --- a/tfjs-core/src/base.ts +++ b/tfjs-core/src/base.ts @@ -55,7 +55,7 @@ export {RMSPropOptimizer} from './optimizers/rmsprop_optimizer'; export {SGDOptimizer} from './optimizers/sgd_optimizer'; export {DataToGPUOptions, DataToGPUWebGLOption, GPUData, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer, Variable} from './tensor'; export {GradSaveFunc, NamedTensorMap, TensorContainer, TensorContainerArray, TensorContainerObject} from './tensor_types'; -export {BackendValues, DataType, DataTypeMap, DataValues, NumericDataType, PixelData, Rank, RecursiveArray, ScalarLike, ShapeMap, sumOutType, TensorLike, TypedArray, upcastType, WebGLData} from './types'; +export {BackendValues, DataType, DataTypeMap, DataValues, NumericDataType, PixelData, Rank, RecursiveArray, ScalarLike, ShapeMap, sumOutType, TensorLike, TypedArray, upcastType, WebGLData, WebGPUData} from './types'; export * from './ops/ops'; export {Reduction} from './ops/loss_ops_utils'; diff --git a/tfjs-core/src/ops/tensor.ts b/tfjs-core/src/ops/tensor.ts index cf9933251c..1a6be14cba 100644 --- a/tfjs-core/src/ops/tensor.ts +++ b/tfjs-core/src/ops/tensor.ts @@ -18,7 +18,7 @@ import {Tensor} from '../tensor'; import {inferShape} from '../tensor_util_env'; import {TensorLike} from '../types'; -import {DataType, Rank, ShapeMap, WebGLData} from '../types'; +import {DataType, Rank, ShapeMap, WebGLData, WebGPUData} from '../types'; import {makeTensor} from './tensor_ops_util'; @@ -92,20 +92,97 @@ import {makeTensor} from './tensor_ops_util'; * * const tex = a.dataToGPU(); * ``` + * + * ```js + * // Pass a `WebGPUData` object and specify a shape yourself. + * + * // This makes it possible for TF.js applications to avoid GPU / CPU sync. + * // For example, if your application includes a preprocessing step on the GPU, + * // you could upload the GPU output directly to TF.js, rather than first + * // downloading the values. Unlike WebGL, this optionally supports zero copy + * // by WebGPUData.zeroCopy. When zeroCopy is false or undefined(default), this + * // passing GPUBuffer can be destroyed after tensor is created. When zeroCopy + * // is true, this GPUBuffer is bound directly by the tensor, so do not destroy + * // this GPUBuffer until all access is done. + * + * // Example for WebGPU: + * function createGPUBufferFromData(device, data, dtype) { + * const bytesPerElement = 4; + * const sizeInBytes = data.length * bytesPerElement; + * + * const gpuWriteBuffer = device.createBuffer({ + * mappedAtCreation: true, + * size: sizeInBytes, + * usage: GPUBufferUsage.MAP_WRITE | GPUBufferUsage.COPY_SRC + * }); + * const arrayBuffer = gpuWriteBuffer.getMappedRange(); + * if (dtype === 'float32') { + * new Float32Array(arrayBuffer).set(data); + * } else if (dtype === 'int32') { + * new Int32Array(arrayBuffer).set(data); + * } else { + * throw new Error( + * `Creating tensor from GPUBuffer only supports` + + * `'float32'|'int32' dtype, while the dtype is ${dtype}.`); + * } + * gpuWriteBuffer.unmap(); + * + * const gpuReadBuffer = device.createBuffer({ + * mappedAtCreation: false, + * size: sizeInBytes, + * usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE | + * GPUBufferUsage.COPY_SRC + * }); + * + * const copyEncoder = device.createCommandEncoder(); + * copyEncoder.copyBufferToBuffer( + * gpuWriteBuffer, 0, gpuReadBuffer, 0, sizeInBytes); + * const copyCommands = copyEncoder.finish(); + * device.queue.submit([copyCommands]); + * gpuWriteBuffer.destroy(); + * return gpuReadBuffer; + * } + * + * const dtype = 'float32'; + * const device = tf.backend().device; + * const aData = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; + * const bData = [1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4]; + * const expected = [2, 4, 6, 8, 6, 8, 10, 12, 10, 12, 14, 16, 14, 16, 18, 20]; + * const aBuffer = createGPUBufferFromData(device, aData, dtype); + * const shape = [aData.length]; + * // To use zeroCopy, use {buffer: aBuffer, zeroCopy: true} instead and destroy + * // aBuffer untill all access is done. + * const a = tf.tensor({buffer: aBuffer}, shape, dtype); + * const b = tf.tensor(bData, shape, dtype); + * const result = tf.add(a, b); + * a.dispose(); + * b.dispose(); + * result.dispose(); + * aBuffer.destroy(); + * ``` * @param values The values of the tensor. Can be nested array of numbers, - * or a flat array, or a `TypedArray`, or a `WebGLData` object. If the - * values are strings, they will be encoded as utf-8 and kept as `Uint8Array[]`. - * If the values is a `WebGLData` object, the dtype could only be 'float32' or - * 'int32' and the object has to have: 1. texture, a `WebGLTexture`, the texture - * must share the same `WebGLRenderingContext` with TFJS's WebGL backend (you - * could create a custom WebGL backend from your texture's canvas) and the - * internal texture format for the input texture must be floating point or - * normalized integer; 2. height, the height of the texture; 3. width, the width - * of the texture; 4. channels, a non-empty subset of 'RGBA', indicating the - * values of which channels will be passed to the tensor, such as 'R' or 'BR' - * (The order of the channels affect the order of tensor values. ). (If the - * values passed from texture is less than the tensor size, zeros will be padded - * at the rear.) + * or a flat array, or a `TypedArray`, or a `WebGLData` object, or a + * `WebGPUData` object. If the values are strings, they will be encoded as utf-8 + * and kept as `Uint8Array[]`. If the values is a `WebGLData` object, the dtype + * could only be 'float32' or 'int32' and the object has to have: 1. texture, a + * `WebGLTexture`, the texture must share the same `WebGLRenderingContext` with + * TFJS's WebGL backend (you could create a custom WebGL backend from your + * texture's canvas) and the internal texture format for the input texture must + * be floating point or normalized integer; 2. height, the height of the + * texture; 3. width, the width of the texture; 4. channels, a non-empty subset + * of 'RGBA', indicating the values of which channels will be passed to the + * tensor, such as 'R' or 'BR' (The order of the channels affect the order of + * tensor values. ). (If the values passed from texture is less than the tensor + * size, zeros will be padded at the rear.). If the values is a `WebGPUData` + * object, the dtype could only be 'float32' or 'int32 and the object has to + * have: buffer, a `GPUBuffer`. The buffer must: 1. share the same `GPUDevice` + * with TFJS's WebGPU backend; 2. buffer.usage should at least support + * GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC; 3. buffer.size should not + * be smaller than the byte size of tensor shape. WebGPUData optionally supports + * zero copy by flag zeroCopy. When zeroCopy is false or undefined(default), + * this passing GPUBuffer can be destroyed after tensor is created. When + * zeroCopy is true, this GPUBuffer is bound directly by the tensor, so do not + * destroy this GPUBuffer until all access is done. * @param shape The shape of the tensor. Optional. If not provided, * it is inferred from `values`. * @param dtype The data type. @@ -113,7 +190,7 @@ import {makeTensor} from './tensor_ops_util'; * @doc {heading: 'Tensors', subheading: 'Creation'} */ export function tensor( - values: TensorLike|WebGLData, shape?: ShapeMap[R], + values: TensorLike|WebGLData|WebGPUData, shape?: ShapeMap[R], dtype?: DataType): Tensor { const inferredShape = inferShape(values, dtype); return makeTensor(values, shape, inferredShape, dtype) as Tensor; diff --git a/tfjs-core/src/ops/tensor_ops_util.ts b/tfjs-core/src/ops/tensor_ops_util.ts index 1b497d4ceb..197ccf1e30 100644 --- a/tfjs-core/src/ops/tensor_ops_util.ts +++ b/tfjs-core/src/ops/tensor_ops_util.ts @@ -17,32 +17,34 @@ import {ENGINE} from '../engine'; import {Tensor} from '../tensor'; -import {TensorLike, TypedArray, WebGLData} from '../types'; +import {TensorLike, TypedArray, WebGLData, WebGPUData} from '../types'; import {DataType} from '../types'; import {assert, assertNonNegativeIntegerDimensions, flatten, inferDtype, isTypedArray, sizeFromShape, toTypedArray} from '../util'; /** This is shared code across all tensor creation methods. */ export function makeTensor( - values: TensorLike|WebGLData, shape: number[], inferredShape: number[], - dtype?: DataType): Tensor { + values: TensorLike|WebGLData|WebGPUData, shape: number[], + inferredShape: number[], dtype?: DataType): Tensor { if (dtype == null) { dtype = inferDtype(values); - } - if (dtype === 'complex64') { + } else if (dtype === 'complex64') { throw new Error( `Cannot construct a complex64 tensor directly. ` + `Please use tf.complex(real, imag).`); } - if (typeof values === 'object' && 'texture' in values) { + + if (typeof values === 'object' && + ('texture' in values || + ('buffer' in values && !(values.buffer instanceof ArrayBuffer)))) { if (dtype !== 'float32' && dtype !== 'int32') { throw new Error( - `Creating tensor from texture only supports ` + + `Creating tensor from GPU data only supports ` + `'float32'|'int32' dtype, while the dtype is ${dtype}.`); } - values.channels = values.channels || 'RGBA'; - return ENGINE.backend.createTensorFromTexture( - values, shape || inferredShape, dtype); + return ENGINE.backend.createTensorFromGPUData( + values as WebGLData | WebGPUData, shape || inferredShape, dtype); } + if (!isTypedArray(values) && !Array.isArray(values) && typeof values !== 'number' && typeof values !== 'boolean' && typeof values !== 'string') { diff --git a/tfjs-core/src/tensor_util_env.ts b/tfjs-core/src/tensor_util_env.ts index 139257d491..e7be429742 100644 --- a/tfjs-core/src/tensor_util_env.ts +++ b/tfjs-core/src/tensor_util_env.ts @@ -18,19 +18,25 @@ import {ENGINE} from './engine'; import {env} from './environment'; import {Tensor} from './tensor'; -import {DataType, TensorLike, WebGLData} from './types'; +import {DataType, TensorLike, WebGLData, WebGPUData} from './types'; import {assert, flatten, inferDtype, isTypedArray, toTypedArray} from './util'; +import {bytesPerElement} from './util_base'; export function inferShape( - val: TensorLike|WebGLData, dtype?: DataType): number[] { + val: TensorLike|WebGLData|WebGPUData, dtype?: DataType): number[] { let firstElem: typeof val = val; if (isTypedArray(val)) { return dtype === 'string' ? [] : [val.length]; } - if (typeof val === 'object' && 'texture' in val) { - const usedChannels = val.channels || 'RGBA'; - return [val.height, val.width * usedChannels.length]; + const isObject = typeof val === 'object'; + if (isObject) { + if ('texture' in val) { + const usedChannels = val.channels || 'RGBA'; + return [val.height, val.width * usedChannels.length]; + } else if ('buffer' in val && !(val.buffer instanceof ArrayBuffer)) { + return [val.buffer.size / (dtype == null ? 4 : bytesPerElement(dtype))]; + } } if (!Array.isArray(val)) { return []; // Scalar. diff --git a/tfjs-core/src/types.ts b/tfjs-core/src/types.ts index 7e416e3c81..2d3fe88dda 100644 --- a/tfjs-core/src/types.ts +++ b/tfjs-core/src/types.ts @@ -182,3 +182,17 @@ export interface WebGLData { width: number; channels: WebGLChannels; } + +/** + * Type for representing a buffer data to create a tensor. Buffer usage should + * at least support GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC. When + * zeroCopy is false or undefined (default), this GPUBuffer will be copied to + * the tensor's resource buffer. When zeroCopy is true, tensor will use this + * GPUBuffer as tensor's resource buffer, user should not destroy this GPUBuffer + * until all access is done. If not specified at creating a tensor, tensor type + * is float32. + */ +export interface WebGPUData { + buffer: GPUBuffer; + zeroCopy?: boolean; +} diff --git a/tfjs-core/src/util_base.ts b/tfjs-core/src/util_base.ts index f4a6f32d22..132cc713d3 100644 --- a/tfjs-core/src/util_base.ts +++ b/tfjs-core/src/util_base.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {DataType, DataTypeMap, FlatVector, NumericDataType, RecursiveArray, TensorLike, TypedArray, WebGLData} from './types'; +import {DataType, DataTypeMap, FlatVector, NumericDataType, RecursiveArray, TensorLike, TypedArray, WebGLData, WebGPUData} from './types'; /** * Shuffles the array in-place using Fisher-Yates algorithm. @@ -559,7 +559,7 @@ export function isNumber(value: {}): boolean { return typeof value === 'number'; } -export function inferDtype(values: TensorLike|WebGLData): DataType { +export function inferDtype(values: TensorLike|WebGLData|WebGPUData): DataType { if (Array.isArray(values)) { return inferDtype(values[0]); }