Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional support for borsh serialisation #1335

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion Cargo.toml
Expand Up @@ -43,6 +43,7 @@ libc = { version = "0.2.82", optional = true }

matrixmultiply = { version = "0.3.2", default-features = false, features=["cgemm"] }

borsh = { version = "1.2", optional = true, default-features = false }
serde = { version = "1.0", optional = true, default-features = false, features = ["alloc"] }
rawpointer = { version = "0.2" }

Expand All @@ -66,7 +67,7 @@ serde-1 = ["serde"]
test = []

# This feature is used for docs
docs = ["approx", "approx-0_5", "serde", "rayon"]
docs = ["approx", "approx-0_5", "serde", "borsh", "rayon"]

std = ["num-traits/std", "matrixmultiply/std"]
rayon = ["rayon_", "std"]
Expand Down
101 changes: 101 additions & 0 deletions src/array_borsh.rs
@@ -0,0 +1,101 @@
use crate::imp_prelude::*;
use crate::IntoDimension;
use alloc::vec::Vec;
use borsh::{BorshDeserialize, BorshSerialize};
use core::ops::Deref;

/// **Requires crate feature `"borsh"`**
impl<I> BorshSerialize for Dim<I>
where
I: BorshSerialize,
{
fn serialize<W: borsh::io::Write>(&self, writer: &mut W) -> borsh::io::Result<()> {
<I as BorshSerialize>::serialize(&self.ix(), writer)
}
}

/// **Requires crate feature `"borsh"`**
impl<I> BorshDeserialize for Dim<I>
where
I: BorshDeserialize,
{
fn deserialize_reader<R: borsh::io::Read>(reader: &mut R) -> borsh::io::Result<Self> {
<I as BorshDeserialize>::deserialize_reader(reader).map(Dim::new)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be worried that we have to break compatibility here if we change how dimensions are stored?

}
}

/// **Requires crate feature `"borsh"`**
impl BorshSerialize for IxDyn {
fn serialize<W: borsh::io::Write>(&self, writer: &mut W) -> borsh::io::Result<()> {
let elts = self.ix().deref();
// Output length of dimensions.
<usize as BorshSerialize>::serialize(&elts.len(), writer)?;
// Followed by actual data.
for elt in elts {
<Ix as BorshSerialize>::serialize(elt, writer)?;
}
Ok(())
}
}

/// **Requires crate feature `"borsh"`**
impl BorshDeserialize for IxDyn {
fn deserialize_reader<R: borsh::io::Read>(reader: &mut R) -> borsh::io::Result<Self> {
// Deserialize the length.
let len = <usize as BorshDeserialize>::deserialize_reader(reader)?;
// Deserialize the given number of elements. We assume the source is
// trusted so we use a capacity hint...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why trusted?

let mut elts = Vec::with_capacity(len);
for _ix in 0..len {
elts.push(<Ix as BorshDeserialize>::deserialize_reader(reader)?);
}
Ok(elts.into_dimension())
}
}

/// **Requires crate feature `"borsh"`**
impl<A, D, S> BorshSerialize for ArrayBase<S, D>
where
A: BorshSerialize,
D: Dimension + BorshSerialize,
S: Data<Elem = A>,
{
fn serialize<W: borsh::io::Write>(&self, writer: &mut W) -> borsh::io::Result<()> {
// Dimensions
<D as BorshSerialize>::serialize(&self.raw_dim(), writer)?;
// Followed by length of data
let iter = self.iter();
<usize as BorshSerialize>::serialize(&iter.len(), writer)?;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why both length and dimensions?

// Followed by data itself.
for elt in iter {
<A as BorshSerialize>::serialize(elt, writer)?;
}
Ok(())
}
}

/// **Requires crate feature `"borsh"`**
impl<A, D, S> BorshDeserialize for ArrayBase<S, D>
where
A: BorshDeserialize,
D: BorshDeserialize + Dimension,
S: DataOwned<Elem = A>,
{
fn deserialize_reader<R: borsh::io::Read>(reader: &mut R) -> borsh::io::Result<Self> {
// Dimensions
let dim = <D as BorshDeserialize>::deserialize_reader(reader)?;
// Followed by length of data
let len = <usize as BorshDeserialize>::deserialize_reader(reader)?;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you handle integer size here? What if it was serialized on a 64-bit usize platform but you deserialize on a 32-bit?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would look at indexmap-rs/indexmap#313 for inspiration - fixed size integers and check error cases

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

borsh says that container sizes should be stored as u32. Unsure how it interacts with ndarray. For me, developing this PR is not my priority unfortunately. Help from others welcome.

// Followed by data itself.
let mut data = Vec::with_capacity(len);
for _ix in 0..len {
data.push(<A as BorshDeserialize>::deserialize_reader(reader)?);
}
ArrayBase::from_shape_vec(dim, data).map_err(|_shape_err| {
borsh::io::Error::new(
borsh::io::ErrorKind::InvalidData,
"data and dimensions must match in size",
)
})
}
}
2 changes: 2 additions & 0 deletions src/lib.rs
Expand Up @@ -164,6 +164,8 @@ mod aliases;
#[macro_use]
mod itertools;
mod argument_traits;
#[cfg(feature = "borsh")]
mod array_borsh;
#[cfg(feature = "serde")]
mod array_serde;
mod arrayformat;
Expand Down
6 changes: 5 additions & 1 deletion xtest-serialization/Cargo.toml
Expand Up @@ -8,7 +8,7 @@ publish = false
test = false

[dependencies]
ndarray = { path = "..", features = ["serde"] }
ndarray = { path = "..", features = ["serde", "borsh"] }

[features]
default = ["ron"]
Expand All @@ -23,6 +23,10 @@ version = "1.0.40"
[dev-dependencies.rmp-serde]
version = "0.14.0"

[dev-dependencies.borsh]
version = "1.2"
default-features = false

[dependencies.ron]
version = "0.5.1"
optional = true
70 changes: 70 additions & 0 deletions xtest-serialization/tests/serialize.rs
Expand Up @@ -9,6 +9,8 @@ extern crate rmp_serde;
#[cfg(feature = "ron")]
extern crate ron;

extern crate borsh;

use ndarray::{arr0, arr1, arr2, s, ArcArray, ArcArray2, ArrayD, IxDyn};

#[test]
Expand Down Expand Up @@ -218,3 +220,71 @@ fn serial_many_dim_ron() {
assert_eq!(a, a_de);
}
}

#[test]
fn serial_ixdyn_borsh() {
{
let a = arr0::<f32>(2.72).into_dyn();
let serial = borsh::to_vec(&a).unwrap();
println!("Borsh encode {:?} => {:?}", a, serial);
let res = borsh::from_slice::<ArcArray<f32, _>>(&serial);
println!("{:?}", res);
assert_eq!(a, res.unwrap());
}

{
let a = arr1::<f32>(&[2.72, 1., 2.]).into_dyn();
let serial = borsh::to_vec(&a).unwrap();
println!("Borsh encode {:?} => {:?}", a, serial);
let res = borsh::from_slice::<ArrayD<f32>>(&serial);
println!("{:?}", res);
assert_eq!(a, res.unwrap());
}

{
let a = arr2(&[[3., 1., 2.2], [3.1, 4., 7.]])
.into_shape(IxDyn(&[3, 1, 1, 1, 2, 1]))
.unwrap();
let serial = borsh::to_vec(&a).unwrap();
println!("Borsh encode {:?} => {:?}", a, serial);
let res = borsh::from_slice::<ArrayD<f32>>(&serial);
println!("{:?}", res);
assert_eq!(a, res.unwrap());
}
}

#[test]
fn serial_many_dim_borsh() {
use borsh::from_slice as borsh_deserialize;
use borsh::to_vec as borsh_serialize;

{
let a = arr0::<f32>(2.72);
let a_s = borsh_serialize(&a).unwrap();
let a_de: ArcArray<f32, _> = borsh_deserialize(&a_s).unwrap();
assert_eq!(a, a_de);
}

{
let a = arr1::<f32>(&[2.72, 1., 2.]);
let a_s = borsh_serialize(&a).unwrap();
let a_de: ArcArray<f32, _> = borsh_deserialize(&a_s).unwrap();
assert_eq!(a, a_de);
}

{
let a = arr2(&[[3., 1., 2.2], [3.1, 4., 7.]]);
let a_s = borsh_serialize(&a).unwrap();
let a_de: ArcArray<f32, _> = borsh_deserialize(&a_s).unwrap();
assert_eq!(a, a_de);
}

{
// Test a sliced array.
let mut a = ArcArray::linspace(0., 31., 32).reshape((2, 2, 2, 4));
a.slice_collapse(s![..;-1, .., .., ..2]);
let a_s = borsh_serialize(&a).unwrap();
let a_de: ArcArray<f32, _> = borsh_deserialize(&a_s).unwrap();
assert_eq!(a, a_de);
}
}