From 749bf0a9e228769e6dfbfa2139d08ee044ddb730 Mon Sep 17 00:00:00 2001 From: David Tolnay Date: Sat, 13 Aug 2022 11:31:29 -0700 Subject: [PATCH] Deserialize empty plain scalar to an empty map or seq --- src/de.rs | 69 ++++++++++++++++++++++++++++++++++----------- src/value/de.rs | 18 ++++++++---- tests/test_de.rs | 19 +++++++++++++ tests/test_error.rs | 4 +-- 4 files changed, 87 insertions(+), 23 deletions(-) diff --git a/src/de.rs b/src/de.rs index 65be8749..37c72ddf 100644 --- a/src/de.rs +++ b/src/de.rs @@ -536,7 +536,11 @@ impl<'de, 'document> DeserializerFromEvents<'de, 'document> { V: Visitor<'de>, { let (value, len) = self.recursion_check(mark, |de| { - let mut seq = SeqAccess { de, len: 0 }; + let mut seq = SeqAccess { + empty: false, + de, + len: 0, + }; let value = visitor.visit_seq(&mut seq)?; Ok((value, seq.len)) })?; @@ -550,6 +554,7 @@ impl<'de, 'document> DeserializerFromEvents<'de, 'document> { { let (value, len) = self.recursion_check(mark, |de| { let mut map = MapAccess { + empty: false, de, len: 0, key: None, @@ -563,7 +568,11 @@ impl<'de, 'document> DeserializerFromEvents<'de, 'document> { fn end_sequence(&mut self, len: usize) -> Result<()> { let total = { - let mut seq = SeqAccess { de: self, len }; + let mut seq = SeqAccess { + empty: false, + de: self, + len, + }; while de::SeqAccess::next_element::(&mut seq)?.is_some() {} seq.len }; @@ -591,6 +600,7 @@ impl<'de, 'document> DeserializerFromEvents<'de, 'document> { fn end_mapping(&mut self, len: usize) -> Result<()> { let total = { let mut map = MapAccess { + empty: false, de: self, len, key: None, @@ -636,6 +646,7 @@ impl<'de, 'document> DeserializerFromEvents<'de, 'document> { } struct SeqAccess<'de, 'document, 'seq> { + empty: bool, de: &'seq mut DeserializerFromEvents<'de, 'document>, len: usize, } @@ -647,6 +658,9 @@ impl<'de, 'document, 'seq> de::SeqAccess<'de> for SeqAccess<'de, 'document, 'seq where T: DeserializeSeed<'de>, { + if self.empty { + return Ok(None); + } match self.de.peek_event()? { Event::SequenceEnd | Event::Void => Ok(None), _ => { @@ -669,6 +683,7 @@ impl<'de, 'document, 'seq> de::SeqAccess<'de> for SeqAccess<'de, 'document, 'seq } struct MapAccess<'de, 'document, 'map> { + empty: bool, de: &'map mut DeserializerFromEvents<'de, 'document>, len: usize, key: Option<&'document [u8]>, @@ -681,6 +696,9 @@ impl<'de, 'document, 'map> de::MapAccess<'de> for MapAccess<'de, 'document, 'map where K: DeserializeSeed<'de>, { + if self.empty { + return Ok(None); + } match self.de.peek_event()? { Event::MappingEnd | Event::Void => Ok(None), Event::Scalar(scalar) => { @@ -1564,12 +1582,23 @@ impl<'de, 'document> de::Deserializer<'de> for &mut DeserializerFromEvents<'de, match next { Event::Alias(mut pos) => self.jump(&mut pos)?.deserialize_seq(visitor), Event::SequenceStart(_) => self.visit_sequence(visitor, mark), - Event::Void => { - *self.pos -= 1; - let mut seq = SeqAccess { de: self, len: 0 }; - visitor.visit_seq(&mut seq) + other => { + if match other { + Event::Void => true, + Event::Scalar(scalar) => { + scalar.value.is_empty() && scalar.style == ScalarStyle::Plain + } + _ => false, + } { + visitor.visit_seq(SeqAccess { + empty: true, + de: self, + len: 0, + }) + } else { + Err(invalid_type(other, &visitor)) + } } - other => Err(invalid_type(other, &visitor)), } .map_err(|err| error::fix_mark(err, mark, self.path)) } @@ -1601,16 +1630,24 @@ impl<'de, 'document> de::Deserializer<'de> for &mut DeserializerFromEvents<'de, match next { Event::Alias(mut pos) => self.jump(&mut pos)?.deserialize_map(visitor), Event::MappingStart(_) => self.visit_mapping(visitor, mark), - Event::Void => { - *self.pos -= 1; - let mut map = MapAccess { - de: self, - len: 0, - key: None, - }; - visitor.visit_map(&mut map) + other => { + if match other { + Event::Void => true, + Event::Scalar(scalar) => { + scalar.value.is_empty() && scalar.style == ScalarStyle::Plain + } + _ => false, + } { + visitor.visit_map(MapAccess { + empty: true, + de: self, + len: 0, + key: None, + }) + } else { + Err(invalid_type(other, &visitor)) + } } - other => Err(invalid_type(other, &visitor)), } .map_err(|err| error::fix_mark(err, mark, self.path)) } diff --git a/src/value/de.rs b/src/value/de.rs index 72bcff38..ee137211 100644 --- a/src/value/de.rs +++ b/src/value/de.rs @@ -184,7 +184,7 @@ where let len = mapping.len(); let mut deserializer = MapRefDeserializer::new(mapping); let map = visitor.visit_map(&mut deserializer)?; - let remaining = deserializer.iter.len(); + let remaining = deserializer.iter.unwrap().len(); if remaining == 0 { Ok(map) } else { @@ -390,6 +390,7 @@ impl<'de> Deserializer<'de> for Value { { match self.untag() { Value::Sequence(v) => visit_sequence(v, visitor), + Value::Null => visit_sequence(Sequence::new(), visitor), other => Err(other.invalid_type(&visitor)), } } @@ -419,6 +420,7 @@ impl<'de> Deserializer<'de> for Value { { match self.untag() { Value::Mapping(v) => visit_mapping(v, visitor), + Value::Null => visit_mapping(Mapping::new(), visitor), other => Err(other.invalid_type(&visitor)), } } @@ -903,8 +905,10 @@ impl<'de> Deserializer<'de> for &'de Value { where V: Visitor<'de>, { + static EMPTY: Sequence = Sequence::new(); match self.untag_ref() { Value::Sequence(v) => visit_sequence_ref(v, visitor), + Value::Null => visit_sequence_ref(&EMPTY, visitor), other => Err(other.invalid_type(&visitor)), } } @@ -934,6 +938,10 @@ impl<'de> Deserializer<'de> for &'de Value { { match self.untag_ref() { Value::Mapping(v) => visit_mapping_ref(v, visitor), + Value::Null => visitor.visit_map(&mut MapRefDeserializer { + iter: None, + value: None, + }), other => Err(other.invalid_type(&visitor)), } } @@ -1138,14 +1146,14 @@ impl<'de> SeqAccess<'de> for SeqRefDeserializer<'de> { } pub(crate) struct MapRefDeserializer<'de> { - iter: <&'de Mapping as IntoIterator>::IntoIter, + iter: Option<<&'de Mapping as IntoIterator>::IntoIter>, value: Option<&'de Value>, } impl<'de> MapRefDeserializer<'de> { pub(crate) fn new(map: &'de Mapping) -> Self { MapRefDeserializer { - iter: map.iter(), + iter: Some(map.iter()), value: None, } } @@ -1158,7 +1166,7 @@ impl<'de> MapAccess<'de> for MapRefDeserializer<'de> { where T: DeserializeSeed<'de>, { - match self.iter.next() { + match self.iter.as_mut().and_then(Iterator::next) { Some((key, value)) => { self.value = Some(value); seed.deserialize(key).map(Some) @@ -1178,7 +1186,7 @@ impl<'de> MapAccess<'de> for MapRefDeserializer<'de> { } fn size_hint(&self) -> Option { - match self.iter.size_hint() { + match self.iter.as_ref()?.size_hint() { (lower, Some(upper)) if lower == upper => Some(upper), _ => None, } diff --git a/tests/test_de.rs b/tests/test_de.rs index d6e17d7a..0d16d5e7 100644 --- a/tests/test_de.rs +++ b/tests/test_de.rs @@ -550,3 +550,22 @@ fn test_no_required_fields() { assert_eq!(expected, deserialized); } } + +#[test] +fn test_empty_scalar() { + #[derive(Deserialize, PartialEq, Debug)] + struct Struct { + thing: T, + } + + let yaml = "thing:\n"; + let expected = Struct { + thing: serde_yaml::Sequence::new(), + }; + test_de(yaml, &expected); + + let expected = Struct { + thing: serde_yaml::Mapping::new(), + }; + test_de(yaml, &expected); +} diff --git a/tests/test_error.rs b/tests/test_error.rs index a90c2dbc..67a9e315 100644 --- a/tests/test_error.rs +++ b/tests/test_error.rs @@ -328,8 +328,8 @@ fn test_invalid_scalar_type() { x: [i32; 1], } - let yaml = "x:\n"; - let expected = "x: invalid type: unit value, expected an array of length 1 at line 1 column 3"; + let yaml = "x: ''\n"; + let expected = "x: invalid type: string \"\", expected an array of length 1 at line 1 column 4"; test_error::(yaml, expected); }