Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scalar operations with complex array and complex scalars #781

Open
g-bauer opened this issue Feb 15, 2020 · 10 comments · May be fixed by #782
Open

Scalar operations with complex array and complex scalars #781

g-bauer opened this issue Feb 15, 2020 · 10 comments · May be fixed by #782

Comments

@g-bauer
Copy link

g-bauer commented Feb 15, 2020

Hi there,

thanks a lot for the nice library! I hope this is the right place to ask questions! Please feel free to close if it's not.

I am having problems understanding why the following is not working.

use num_complex::Complex64;
use ndarray::arr1;

fn main() {
    let complex_array = arr1(&[Complex64::from(1.0), Complex64::from(2.0)]);
    let complex_scalar = Complex64::from(1.0);
    let f_scalar = 1.0;
    let f_array = arr1(&[1.0, 2.0]);
    println!("{}", &complex_array * f_scalar);
    println!("{}", &complex_array * complex_scalar);
    println!("{}", complex_scalar * &complex_array);
    // println!("{}", complex_scalar * &f_array); // <- not working
    // println!("{}", f_scalar * complex_array); // <- not working
}

I could circumvent these problems by manually iterating through arrays but it makes the code less readable and less ergonomic. Maybe someone could help me understand why these operations are a problem.

@LukeMathWalker
Copy link
Member

Hey @g-bauer!
Can you also post the error messages you are getting from the compiler?

@g-bauer
Copy link
Author

g-bauer commented Feb 15, 2020

Hey,

thank you for the answer.
The error message is

error[E0271]: type mismatch resolving `<ndarray::OwnedRepr<num_complex::Complex<f64>> as ndarray::data_traits::RawData>::Elem == f64`
  --> src/main.rs:13:29
   |
13 |     println!("{}", f_scalar * &complex_array); // <- not working
   |                             ^ expected struct `num_complex::Complex`, found `f64`
   |
   = note: expected type `num_complex::Complex<f64>`
              found type `f64`
   = note: required because of the requirements on the impl of `std::ops::Mul<&ndarray::ArrayBase<ndarray::OwnedRepr<num_complex::Complex<f64>>, ndarray::dimension::dim::Dim<[usize; 1]>>>` for `f64`

Link to playground

@LukeMathWalker
Copy link
Member

Ah, I see.
There are some limitations when it comes to the possible combinations of binary operations, scalar types and array types.
You can have a look here to see what is supported: https://docs.rs/ndarray/0.13.0/ndarray/struct.ArrayBase.html#binary-operators-with-two-arrays

@g-bauer
Copy link
Author

g-bauer commented Feb 15, 2020

Ah ok. I thought ScalarOperand and impl_scalar_lhs_op! take care of the combinations as well when binary operations for scalar types are already implemented.

Is that something one could implement or is there a reason I am not seeing that won't allow for an implementation: scalar * vec and vec * scalar where the elements of vec have a different data type from scalar?

@jturner314
Copy link
Member

I've created #782 to add the necessary implementations to compile the example (plus a few more). Note that even with #782, it will be possible to do complex_scalar + &f_array (as in the example) but not complex_scalar + f_array due to implementation details. (Once Rust has specialization, we could add a complex_scalar + f_array implementation.)

@g-bauer
Copy link
Author

g-bauer commented Feb 16, 2020

That's great! Thank you very much for the help & effort.

I'm still confused by some behavior, but I guess it stems from my limited knowledge of ndarray and rust in general.

use ndarray::arr1;
use num_complex::Complex64;

fn main() {
    let complex_array = arr1(&[Complex64::from(1.0), Complex64::from(2.0)]);
    let complex_scalar = Complex64::from(1.0);
    let f_scalar = 1.0;
    let f_array = arr1(&[1.0, 2.0]);
    println!("{}", &f_array * complex_scalar * f_scalar);
    // println!("{}", f_scalar * &f_array * complex_scalar); // not working: type mismatch
    println!("{}", &(f_scalar * &f_array) * complex_scalar); // fixes above
    println!("{}", complex_scalar * &f_array * f_scalar); // this works as well

    println!("{}", &complex_array * &f_array); // this works
    // println!("{}", &f_array * &complex_array); // this will not
}

I am actually using Complex here as a placeholder for my own data type (Dual numbers). In my code, I have a lot of operations between arrays and scalars of floats and Dual numbers. Even with these changes, I am having difficulties properly implementing my equations (due to my limited knowledge).
Again, thank you very much for your help!

@jturner314
Copy link
Member

