Skip to content

Commit

Permalink
Preserve null policy in wrapped Java Map
Browse files Browse the repository at this point in the history
When using `compute`, which has "remove" semantics for `null` value,
try alternative means when user wants to put a `null`. In particular,
if the underlying map wants to throw NPE, the fallback should do so.
  • Loading branch information
som-snytt authored and lrytz committed Sep 13, 2022
1 parent d578a02 commit b824b84
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/library/scala/collection/concurrent/Map.scala
Expand Up @@ -97,7 +97,7 @@ trait Map[K, V] extends scala.collection.mutable.Map[K, V] {
case None =>
val v = op
putIfAbsent(key, v) match {
case Some(nv) => nv
case Some(ov) => ov
case None => v
}
}
Expand Down
43 changes: 37 additions & 6 deletions src/library/scala/collection/convert/JavaCollectionWrappers.scala
Expand Up @@ -21,6 +21,7 @@ import java.{lang => jl, util => ju}
import scala.jdk.CollectionConverters._
import scala.util.Try
import scala.util.chaining._
import scala.util.control.ControlThrowable

/** Wrappers for exposing Scala collections as Java collections and vice-versa */
@SerialVersionUID(3L)
Expand Down Expand Up @@ -332,7 +333,12 @@ private[collection] object JavaCollectionWrappers extends Serializable {
else
None
}
override def getOrElseUpdate(key: K, op: => V): V = underlying.computeIfAbsent(key, _ => op)

override def getOrElseUpdate(key: K, op: => V): V =
underlying.computeIfAbsent(key, _ => op) match {
case null => update(key, null.asInstanceOf[V]); null.asInstanceOf[V]
case v => v
}

def addOne(kv: (K, V)): this.type = { underlying.put(kv._1, kv._2); this }
def subtractOne(key: K): this.type = { underlying remove key; this }
Expand All @@ -355,8 +361,17 @@ private[collection] object JavaCollectionWrappers extends Serializable {

override def update(k: K, v: V): Unit = underlying.put(k, v)

override def updateWith(key: K)(remappingFunction: Option[V] => Option[V]): Option[V] = Option {
underlying.compute(key, (_, v) => remappingFunction(Option(v)).getOrElse(null.asInstanceOf[V]))
override def updateWith(key: K)(remappingFunction: Option[V] => Option[V]): Option[V] = {
def remap(k: K, v: V): V =
remappingFunction(Option(v)) match {
case Some(null) => throw PutNull
case Some(x) => x
case None => null.asInstanceOf[V]
}
try Option(underlying.compute(key, remap))
catch {
case PutNull => update(key, null.asInstanceOf[V]); Some(null.asInstanceOf[V])
}
}

// support Some(null) if currently bound to null
Expand Down Expand Up @@ -441,7 +456,11 @@ private[collection] object JavaCollectionWrappers extends Serializable {

override def get(k: K) = Option(underlying get k)

override def getOrElseUpdate(key: K, op: => V): V = underlying.computeIfAbsent(key, _ => op)
override def getOrElseUpdate(key: K, op: => V): V =
underlying.computeIfAbsent(key, _ => op) match {
case null => super/*[concurrent.Map]*/.getOrElseUpdate(key, op)
case v => v
}

override def isEmpty: Boolean = underlying.isEmpty
override def knownSize: Int = if (underlying.isEmpty) 0 else super.knownSize
Expand All @@ -462,8 +481,17 @@ private[collection] object JavaCollectionWrappers extends Serializable {
case _ => Try(last).toOption
}

override def updateWith(key: K)(remappingFunction: Option[V] => Option[V]): Option[V] = Option {
underlying.compute(key, (_, v) => remappingFunction(Option(v)).getOrElse(null.asInstanceOf[V]))
override def updateWith(key: K)(remappingFunction: Option[V] => Option[V]): Option[V] = {
def remap(k: K, v: V): V =
remappingFunction(Option(v)) match {
case Some(null) => throw PutNull // see scala/scala#10129
case Some(x) => x
case None => null.asInstanceOf[V]
}
try Option(underlying.compute(key, remap))
catch {
case PutNull => super/*[concurrent.Map]*/.updateWith(key)(remappingFunction)
}
}
}

Expand Down Expand Up @@ -572,4 +600,7 @@ private[collection] object JavaCollectionWrappers extends Serializable {

override def mapFactory = mutable.HashMap
}

/** Thrown when certain Map operations attempt to put a null value. */
private val PutNull = new ControlThrowable {}
}
69 changes: 68 additions & 1 deletion test/junit/scala/collection/convert/MapWrapperTest.scala
Expand Up @@ -2,12 +2,13 @@ package scala.collection.convert

import java.{util => jutil}

import org.junit.Assert._
import org.junit.Assert.{assertEquals, assertFalse, assertTrue}
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4

import scala.jdk.CollectionConverters._
import scala.tools.testkit.AssertUtil.assertThrows
import scala.util.chaining._

@RunWith(classOf[JUnit4])
Expand Down Expand Up @@ -107,4 +108,70 @@ class MapWrapperTest {
loki.done = true
runner.join()
}
@Test def `updateWith and getOrElseUpdate should reflect null policy of update`: Unit = {
val jmap = new jutil.concurrent.ConcurrentHashMap[String, String]()
val wrapped = jmap.asScala
assertThrows[NullPointerException](jmap.put("K", null))
assertThrows[NullPointerException](jmap.putIfAbsent("K", null))
assertThrows[NullPointerException](wrapped.put("K", null))
assertThrows[NullPointerException](wrapped.update("K", null))
assertThrows[NullPointerException](wrapped.updateWith("K")(_ => Some(null)))
assertThrows[NullPointerException](wrapped.getOrElseUpdate("K", null))

var count = 0
def v = {
count += 1
null
}
assertThrows[NullPointerException](wrapped.update("K", v))
assertEquals(1, count)
assertThrows[NullPointerException](wrapped.updateWith("K")(_ => Some(v)))
assertEquals(3, count) // extra count in retry
}
@Test def `more updateWith and getOrElseUpdate should reflect null policy of update`: Unit = {
val jmap = new jutil.HashMap[String, String]()
val wrapped = jmap.asScala
wrapped.put("K", null)
assertEquals(1, wrapped.size)
wrapped.remove("K")
assertEquals(0, wrapped.size)
wrapped.update("K", null)
assertEquals(1, wrapped.size)
wrapped.remove("K")
wrapped.updateWith("K")(_ => Some(null))
assertEquals(1, wrapped.size)
wrapped.remove("K")
wrapped.getOrElseUpdate("K", null)
assertEquals(1, wrapped.size)

var count = 0
def v = {
count += 1
null
}
wrapped.update("K", v)
assertEquals(1, count)
wrapped.remove("K")
wrapped.updateWith("K")(_ => Some(v))
assertEquals(2, count)
}

@Test def `getOrElseUpdate / updateWith support should insert null`: Unit = {
val jmap = new jutil.HashMap[String, String]()
val wrapped = jmap.asScala

wrapped.getOrElseUpdate("a", null)
assertTrue(jmap.containsKey("a"))

wrapped.getOrElseUpdate(null, "x")
assertTrue(jmap.containsKey(null))

jmap.clear()

wrapped.updateWith("b")(_ => Some(null))
assertTrue(jmap.containsKey("b"))

wrapped.updateWith(null)(_ => Some("x"))
assertTrue(jmap.containsKey(null))
}
}

0 comments on commit b824b84

Please sign in to comment.