Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Please unseal trait Dimension #1326

Open
bionicles opened this issue Sep 16, 2023 · 2 comments
Open

Please unseal trait Dimension #1326

bionicles opened this issue Sep 16, 2023 · 2 comments

Comments

@bionicles
Copy link

bionicles commented Sep 16, 2023

Hi, today I wanted to impl Dimension in a macro so i could make custom dimensions; anyway, long story short, I couldn't make a custom newtype of (usize) axis by implementing Dimension because it has a __private__ function. Could you please remove the __private__ function from trait Dimension or make some builder for new Dimensions or tell me how I'm being a noob and it's easy? I reckon users of the library can empower Rustc to catch shape bugs and wrong-axis-indexed bugs at compile time instead of with panics at runtime, and I always feel dumb when my code crashes at runtime for something I could theoretically have caught at compile time but didnt. (Like when you mix up the TickersAxis and the TimestampsAxis ... whoops)

macro_rules! implement_dimension_traits {
    ($type:ident) => {
        #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord)]
        pub struct $type(usize);

        impl Add for $type {
            type Output = Self;
            fn add(self, rhs: Self) -> Self {
                Self(self.0 + rhs.0)
            }
        }

        impl AddAssign for $type {
            fn add_assign(&mut self, rhs: Self) {
                self.0 += rhs.0;
            }
        }

        impl AddAssign<&$type> for $type {
            fn add_assign(&mut self, rhs: &$type) {
                self.0 += rhs.0;
            }
        }

        impl Sub for $type {
            type Output = Self;
            fn sub(self, rhs: Self) -> Self {
                Self(self.0 - rhs.0)
            }
        }

        impl SubAssign for $type {
            fn sub_assign(&mut self, rhs: Self) {
                self.0 -= rhs.0;
            }
        }

        impl SubAssign<&$type> for $type {
            fn sub_assign(&mut self, rhs: &$type) {
                self.0 -= rhs.0;
            }
        }

        impl Mul<$type> for $type {
            type Output = Self;
            fn mul(self, rhs: Self) -> Self {
                Self(self.0 * rhs.0)
            }
        }

        impl MulAssign for $type {
            fn mul_assign(&mut self, rhs: Self) {
                self.0 *= rhs.0;
            }
        }

        impl MulAssign<&$type> for $type {
            fn mul_assign(&mut self, rhs: &$type) {
                self.0 *= rhs.0;
            }
        }

        impl DimAdd<Dim<[usize; 1]>> for $type {
            type Output = $type;
        }
        impl DimAdd<Dim<[usize; 0]>> for $type {
            type Output = $type;
        }
        impl DimAdd<Dim<IxDynImpl>> for $type {
            type Output = Dim<IxDynImpl>;
        }
        impl DimAdd<$type> for $type {
            type Output = $type;
        }

        impl DimMax<Dim<[usize; 0]>> for $type {
            type Output = $type;
        }
        impl DimMax<Dim<[usize; 1]>> for $type {
            type Output = $type;
        }
        impl DimMax<Dim<IxDynImpl>> for $type {
            type Output = Dim<IxDynImpl>;
        }

        impl Mul<usize> for $type {
            type Output = Self;
            fn mul(self, rhs: usize) -> Self {
                Self(self.0 * rhs)
            }
        }

        impl MulAssign<usize> for $type {
            fn mul_assign(&mut self, rhs: usize) {
                self.0 *= rhs;
            }
        }

        impl Index<usize> for $type {
            type Output = usize;
            fn index(&self, index: usize) -> &Self::Output {
                if index != 0 {
                    panic!("Index out of bounds");
                }
                &self.0
            }
        }
        impl IndexMut<usize> for $type {
            fn index_mut(&mut self, index: usize) -> &mut Self::Output {
                if index != 0 {
                    panic!("Index out of bounds");
                }
                &mut self.0
            }
        }
        impl Dimension for $type {
            const NDIM: Option<usize> = Some(1); // 1D dimension
            
            type Smaller = Self; // No smaller dimension for 1D
            type Larger = Self; // No larger dimension for 1D
            type Pattern = usize; // For 1D, a single usize should suffice as a pattern
            
            fn __private__(&self) -> ndarray::private::PrivateMarker {
                ndarray::private::PrivateMarker::PhantomData
            }

            fn into_pattern(self) -> Self::Pattern {
                // Returning the usize directly
                self.0
            }

            fn slice(&self) -> Self {
                // Slicing does not make sense for a 1D axis, so return self
                *self
            }

            fn slice_mut(&mut self) -> &mut Self {
                // Slicing does not make sense for a 1D axis, so return self
                self
            }

            fn zeros() -> Self {
                // A 1D axis with size 0
                Self(0)
            }

            fn insert_axis(&mut self, _axis: Axis) {
                // No-op for a 1D axis
            }

            fn try_remove_axis(&mut self, _axis: Axis) -> Result<(), String> {
                // No-op for a 1D axis
                Ok(())
            }
            
            fn ndim(&self) -> usize {
                1
            }
            
            fn size(&self) -> usize {
                self.0
            }
        }
    };
}

