Skip to content

Commit

Permalink
FEAT: Let RandomExt methods apply to array views if possible (sample)
Browse files Browse the repository at this point in the history
The RandomExt methods for sampling were unintentionally restricted to
owned arrays only (like the original random constructors). Now the
methods which can also apply to array views.
  • Loading branch information
bluss committed Dec 21, 2020
1 parent 37e4070 commit 2dfa40f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
14 changes: 11 additions & 3 deletions ndarray-rand/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use crate::rand::seq::index;
use crate::rand::{thread_rng, Rng, SeedableRng};

use ndarray::{Array, Axis, RemoveAxis, ShapeBuilder};
use ndarray::{ArrayBase, DataOwned, Dimension};
use ndarray::{ArrayBase, DataOwned, RawData, Data, Dimension};
#[cfg(feature = "quickcheck")]
use quickcheck::{Arbitrary, Gen};

Expand Down Expand Up @@ -63,7 +63,7 @@ pub mod rand_distr {
/// [`.random_using()`](#tymethod.random_using).
pub trait RandomExt<S, A, D>
where
S: DataOwned<Elem = A>,
S: RawData<Elem = A>,
D: Dimension,
{
/// Create an array with shape `dim` with elements drawn from
Expand All @@ -87,6 +87,7 @@ where
fn random<Sh, IdS>(shape: Sh, distribution: IdS) -> ArrayBase<S, D>
where
IdS: Distribution<S::Elem>,
S: DataOwned<Elem = A>,
Sh: ShapeBuilder<Dim = D>;

/// Create an array with shape `dim` with elements drawn from
Expand Down Expand Up @@ -117,6 +118,7 @@ where
where
IdS: Distribution<S::Elem>,
R: Rng + ?Sized,
S: DataOwned<Elem = A>,
Sh: ShapeBuilder<Dim = D>;

/// Sample `n_samples` lanes slicing along `axis` using the default RNG.
Expand Down Expand Up @@ -163,6 +165,7 @@ where
fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array<A, D>
where
A: Copy,
S: Data<Elem = A>,
D: RemoveAxis;

/// Sample `n_samples` lanes slicing along `axis` using the specified RNG `rng`.
Expand Down Expand Up @@ -223,17 +226,19 @@ where
where
R: Rng + ?Sized,
A: Copy,
S: Data<Elem = A>,
D: RemoveAxis;
}

impl<S, A, D> RandomExt<S, A, D> for ArrayBase<S, D>
where
S: DataOwned<Elem = A>,
S: RawData<Elem = A>,
D: Dimension,
{
fn random<Sh, IdS>(shape: Sh, dist: IdS) -> ArrayBase<S, D>
where
IdS: Distribution<S::Elem>,
S: DataOwned<Elem = A>,
Sh: ShapeBuilder<Dim = D>,
{
Self::random_using(shape, dist, &mut get_rng())
Expand All @@ -243,6 +248,7 @@ where
where
IdS: Distribution<S::Elem>,
R: Rng + ?Sized,
S: DataOwned<Elem = A>,
Sh: ShapeBuilder<Dim = D>,
{
Self::from_shape_simple_fn(shape, move || dist.sample(rng))
Expand All @@ -251,6 +257,7 @@ where
fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array<A, D>
where
A: Copy,
S: Data<Elem = A>,
D: RemoveAxis,
{
self.sample_axis_using(axis, n_samples, strategy, &mut get_rng())
Expand All @@ -266,6 +273,7 @@ where
where
R: Rng + ?Sized,
A: Copy,
S: Data<Elem = A>,
D: RemoveAxis,
{
let indices: Vec<_> = match strategy {
Expand Down
7 changes: 7 additions & 0 deletions ndarray-rand/tests/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ fn test_dim_f() {
}
}

#[test]
fn sample_axis_on_view() {
let m = 5;
let a = Array::random((m, 4), Uniform::new(0., 2.));
let _samples = a.view().sample_axis(Axis(0), m, SamplingStrategy::WithoutReplacement);
}

#[test]
#[should_panic]
fn oversampling_without_replacement_should_panic() {
Expand Down

0 comments on commit 2dfa40f

Please sign in to comment.