Skip to content

Commit

Permalink
Prevent Function0 execution during LazyList deserialization
Browse files Browse the repository at this point in the history
This PR ensures that LazyList deserialization will not execute an
arbitrary Function0 when being passed a forged serialization stream.

See the PR description for a detailed explanation.
  • Loading branch information
lrytz committed Aug 24, 2022
1 parent c46fd04 commit 064f5c9
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
10 changes: 9 additions & 1 deletion src/library/scala/collection/immutable/LazyList.scala
Expand Up @@ -249,6 +249,14 @@ final class LazyList[+A] private(private[this] var lazyState: () => LazyList.Sta
@inline private def stateDefined: Boolean = stateEvaluated
private[this] var midEvaluation = false

// see scala/scala#10118
private def withNullLazyState[T](f: => T): T = {
val saved = lazyState
lazyState = null
try f
finally lazyState = saved
}

private lazy val state: State[A] = {
// if it's already mid-evaluation, we're stuck in an infinite
// self-referential loop (also it's empty)
Expand Down Expand Up @@ -1370,7 +1378,7 @@ object LazyList extends SeqFactory[LazyList] {
case a => init += a.asInstanceOf[A]
}
val tail = in.readObject().asInstanceOf[LazyList[A]]
coll = init ++: tail
coll = tail.withNullLazyState(tail.prependedAll(init))
}

private[this] def readResolve(): Any = coll
Expand Down
28 changes: 28 additions & 0 deletions test/junit/scala/collection/immutable/LazyListTest.scala
Expand Up @@ -14,6 +14,34 @@ import scala.util.Try
@RunWith(classOf[JUnit4])
class LazyListTest {

@Test
def serialization(): Unit = {
import java.io._

def serialize(obj: AnyRef): Array[Byte] = {
val buffer = new ByteArrayOutputStream
val out = new ObjectOutputStream(buffer)
out.writeObject(obj)
buffer.toByteArray
}
def deserialize(a: Array[Byte]): AnyRef = {
val in = new ObjectInputStream(new ByteArrayInputStream(a))
in.readObject
}

def serializeDeserialize[T <: AnyRef](obj: T) = deserialize(serialize(obj)).asInstanceOf[T]

val l = LazyList.from(10)
val ld = serializeDeserialize(l)
assertEquals(ld.toString, "LazyList(<not computed>)")
ld.head
assertEquals(ld.toString, "LazyList(10, <not computed>)")
ld.tail.head
assertEquals(ld.toString, "LazyList(10, 11, <not computed>)")
ld.tail.tail.head
assertEquals(ld.toString, "LazyList(10, 11, 12, <not computed>)")
}

@Test
def t6727_and_t6440_and_8627(): Unit = {
assertTrue(LazyList.continually(()).filter(_ => true).take(2) == Seq((), ()))
Expand Down

0 comments on commit 064f5c9

Please sign in to comment.