f_scalar * &f_array * complex_scalar is f64 * &Array1<f64> * Complex64. By operator precedence, the left multiplication will be evaluated first, so we get Array1<f64> * Complex64. Multiplication is not implemented for this combination of types. The closest implementation is this one:

impl<A, S, D, B> Mul<B> for ArrayBase<S, D>
where
    A: Clone + Mul<B, Output = A>,
    S: DataOwned<Elem = A> + DataMut,
    D: Dimension,
    B: ScalarOperand,

It almost matches, but the Output = A bound in A: Clone + Mul<B, Output = A> isn't met because the output of multiplying a f64 and a Complex64 is a Complex64, not a f64. The reason for the Output = A restriction is that this implementation works by modifying the array in-place (to avoid allocating a new array), and it's not possible to change the element type in-place. Until Rust has specialization, I don't see a way for us to remove this restriction without always allocating a new array.

As you discovered, the workaround is to write &(f_scalar * &f_array) * complex_scalar (after #782) because the relevant implementation doesn't have the Output = A restriction, since it always needs to allocate a new array.

&f_array * &complex_array is &Array1<f64> * &Array1<Complex64>. The closest implementation is this one:

impl<'a, A, B, S, S2, D, E> Mul<&'a ArrayBase<S2, E>> for &'a ArrayBase<S, D>
where
    A: Clone + Mul<B, Output = A>,
    B: Clone,
    S: Data<Elem = A>,
    S2: Data<Elem = B>,
    D: Dimension,
    E: Dimension,

As with the previous case, the reason why this implementation doesn't apply is the Output = A restriction. However, in this case, the implementation always needs to allocate a new array, so there's no reason why we need to keep the Output = A restriction. I have plans to remove this restriction in the future using the (currently experimental) nditer crate, but it's not ready yet.

As you discovered, the workaround in this case is to swap left and right sides of the multiplication. For non-commutative arithmetic operations where you can't swap the two sides, I'd suggest allocating an array for the result of the operation, and then zipping the three arrays together to perform the operation:

use ndarray::prelude::*;
use num_complex::Complex64;

fn main() {
    let complex_array = arr1(&[Complex64::from(1.0), Complex64::from(2.0)]);
    let f_array = arr1(&[1.0, 2.0]);
    let mut result = Array1::zeros(2);
    azip!((res in &mut result, &f in &f_array, &c in &complex_array) *res = f * c);
}

I am actually using Complex here as a placeholder for my own data type (Dual numbers).

See #783 for the implementations you need to add for your type. Please feel free to ask questions if that isn't clear.

@jturner314
Copy link
Member

jturner314 commented Feb 17, 2020

@g-bauer Note that in most cases, if you have an expression involving more than a single arithmetic operation, Zip/azip will provide better performance than using multiple arithmetic operations on the arrays, because Zip needs to iterate only once and you can avoid allocating arrays for intermediate subexpressions.

@g-bauer
Copy link
Author

g-bauer commented Feb 17, 2020

Thank you very much for the details. I think using Zip/azip is my best option at the moment - I'll have to play around with it and see if that works.

Regarding #783, will this allow for impls outside of ndarray? I currently clone the repo and implement all operations for my structs similar to Complex which is a bit unwieldy.

Again, thanks for the help!

@jturner314
Copy link
Member

Regarding #783, will this allow for impls outside of ndarray?

This is already possible. What I'm proposing in #783 is providing a macro to make it easier.

Here's an example of a crate outside ndarray that provides the necessary implementations to treat MyType as a scalar in addition operations (to provide multiplication, division, etc., you'd need to provide analogous implementations for std::ops::Mul, std::ops::Div, etc.):

use ndarray::prelude::*;
use ndarray::{Data, DataMut, DataOwned, ScalarOperand};
use std::ops::Add;

#[derive(Clone, Copy, Debug, PartialEq)]
pub struct MyType(f32);

impl Add<f32> for MyType {
    type Output = MyType;

    fn add(self, rhs: f32) -> MyType {
        MyType(self.0 + rhs)
    }
}

impl Add<MyType> for f32 {
    type Output = MyType;

    fn add(self, rhs: MyType) -> MyType {
        MyType(self + rhs.0)
    }
}

impl Add<MyType> for MyType {
    type Output = MyType;

    fn add(self, rhs: MyType) -> MyType {
        MyType(self.0 + rhs.0)
    }
}

impl ScalarOperand for MyType {}