// Using the macro to implement traits for TreeAxis and TickersAxis
// TickersAxis is the number of stocks we're trading.
// TimestampsAxis is the number of timestamps in the data (candles).
implement_dimension_traits!(TickersAxis);
implement_dimension_traits!(TimestampsAxis);

struct N1<T, D1>(Array<T, D1>)
where
    D1: Dimension;

struct N2<T, D1, D2>(Array<T, (D1, D2)>)
where
    D1: Dimension,
    D2: Dimension;

struct N3<T, D1, D2, D3>(Array<T, (D1, D2, D3)>)
where
    D1: Dimension,
    D2: Dimension,
    D3: Dimension;

something like that, we could make named dimensions, maybe it's not really adding too much complexity, but we could be explicit about our axes, and prevent a lot of issues, just by moving information about shape variables from comments into the type system

    /// close prices of each (ticker, timestamp)
    closes: Array2<Prices>

vs

    /// close prices of each (ticker, timestamp)
    closes: N2<Pennies, TickersAxis, TimestampsAxis>,

or (better with the parens imho)

    closes: N2<Pennies, (TickersAxis, TimestampsAxis)>,

Could you please tell me if something like this is already possible with Rust's ndarray, and if not, can we please enable it? I'm guessing I'm just a noob and it's in the docs somewhere? Maybe it belongs as a bullet point above the fold on the front page if it's doable because any time we can make Rustc check our work, it's helpful. Or, perhaps I'm just a tryhard and this isn't necessary or I'm not appreciating the downstream issues this might cause?

Anyway, thank you for the crate! I enjoy it!

@bluss
Copy link
Member

bluss commented Mar 6, 2024

Unsealing Dimension is not realistic at this time, I'm sorry, it has been used to avoid a lot of breaking version bumps. Next step for dimensions is bringing them over to const generics, so they will be redesigned. I'm not available to help you with design right now, I'm sorry.

Named dimensions would be great. I enjoy xarray (Python) a lot.

@bionicles
Copy link
Author

I did get compile time bounds checking of named tensors working in rust but without type aliases the errors are unreadable HLists of TypeNum,

const generics are an interesting approach to improve the comptime bounds checking, but without custom named dimensions (literally just a new type of integer which implements your Dimension) it's a crapshoot to remember which axis is which name, always indexing with integers instead of names and users will continue to struggle with runtime shape bugs.

There might be a way for you to suitably control your Dimension trait using macros to define newtypes of Dimension to basically assign a 1-letter struct and a full name trait, then einsum looks extremely nice and readable and you can catch the bugs at compile time.

Burn has a proto-version of this, but imho nobody has figured it out properly, and it's the difference between runtime crashes and not, I'd meekly suggest it's a good project, and I'm happy to share some code.

Probably the answer (for me) is to quit rust and use zig, compile time functions could be dramatically more readable than the nonsense hoops we must jump through for type level rust right now. Alas, the rest of zig ecosystem doesn't seem to have what we need.

it's tbh a tantalizingly close objective to eliminate runtime shape bugs with Rustc, but yeah, a lot of work

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants