From f24c226211eb340c999d810013efbff35a49863f Mon Sep 17 00:00:00 2001 From: Lukas Rytz Date: Tue, 21 Jun 2022 13:57:50 +0200 Subject: [PATCH] Prevent Function0 execution during LazyList deserialization This PR ensures that LazyList deserialization will not execute an arbitrary Function0 when being passed a forged serialization stream. See the PR description for a detailed explanation. --- .../scala/collection/immutable/LazyList.scala | 11 ++- .../collection/immutable/LazyListTest.scala | 73 ++++++++++++++++++- 2 files changed, 79 insertions(+), 5 deletions(-) diff --git a/src/library/scala/collection/immutable/LazyList.scala b/src/library/scala/collection/immutable/LazyList.scala index dde413bd91ce..d9d52e8da44b 100644 --- a/src/library/scala/collection/immutable/LazyList.scala +++ b/src/library/scala/collection/immutable/LazyList.scala @@ -19,7 +19,7 @@ import java.lang.{StringBuilder => JStringBuilder} import scala.annotation.tailrec import scala.collection.generic.SerializeEnd -import scala.collection.mutable.{ArrayBuffer, Builder, ReusableBuilder, StringBuilder} +import scala.collection.mutable.{Builder, ReusableBuilder, StringBuilder} import scala.language.implicitConversions import scala.runtime.Statics @@ -1353,7 +1353,7 @@ object LazyList extends SeqFactory[LazyList] { private[this] def writeObject(out: ObjectOutputStream): Unit = { out.defaultWriteObject() var these = coll - while(these.knownNonEmpty) { + while (these.knownNonEmpty) { out.writeObject(these.head) these = these.tail } @@ -1363,14 +1363,17 @@ object LazyList extends SeqFactory[LazyList] { private[this] def readObject(in: ObjectInputStream): Unit = { in.defaultReadObject() - val init = new ArrayBuffer[A] + val init = new mutable.ListBuffer[A] var initRead = false while (!initRead) in.readObject match { case SerializeEnd => initRead = true case a => init += a.asInstanceOf[A] } val tail = in.readObject().asInstanceOf[LazyList[A]] - coll = init ++: tail + // scala/scala#10118: caution that no code path can evaluate `tail.state` + // before the resulting LazyList is returned + val it = init.toList.iterator + coll = newLL(stateFromIteratorConcatSuffix(it)(tail.state)) } private[this] def readResolve(): Any = coll diff --git a/test/junit/scala/collection/immutable/LazyListTest.scala b/test/junit/scala/collection/immutable/LazyListTest.scala index 58798ec4cb9d..8ff9ebb72361 100644 --- a/test/junit/scala/collection/immutable/LazyListTest.scala +++ b/test/junit/scala/collection/immutable/LazyListTest.scala @@ -8,12 +8,79 @@ import org.junit.Assert._ import scala.annotation.unused import scala.collection.mutable.{Builder, ListBuffer} -import scala.tools.testkit.AssertUtil +import scala.tools.testkit.{AssertUtil, ReflectUtil} import scala.util.Try @RunWith(classOf[JUnit4]) class LazyListTest { + @Test + def serialization(): Unit = { + import java.io._ + + def serialize(obj: AnyRef): Array[Byte] = { + val buffer = new ByteArrayOutputStream + val out = new ObjectOutputStream(buffer) + out.writeObject(obj) + buffer.toByteArray + } + + def deserialize(a: Array[Byte]): AnyRef = { + val in = new ObjectInputStream(new ByteArrayInputStream(a)) + in.readObject + } + + def serializeDeserialize[T <: AnyRef](obj: T) = deserialize(serialize(obj)).asInstanceOf[T] + + val l = LazyList.from(10) + + val ld1 = serializeDeserialize(l) + assertEquals(l.take(10).toList, ld1.take(10).toList) + + l.tail.head + val ld2 = serializeDeserialize(l) + assertEquals(l.take(10).toList, ld2.take(10).toList) + + LazyListTest.serializationForceCount = 0 + val u = LazyList.from(10).map(x => { LazyListTest.serializationForceCount += 1; x }) + + @unused def printDiff(): Unit = { + val a = serialize(u) + ReflectUtil.getFieldAccessible[LazyList[_]]("scala$collection$immutable$LazyList$$stateEvaluated").setBoolean(u, true) + val b = serialize(u) + val i = a.zip(b).indexWhere(p => p._1 != p._2) + println("difference: ") + println(s"val from = ${a.slice(i - 10, i + 10).mkString("List[Byte](", ", ", ")")}") + println(s"val to = ${b.slice(i - 10, i + 10).mkString("List[Byte](", ", ", ")")}") + } + + // to update this test, comment-out `LazyList.writeReplace` and run `printDiff` + // printDiff() + + val from = List[Byte](83, 116, 97, 116, 101, 59, 120, 112, 0, 0, 0, 115, 114, 0, 33, 106, 97, 118, 97, 46) + val to = List[Byte](83, 116, 97, 116, 101, 59, 120, 112, 0, 0, 1, 115, 114, 0, 33, 106, 97, 118, 97, 46) + + assertEquals(LazyListTest.serializationForceCount, 0) + + u.head + assertEquals(LazyListTest.serializationForceCount, 1) + + val data = serialize(u) + var i = data.indexOfSlice(from) + to.foreach(x => {data(i) = x; i += 1}) + + val ud1 = deserialize(data).asInstanceOf[LazyList[Int]] + + // this check failed before scala/scala#10118, deserialization triggered evaluation + assertEquals(LazyListTest.serializationForceCount, 1) + + ud1.tail.head + assertEquals(LazyListTest.serializationForceCount, 2) + + u.tail.head + assertEquals(LazyListTest.serializationForceCount, 3) + } + @Test def t6727_and_t6440_and_8627(): Unit = { assertTrue(LazyList.continually(()).filter(_ => true).take(2) == Seq((), ())) @@ -378,3 +445,7 @@ class LazyListTest { assertEquals(1, count) } } + +object LazyListTest { + var serializationForceCount = 0 +} \ No newline at end of file