Skip to content
This repository has been archived by the owner on Sep 24, 2022. It is now read-only.

Commit

Permalink
Speed up map code too
Browse files Browse the repository at this point in the history
Also add regression test
  • Loading branch information
est31 committed Oct 25, 2019
1 parent dcc063f commit 9a38a82
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 11 deletions.
47 changes: 36 additions & 11 deletions src/de.rs
Expand Up @@ -216,6 +216,7 @@ impl<'de, 'b> de::Deserializer<'de> for &'b mut Deserializer<'de> {
{
let mut tables = self.tables()?;
let table_indices = build_table_indices(&tables);
let table_pindices = build_table_pindices(&tables);

let res = visitor.visit_map(MapVisitor {
values: Vec::new().into_iter(),
Expand All @@ -225,6 +226,7 @@ impl<'de, 'b> de::Deserializer<'de> for &'b mut Deserializer<'de> {
cur_parent: 0,
max: tables.len(),
table_indices: &table_indices,
table_pindices: &table_pindices,
tables: &mut tables,
array: false,
de: self,
Expand Down Expand Up @@ -330,6 +332,19 @@ fn build_table_indices<'de>(tables: &[Table<'de>]) -> HashMap<Vec<Cow<'de, str>>
res
}

fn build_table_pindices<'de>(tables: &[Table<'de>]) -> HashMap<Vec<Cow<'de, str>>, Vec<usize>> {
let mut res = HashMap::new();
for (i, table) in tables.iter().enumerate() {
let header = table.header.iter().map(|v| v.1.clone()).collect::<Vec<_>>();
for len in 0..=header.len() {
res.entry(header[..len].to_owned())
.or_insert(Vec::new())
.push(i);
}
}
res
}

fn headers_equal<'a, 'b>(hdr_a: &[(Span, Cow<'a, str>)], hdr_b: &[(Span, Cow<'b, str>)]) -> bool {
if hdr_a.len() != hdr_b.len() {
return false;
Expand All @@ -352,6 +367,7 @@ struct MapVisitor<'de, 'b> {
cur_parent: usize,
max: usize,
table_indices: &'b HashMap<Vec<Cow<'de, str>>, Vec<usize>>,
table_pindices: &'b HashMap<Vec<Cow<'de, str>>, Vec<usize>>,
tables: &'b mut [Table<'de>],
array: bool,
de: &'b mut Deserializer<'de>,
Expand All @@ -377,20 +393,27 @@ impl<'de, 'b> de::MapAccess<'de> for MapVisitor<'de, 'b> {
}

let next_table = {
let prefix = &self.tables[self.cur_parent].header[..self.depth];
self.tables[self.cur..self.max]
let prefix_stripped = self.tables[self.cur_parent].header[..self.depth]
.iter()
.enumerate()
.find(|&(_, t)| {
if t.values.is_none() {
return false;
}
match t.header.get(..self.depth) {
Some(header) => headers_equal(&header, &prefix),
None => false,
.map(|v| v.1.clone())
.collect::<Vec<_>>();
self.table_pindices
.get(&prefix_stripped)
.and_then(|entries| {
let start = entries
.binary_search(&self.cur)
.unwrap_or_else(std::convert::identity);
if start == entries.len() || entries[start] < self.cur {
return None;
}
entries[start..]
.iter()
.copied()
.filter(|i| *i < self.max)
.map(|i| (i, &self.tables[i]))
.find(|(_, table)| table.values.is_some())
.map(|p| p.0)
})
.map(|(i, _)| i + self.cur)
};

let pos = match next_table {
Expand Down Expand Up @@ -483,6 +506,7 @@ impl<'de, 'b> de::MapAccess<'de> for MapVisitor<'de, 'b> {
max: self.max,
array,
table_indices: &*self.table_indices,
table_pindices: &*self.table_pindices,
tables: &mut *self.tables,
de: &mut *self.de,
});
Expand Down Expand Up @@ -546,6 +570,7 @@ impl<'de, 'b> de::SeqAccess<'de> for MapVisitor<'de, 'b> {
cur: 0,
array: false,
table_indices: &*self.table_indices,
table_pindices: &*self.table_pindices,
tables: &mut self.tables,
de: &mut self.de,
})?;
Expand Down
37 changes: 37 additions & 0 deletions test-suite/tests/linear.rs
@@ -0,0 +1,37 @@
use std::time::{Duration, Instant};
use toml::Value;

const TOLERANCE: f64 = 2.0;

fn measure_time(entries: usize, f: impl Fn(usize) -> String) -> Duration {
let start = Instant::now();
let mut s = String::new();
for i in 0..entries {
s += &f(i);
s += "entry = 42\n"
}
s.parse::<Value>().unwrap();
Instant::now() - start
}

#[test]
fn linear_increase_map() {
let time_1 = measure_time(100, |i| format!("[header_no_{}]\n", i));
let time_4 = measure_time(400, |i| format!("[header_no_{}]\n", i));
dbg!(time_1, time_4);
// Now ensure that the deserialization time has increased linearly
// (within a tolerance interval) instead of, say, quadratically
assert!(time_4 > time_1.mul_f64(4.0 - TOLERANCE));
assert!(time_4 < time_1.mul_f64(4.0 + TOLERANCE));
}

#[test]
fn linear_increase_array() {
let time_1 = measure_time(100, |i| format!("[[header_no_{}]]\n", i));
let time_4 = measure_time(400, |i| format!("[[header_no_{}]]\n", i));
dbg!(time_1, time_4);
// Now ensure that the deserialization time has increased linearly
// (within a tolerance interval) instead of, say, quadratically
assert!(time_4 > time_1.mul_f64(4.0 - TOLERANCE));
assert!(time_4 < time_1.mul_f64(4.0 + TOLERANCE));
}

0 comments on commit 9a38a82

Please sign in to comment.