Skip to content

Commit

Permalink
Impl a lifetime-relaxed broadcast for ArrayView
Browse files Browse the repository at this point in the history
ArrayView::broadcast has a lifetime that depends on &self instead of its
internal buffer. This prevents writing some types of functions in an
allocation-free way. For instance, take the numpy `meshgrid` function:
It could be implemented like so:

```rust
fn meshgrid_2d<'a, 'b>(coords_x: ArrayView1<'a, X>, coords_y: ArrayView1<'b, X>) -> (ArrayView2<'a, X>, ArrayView2<'b, X>) {
    let x_len = coords_x.shape()[0];
    let y_len = coords_y.shape()[0];

    let coords_x_s = coords_x.into_shape((1, y_len)).unwrap();
    let coords_x_b = coords_x_s.broadcast((x_len, y_len)).unwrap();
    let coords_y_s = coords_y.into_shape((x_len, 1)).unwrap();
    let coords_y_b = coords_y_s.broadcast((x_len, y_len)).unwrap();

    (coords_x_b, coords_y_b)
}
```

Unfortunately, this doesn't work, because `coords_x_b` is bound to the
lifetime of `coord_x_s`, instead of being bound to 'a.

This commit introduces a new function, broadcast_ref, that does just
that.
  • Loading branch information
roblabla committed Oct 10, 2022
1 parent e080d62 commit c7cdf52
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 0 deletions.
81 changes: 81 additions & 0 deletions src/impl_views/methods.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// Copyright 2014-2016 bluss and ndarray developers.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

use crate::imp_prelude::*;
use crate::dimension::IntoDimension;
use crate::dimension::size_of_shape_checked;

impl<'a, A, D> ArrayView<'a, A, D>
where
D: Dimension,
{
/// Broadcasts an arrayview.
pub fn broadcast_ref<E>(&self, dim: E) -> Option<ArrayView<'a, A, E::Dim>>
where
E: IntoDimension,
{
/// Return new stride when trying to grow `from` into shape `to`
///
/// Broadcasting works by returning a "fake stride" where elements
/// to repeat are in axes with 0 stride, so that several indexes point
/// to the same element.
///
/// **Note:** Cannot be used for mutable iterators, since repeating
/// elements would create aliasing pointers.
fn upcast<D: Dimension, E: Dimension>(to: &D, from: &E, stride: &E) -> Option<D> {
// Make sure the product of non-zero axis lengths does not exceed
// `isize::MAX`. This is the only safety check we need to perform
// because all the other constraints of `ArrayBase` are guaranteed
// to be met since we're starting from a valid `ArrayBase`.
let _ = size_of_shape_checked(to).ok()?;

let mut new_stride = to.clone();
// begin at the back (the least significant dimension)
// size of the axis has to either agree or `from` has to be 1
if to.ndim() < from.ndim() {
return None;
}

{
let mut new_stride_iter = new_stride.slice_mut().iter_mut().rev();
for ((er, es), dr) in from
.slice()
.iter()
.rev()
.zip(stride.slice().iter().rev())
.zip(new_stride_iter.by_ref())
{
/* update strides */
if *dr == *er {
/* keep stride */
*dr = *es;
} else if *er == 1 {
/* dead dimension, zero stride */
*dr = 0
} else {
return None;
}
}

/* set remaining strides to zero */
for dr in new_stride_iter {
*dr = 0;
}
}
Some(new_stride)
}
let dim = dim.into_dimension();

// Note: zero strides are safe precisely because we return an read-only view
let broadcast_strides = match upcast(&dim, &self.dim, &self.strides) {
Some(st) => st,
None => return None,
};
unsafe { Some(ArrayView::new(self.ptr, dim, broadcast_strides)) }
}
}
1 change: 1 addition & 0 deletions src/impl_views/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod constructors;
mod conversions;
mod indexing;
mod methods;
mod splitting;

pub use constructors::*;
Expand Down

0 comments on commit c7cdf52

Please sign in to comment.