macro_rules! impl_scalar_lhs_op {
    ($scalar:ty, $operator:tt, $trt:ident, $mth:ident, $doc:expr) => (
        /// Perform elementwise
        #[doc=$doc]
        /// between the scalar `self` and array `rhs`,
        /// and return the result (based on `self`).
        impl<A, S, D> $trt<ArrayBase<S, D>> for $scalar
        where
            $scalar: Clone + $trt<A, Output=A>,
            A: Clone,
            S: DataOwned<Elem=A> + DataMut,
            D: Dimension,
        {
            type Output = ArrayBase<S, D>;
            fn $mth(self, mut rhs: ArrayBase<S, D>) -> ArrayBase<S, D> {
                rhs.map_inplace(move |elt| *elt = self.clone() $operator elt.clone());
                rhs
            }
        }

        /// Perform elementwise
        #[doc=$doc]
        /// between the scalar `self` and array `rhs`,
        /// and return the result as a new `Array`.
        impl<'a, A, S, D, B> $trt<&'a ArrayBase<S, D>> for $scalar
        where
            $scalar: Clone + $trt<A, Output=B>,
            A: Clone,
            S: Data<Elem=A>,
            D: Dimension,
        {
            type Output = Array<B, D>;
            fn $mth(self, rhs: &ArrayBase<S, D>) -> Array<B, D> {
                rhs.map(move |elt| self.clone() $operator elt.clone())
            }
        }
    );
}

impl_scalar_lhs_op!(MyType, +, Add, add, "addition");

fn main() {
    let f = 1.;
    let my = MyType(2.);
    let arr_f = array![10., 20., 30.];
    let arr_my = array![MyType(1.), MyType(2.), MyType(3.)];

    // f32 and MyType
    assert_eq!(f + my, MyType(3.));
    assert_eq!(my + f, MyType(3.));

    // f32 and array of MyType
    assert_eq!(f + arr_my.clone(), array![MyType(2.), MyType(3.), MyType(4.)]);
    assert_eq!(f + &arr_my, array![MyType(2.), MyType(3.), MyType(4.)]);
    assert_eq!(arr_my.clone() + f, array![MyType(2.), MyType(3.), MyType(4.)]);
    assert_eq!(&arr_my + f, array![MyType(2.), MyType(3.), MyType(4.)]);

    // MyType and array of f32
    // assert_eq!(my + arr_f.clone(), array![MyType(12.), MyType(22.), MyType(32.)]); // doesn't work
    assert_eq!(my + &arr_f, array![MyType(12.), MyType(22.), MyType(32.)]);
    // assert_eq!(arr_f.clone() + my, array![MyType(12.), MyType(22.), MyType(32.)]); // doesn't work
    assert_eq!(&arr_f + my, array![MyType(12.), MyType(22.), MyType(32.)]);

    // MyType and array of MyType
    assert_eq!(my + arr_my.clone(), array![MyType(3.), MyType(4.), MyType(5.)]);
    assert_eq!(my + &arr_my, array![MyType(3.), MyType(4.), MyType(5.)]);
    assert_eq!(arr_my.clone() + my, array![MyType(3.), MyType(4.), MyType(5.)]);
    assert_eq!(&arr_my + my, array![MyType(3.), MyType(4.), MyType(5.)]);

    // array of f32 and array of MyType
    // assert_eq!(arr_f.clone() + arr_my.clone(), array![MyType(11.), MyType(22.), MyType(33.)]); // doesn't work
    // assert_eq!(arr_f.clone() + &arr_my, array![MyType(11.), MyType(22.), MyType(33.)]); // doesn't work
    // assert_eq!(&arr_f + &arr_my, array![MyType(11.), MyType(22.), MyType(33.)]); // doesn't work
    assert_eq!(arr_my.clone() + arr_f.clone(), array![MyType(11.), MyType(22.), MyType(33.)]);
    assert_eq!(arr_my.clone() + &arr_f, array![MyType(11.), MyType(22.), MyType(33.)]);
    assert_eq!(&arr_my + &arr_f, array![MyType(11.), MyType(22.), MyType(33.)]);

    // array of MyType and array of MyType
    assert_eq!(arr_my.clone() + arr_my.clone(), array![MyType(2.), MyType(4.), MyType(6.)]);
    assert_eq!(arr_my.clone() + &arr_my, array![MyType(2.), MyType(4.), MyType(6.)]);
    assert_eq!(&arr_my + &arr_my, array![MyType(2.), MyType(4.), MyType(6.)]);
}

(Note that some of those expressions rely on #782 to work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants