Skip to content

Commit

Permalink
Add optional support for borsh serialisation
Browse files Browse the repository at this point in the history
Behind a feature flag.
  • Loading branch information
Fuuzetsu committed Nov 27, 2023
1 parent 40bb0b2 commit 847c12c
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 2 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Expand Up @@ -43,6 +43,7 @@ libc = { version = "0.2.82", optional = true }

matrixmultiply = { version = "0.3.2", default-features = false, features=["cgemm"] }

borsh = { version = "1.2", optional = true, default-features = false }
serde = { version = "1.0", optional = true, default-features = false, features = ["alloc"] }
rawpointer = { version = "0.2" }

Expand All @@ -66,7 +67,7 @@ serde-1 = ["serde"]
test = []

# This feature is used for docs
docs = ["approx", "approx-0_5", "serde", "rayon"]
docs = ["approx", "approx-0_5", "serde", "borsh", "rayon"]

std = ["num-traits/std", "matrixmultiply/std"]
rayon = ["rayon_", "std"]
Expand Down
101 changes: 101 additions & 0 deletions src/array_borsh.rs
@@ -0,0 +1,101 @@
use crate::imp_prelude::*;
use crate::IntoDimension;
use alloc::vec::Vec;
use borsh::{BorshDeserialize, BorshSerialize};
use core::ops::Deref;

/// **Requires crate feature `"borsh"`**
impl<I> BorshSerialize for Dim<I>
where
I: BorshSerialize,
{
fn serialize<W: borsh::io::Write>(&self, writer: &mut W) -> borsh::io::Result<()> {
<I as BorshSerialize>::serialize(&self.ix(), writer)
}
}

/// **Requires crate feature `"borsh"`**
impl<I> BorshDeserialize for Dim<I>
where
I: BorshDeserialize,
{
fn deserialize_reader<R: borsh::io::Read>(reader: &mut R) -> borsh::io::Result<Self> {
<I as BorshDeserialize>::deserialize_reader(reader).map(Dim::new)
}
}

/// **Requires crate feature `"borsh"`**
impl BorshSerialize for IxDyn {
fn serialize<W: borsh::io::Write>(&self, writer: &mut W) -> borsh::io::Result<()> {
let elts = self.ix().deref();
// Output length of dimensions.
<usize as BorshSerialize>::serialize(&elts.len(), writer)?;
// Followed by actual data.
for elt in elts {
<Ix as BorshSerialize>::serialize(elt, writer)?;
}
Ok(())
}
}

/// **Requires crate feature `"borsh"`**
impl BorshDeserialize for IxDyn {
fn deserialize_reader<R: borsh::io::Read>(reader: &mut R) -> borsh::io::Result<Self> {
// Deserialize the length.
let len = <usize as BorshDeserialize>::deserialize_reader(reader)?;
// Deserialize the given number of elements. We assume the source is
// trusted so we use a capacity hint...
let mut elts = Vec::with_capacity(len);
for _ix in 0..len {
elts.push(<Ix as BorshDeserialize>::deserialize_reader(reader)?);
}
Ok(elts.into_dimension())
}
}

/// **Requires crate feature `"borsh"`**
impl<A, D, S> BorshSerialize for ArrayBase<S, D>
where
A: BorshSerialize,
D: Dimension + BorshSerialize,
S: Data<Elem = A>,
{
fn serialize<W: borsh::io::Write>(&self, writer: &mut W) -> borsh::io::Result<()> {
// Dimensions
<D as BorshSerialize>::serialize(&self.raw_dim(), writer)?;
// Followed by length of data
let iter = self.iter();
<usize as BorshSerialize>::serialize(&iter.len(), writer)?;
// Followed by data itself.
for elt in iter {
<A as BorshSerialize>::serialize(elt, writer)?;
}
Ok(())
}
}

/// **Requires crate feature `"borsh"`**
impl<A, D, S> BorshDeserialize for ArrayBase<S, D>
where
A: BorshDeserialize,
D: BorshDeserialize + Dimension,
S: DataOwned<Elem = A>,
{
fn deserialize_reader<R: borsh::io::Read>(reader: &mut R) -> borsh::io::Result<Self> {
// Dimensions
let dim = <D as BorshDeserialize>::deserialize_reader(reader)?;
// Followed by length of data
let len = <usize as BorshDeserialize>::deserialize_reader(reader)?;
// Followed by data itself.
let mut data = Vec::with_capacity(len);
for _ix in 0..len {
data.push(<A as BorshDeserialize>::deserialize_reader(reader)?);
}
ArrayBase::from_shape_vec(dim, data).map_err(|_shape_err| {
borsh::io::Error::new(
borsh::io::ErrorKind::InvalidData,
"data and dimensions must match in size",
)
})
}
}
2 changes: 2 additions & 0 deletions src/lib.rs
Expand Up @@ -164,6 +164,8 @@ mod aliases;
#[macro_use]
mod itertools;
mod argument_traits;
#[cfg(feature = "borsh")]
mod array_borsh;
#[cfg(feature = "serde")]
mod array_serde;
mod arrayformat;
Expand Down
6 changes: 5 additions & 1 deletion xtest-serialization/Cargo.toml
Expand Up @@ -8,7 +8,7 @@ publish = false
test = false

[dependencies]
ndarray = { path = "..", features = ["serde"] }
ndarray = { path = "..", features = ["serde", "borsh"] }

[features]
default = ["ron"]
Expand All @@ -23,6 +23,10 @@ version = "1.0.40"
[dev-dependencies.rmp-serde]
version = "0.14.0"

[dev-dependencies.borsh]
version = "1.2"
default-features = false

[dependencies.ron]
version = "0.5.1"
optional = true
70 changes: 70 additions & 0 deletions xtest-serialization/tests/serialize.rs
Expand Up @@ -9,6 +9,8 @@ extern crate rmp_serde;
#[cfg(feature = "ron")]
extern crate ron;

extern crate borsh;

use ndarray::{arr0, arr1, arr2, s, ArcArray, ArcArray2, ArrayD, IxDyn};

#[test]
Expand Down Expand Up @@ -218,3 +220,71 @@ fn serial_many_dim_ron() {
assert_eq!(a, a_de);
}
}

#[test]
fn serial_ixdyn_borsh() {
{
let a = arr0::<f32>(2.72).into_dyn();
let serial = borsh::to_vec(&a).unwrap();
println!("Borsh encode {:?} => {:?}", a, serial);
let res = borsh::from_slice::<ArcArray<f32, _>>(&serial);
println!("{:?}", res);
assert_eq!(a, res.unwrap());
}

{
let a = arr1::<f32>(&[2.72, 1., 2.]).into_dyn();
let serial = borsh::to_vec(&a).unwrap();
println!("Borsh encode {:?} => {:?}", a, serial);
let res = borsh::from_slice::<ArrayD<f32>>(&serial);
println!("{:?}", res);
assert_eq!(a, res.unwrap());
}

{
let a = arr2(&[[3., 1., 2.2], [3.1, 4., 7.]])
.into_shape(IxDyn(&[3, 1, 1, 1, 2, 1]))
.unwrap();
let serial = borsh::to_vec(&a).unwrap();
println!("Borsh encode {:?} => {:?}", a, serial);
let res = borsh::from_slice::<ArrayD<f32>>(&serial);
println!("{:?}", res);
assert_eq!(a, res.unwrap());
}
}

#[test]
fn serial_many_dim_borsh() {
use borsh::from_slice as borsh_deserialize;
use borsh::to_vec as borsh_serialize;

{
let a = arr0::<f32>(2.72);
let a_s = borsh_serialize(&a).unwrap();
let a_de: ArcArray<f32, _> = borsh_deserialize(&a_s).unwrap();
assert_eq!(a, a_de);
}

{
let a = arr1::<f32>(&[2.72, 1., 2.]);
let a_s = borsh_serialize(&a).unwrap();
let a_de: ArcArray<f32, _> = borsh_deserialize(&a_s).unwrap();
assert_eq!(a, a_de);
}

{
let a = arr2(&[[3., 1., 2.2], [3.1, 4., 7.]]);
let a_s = borsh_serialize(&a).unwrap();
let a_de: ArcArray<f32, _> = borsh_deserialize(&a_s).unwrap();
assert_eq!(a, a_de);
}

{
// Test a sliced array.
let mut a = ArcArray::linspace(0., 31., 32).reshape((2, 2, 2, 4));
a.slice_collapse(s![..;-1, .., .., ..2]);
let a_s = borsh_serialize(&a).unwrap();
let a_de: ArcArray<f32, _> = borsh_deserialize(&a_s).unwrap();
assert_eq!(a, a_de);
}
}

0 comments on commit 847c12c

Please sign in to comment.