Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Queue::write_buffer_with #2777

Merged
merged 4 commits into from Jun 28, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
240 changes: 171 additions & 69 deletions wgpu-core/src/device/queue.rs
Expand Up @@ -8,10 +8,10 @@ use crate::{
conv,
device::{DeviceError, WaitIdleError},
get_lowest_common_denom,
hub::{Global, GlobalIdentityHandlerFactory, HalApi, Token},
hub::{Global, GlobalIdentityHandlerFactory, HalApi, Input, Token},
id,
init_tracker::{has_copy_partial_init_tracker_coverage, TextureInitRange},
resource::{BufferAccessError, BufferMapState, TextureInner},
resource::{BufferAccessError, BufferMapState, StagingBuffer, TextureInner},
track, FastHashSet, SubmissionIndex,
};

Expand Down Expand Up @@ -86,28 +86,6 @@ pub struct WrappedSubmissionIndex {
pub index: SubmissionIndex,
}

struct StagingData<A: hal::Api> {
buffer: A::Buffer,
}

impl<A: hal::Api> StagingData<A> {
unsafe fn write(
&self,
device: &A::Device,
offset: wgt::BufferAddress,
data: &[u8],
) -> Result<(), hal::DeviceError> {
let mapping = device.map_buffer(&self.buffer, offset..offset + data.len() as u64)?;
ptr::copy_nonoverlapping(data.as_ptr(), mapping.ptr.as_ptr(), data.len());
if !mapping.is_coherent {
device
.flush_mapped_ranges(&self.buffer, iter::once(offset..offset + data.len() as u64));
}
device.unmap_buffer(&self.buffer)?;
Ok(())
}
}

#[derive(Debug)]
pub enum TempResource<A: hal::Api> {
Buffer(A::Buffer),
Expand Down Expand Up @@ -178,8 +156,8 @@ impl<A: hal::Api> PendingWrites<A> {
self.temp_resources.push(resource);
}

fn consume(&mut self, stage: StagingData<A>) {
self.temp_resources.push(TempResource::Buffer(stage.buffer));
fn consume(&mut self, buffer: StagingBuffer<A>) {
self.temp_resources.push(TempResource::Buffer(buffer.raw));
}

#[must_use]
Expand Down Expand Up @@ -240,16 +218,38 @@ impl<A: hal::Api> PendingWrites<A> {
}

impl<A: HalApi> super::Device<A> {
fn prepare_stage(&mut self, size: wgt::BufferAddress) -> Result<StagingData<A>, DeviceError> {
profiling::scope!("prepare_stage");
fn prepare_staging_buffer(
&mut self,
size: wgt::BufferAddress,
) -> Result<(StagingBuffer<A>, *mut u8), DeviceError> {
profiling::scope!("prepare_staging_buffer");
let stage_desc = hal::BufferDescriptor {
label: Some("(wgpu internal) Staging"),
size,
usage: hal::BufferUses::MAP_WRITE | hal::BufferUses::COPY_SRC,
memory_flags: hal::MemoryFlags::TRANSIENT,
};

let buffer = unsafe { self.raw.create_buffer(&stage_desc)? };
Ok(StagingData { buffer })
let mapping = unsafe { self.raw.map_buffer(&buffer, 0..size) }?;

let staging_buffer = StagingBuffer {
raw: buffer,
size,
is_coherent: mapping.is_coherent,
};

Ok((staging_buffer, mapping.ptr.as_ptr()))
}
}

impl<A: hal::Api> StagingBuffer<A> {
unsafe fn flush(&self, device: &A::Device) -> Result<(), DeviceError> {
if !self.is_coherent {
device.flush_mapped_ranges(&self.raw, iter::once(0..self.size));
}
device.unmap_buffer(&self.raw)?;
Ok(())
}
}

Expand Down Expand Up @@ -305,30 +305,141 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
.map_err(|_| DeviceError::Invalid)?;
let (buffer_guard, _) = hub.buffers.read(&mut token);

let data_size = data.len() as wgt::BufferAddress;

#[cfg(feature = "trace")]
if let Some(ref trace) = device.trace {
let mut trace = trace.lock();
let data_path = trace.make_binary("bin", data);
trace.add(Action::WriteBuffer {
id: buffer_id,
data: data_path,
range: buffer_offset..buffer_offset + data.len() as wgt::BufferAddress,
range: buffer_offset..buffer_offset + data_size,
queued: true,
});
}

let data_size = data.len() as wgt::BufferAddress;
if data_size == 0 {
log::trace!("Ignoring write_buffer of size 0");
return Ok(());
}

let stage = device.prepare_stage(data_size)?;
let (staging_buffer, staging_buffer_ptr) = device.prepare_staging_buffer(data_size)?;

profiling::scope!("copy");
cwfitzgerald marked this conversation as resolved.
Show resolved Hide resolved
unsafe {
profiling::scope!("copy");
stage.write(&device.raw, 0, data)
ptr::copy_nonoverlapping(data.as_ptr(), staging_buffer_ptr, data.len());
staging_buffer.flush(&device.raw)?;
};

let mut trackers = device.trackers.lock();
let (dst, transition) = trackers
.buffers
.set_single(&*buffer_guard, buffer_id, hal::BufferUses::COPY_DST)
.ok_or(TransferError::InvalidBuffer(buffer_id))?;
let dst_raw = dst
.raw
.as_ref()
.ok_or(TransferError::InvalidBuffer(buffer_id))?;
if !dst.usage.contains(wgt::BufferUsages::COPY_DST) {
return Err(TransferError::MissingCopyDstUsageFlag(Some(buffer_id), None).into());
}
.map_err(DeviceError::from)?;
dst.life_guard.use_at(device.active_submission_index + 1);

if data_size % wgt::COPY_BUFFER_ALIGNMENT != 0 {
return Err(TransferError::UnalignedCopySize(data_size).into());
}
if buffer_offset % wgt::COPY_BUFFER_ALIGNMENT != 0 {
return Err(TransferError::UnalignedBufferOffset(buffer_offset).into());
}
if buffer_offset + data_size > dst.size {
return Err(TransferError::BufferOverrun {
start_offset: buffer_offset,
end_offset: buffer_offset + data_size,
buffer_size: dst.size,
side: CopySide::Destination,
}
.into());
}

let region = wgt::BufferSize::new(data_size).map(|size| hal::BufferCopy {
src_offset: 0,
dst_offset: buffer_offset,
size,
});
let barriers = iter::once(hal::BufferBarrier {
buffer: &staging_buffer.raw,
usage: hal::BufferUses::MAP_WRITE..hal::BufferUses::COPY_SRC,
})
.chain(transition.map(|pending| pending.into_hal(dst)));
let encoder = device.pending_writes.activate();
unsafe {
encoder.transition_buffers(barriers);
encoder.copy_buffer_to_buffer(&staging_buffer.raw, dst_raw, region.into_iter());
}

device.pending_writes.consume(staging_buffer);
device.pending_writes.dst_buffers.insert(buffer_id);

// Ensure the overwritten bytes are marked as initialized so they don't need to be nulled prior to mapping or binding.
{
drop(buffer_guard);
let (mut buffer_guard, _) = hub.buffers.write(&mut token);

let dst = buffer_guard.get_mut(buffer_id).unwrap();
dst.initialization_status
.drain(buffer_offset..(buffer_offset + data_size));
}

Ok(())
}

pub fn queue_create_staging_buffer<A: HalApi>(
teoxoy marked this conversation as resolved.
Show resolved Hide resolved
&self,
queue_id: id::QueueId,
buffer_size: wgt::BufferSize,
id_in: Input<G, id::StagingBufferId>,
) -> Result<(id::StagingBufferId, *mut u8), QueueWriteError> {
let hub = A::hub(self);
let mut token = Token::root();
let fid = hub.staging_buffers.prepare(id_in);
let (mut device_guard, mut token) = hub.devices.write(&mut token);
let device = device_guard
.get_mut(queue_id)
.map_err(|_| DeviceError::Invalid)?;

let data_size = buffer_size.get();

let (staging_buffer, staging_buffer_ptr) = device.prepare_staging_buffer(data_size)?;

let id = fid.assign(staging_buffer, &mut token);
Ok((id.0, staging_buffer_ptr))
}

pub fn queue_write_staging_buffer<A: HalApi>(
&self,
queue_id: id::QueueId,
buffer_id: id::BufferId,
buffer_offset: wgt::BufferAddress,
staging_buffer: id::StagingBufferId,
) -> Result<(), QueueWriteError> {
profiling::scope!("write_buffer_with", "Queue");

let hub = A::hub(self);
let mut token = Token::root();
let (mut device_guard, mut token) = hub.devices.write(&mut token);
let device = device_guard
.get_mut(queue_id)
.map_err(|_| DeviceError::Invalid)?;

let (src_buffer, _) = hub.staging_buffers.unregister(staging_buffer, &mut token);
let src_buffer = src_buffer.ok_or(TransferError::InvalidBuffer(buffer_id))?;

let data_size = src_buffer.size;

unsafe { src_buffer.flush(&device.raw)? };

let (buffer_guard, _) = hub.buffers.read(&mut token);

let mut trackers = device.trackers.lock();
let (dst, transition) = trackers
Expand Down Expand Up @@ -360,23 +471,23 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
.into());
}

let region = wgt::BufferSize::new(data.len() as u64).map(|size| hal::BufferCopy {
let region = wgt::BufferSize::new(data_size).map(|size| hal::BufferCopy {
src_offset: 0,
dst_offset: buffer_offset,
size,
});
let barriers = iter::once(hal::BufferBarrier {
buffer: &stage.buffer,
buffer: &src_buffer.raw,
usage: hal::BufferUses::MAP_WRITE..hal::BufferUses::COPY_SRC,
})
.chain(transition.map(|pending| pending.into_hal(dst)));
let encoder = device.pending_writes.activate();
unsafe {
encoder.transition_buffers(barriers);
encoder.copy_buffer_to_buffer(&stage.buffer, dst_raw, region.into_iter());
encoder.copy_buffer_to_buffer(&src_buffer.raw, dst_raw, region.into_iter());
}

device.pending_writes.consume(stage);
device.pending_writes.consume(src_buffer);
device.pending_writes.dst_buffers.insert(buffer_id);

// Ensure the overwritten bytes are marked as initialized so they don't need to be nulled prior to mapping or binding.
Expand Down Expand Up @@ -469,7 +580,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
let block_rows_in_copy =
(size.depth_or_array_layers - 1) * block_rows_per_image + height_blocks;
let stage_size = stage_bytes_per_row as u64 * block_rows_in_copy as u64;
let stage = device.prepare_stage(stage_size)?;
let (staging_buffer, staging_buffer_ptr) = device.prepare_staging_buffer(stage_size)?;

let dst = texture_guard.get_mut(destination.texture).unwrap();
if !dst.desc.usage.contains(wgt::TextureUsages::COPY_DST) {
Expand Down Expand Up @@ -538,30 +649,30 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
width_blocks * format_desc.block_size as u32
};

let mapping = unsafe { device.raw.map_buffer(&stage.buffer, 0..stage_size) }
.map_err(DeviceError::from)?;
unsafe {
if stage_bytes_per_row == bytes_per_row {
profiling::scope!("copy aligned");
// Fast path if the data is already being aligned optimally.
if stage_bytes_per_row == bytes_per_row {
profiling::scope!("copy aligned");
// Fast path if the data is already being aligned optimally.
unsafe {
ptr::copy_nonoverlapping(
data.as_ptr().offset(data_layout.offset as isize),
mapping.ptr.as_ptr(),
staging_buffer_ptr,
stage_size as usize,
);
} else {
profiling::scope!("copy chunked");
// Copy row by row into the optimal alignment.
let copy_bytes_per_row = stage_bytes_per_row.min(bytes_per_row) as usize;
for layer in 0..size.depth_or_array_layers {
let rows_offset = layer * block_rows_per_image;
for row in 0..height_blocks {
}
} else {
profiling::scope!("copy chunked");
// Copy row by row into the optimal alignment.
let copy_bytes_per_row = stage_bytes_per_row.min(bytes_per_row) as usize;
for layer in 0..size.depth_or_array_layers {
let rows_offset = layer * block_rows_per_image;
for row in 0..height_blocks {
unsafe {
ptr::copy_nonoverlapping(
data.as_ptr().offset(
data_layout.offset as isize
+ (rows_offset + row) as isize * bytes_per_row as isize,
),
mapping.ptr.as_ptr().offset(
staging_buffer_ptr.offset(
(rows_offset + row) as isize * stage_bytes_per_row as isize,
),
copy_bytes_per_row,
Expand All @@ -570,17 +681,8 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
}
}
}
unsafe {
if !mapping.is_coherent {
device
.raw
.flush_mapped_ranges(&stage.buffer, iter::once(0..stage_size));
}
device
.raw
.unmap_buffer(&stage.buffer)
.map_err(DeviceError::from)?;
}

unsafe { staging_buffer.flush(&device.raw) }?;

let regions = (0..array_layer_count).map(|rel_array_layer| {
let mut texture_base = dst_base.clone();
Expand All @@ -598,7 +700,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
}
});
let barrier = hal::BufferBarrier {
buffer: &stage.buffer,
buffer: &staging_buffer.raw,
usage: hal::BufferUses::MAP_WRITE..hal::BufferUses::COPY_SRC,
};

Expand All @@ -611,10 +713,10 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
encoder
.transition_textures(transition.map(|pending| pending.into_hal(dst)).into_iter());
encoder.transition_buffers(iter::once(barrier));
encoder.copy_buffer_to_texture(&stage.buffer, dst_raw, regions);
encoder.copy_buffer_to_texture(&staging_buffer.raw, dst_raw, regions);
}

device.pending_writes.consume(stage);
device.pending_writes.consume(staging_buffer);
device
.pending_writes
.dst_textures
Expand Down