From 8ade083a053b853c21bb84973ffb0a2541dfc67e Mon Sep 17 00:00:00 2001 From: rstoyanchev Date: Fri, 11 Nov 2022 11:20:18 +0000 Subject: [PATCH] Filter out null WebSocketSession attributes Closes gh-29315 --- .../socket/adapter/AbstractWebSocketSession.java | 7 +++++-- .../socket/adapter/AbstractWebSocketSession.java | 6 ++++-- .../standard/StandardWebSocketSessionTests.java | 16 ++++++++++++---- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/adapter/AbstractWebSocketSession.java b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/adapter/AbstractWebSocketSession.java index 1fc285962c3d..68a268062779 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/adapter/AbstractWebSocketSession.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/adapter/AbstractWebSocketSession.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -73,9 +73,12 @@ protected AbstractWebSocketSession(T delegate, String id, HandshakeInfo info, Da this.id = id; this.handshakeInfo = info; this.bufferFactory = bufferFactory; - this.attributes.putAll(info.getAttributes()); this.logPrefix = initLogPrefix(info, id); + info.getAttributes().entrySet().stream() + .filter(entry -> (entry.getKey() != null && entry.getValue() != null)) + .forEach(entry -> this.attributes.put(entry.getKey(), entry.getValue())); + if (logger.isDebugEnabled()) { logger.debug(getLogPrefix() + "Session id \"" + getId() + "\" for " + getHandshakeInfo().getUri()); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/AbstractWebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/AbstractWebSocketSession.java index 93dd65275384..63a403e46cf1 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/AbstractWebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/AbstractWebSocketSession.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -62,7 +62,9 @@ public abstract class AbstractWebSocketSession implements NativeWebSocketSess */ public AbstractWebSocketSession(@Nullable Map attributes) { if (attributes != null) { - this.attributes.putAll(attributes); + attributes.entrySet().stream() + .filter(entry -> (entry.getKey() != null && entry.getValue() != null)) + .forEach(entry -> this.attributes.put(entry.getKey(), entry.getValue())); } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/adapter/standard/StandardWebSocketSessionTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/adapter/standard/StandardWebSocketSessionTests.java index a16948cb4b1c..8caca767a027 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/adapter/standard/StandardWebSocketSessionTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/adapter/standard/StandardWebSocketSessionTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,6 +37,7 @@ * * @author Rossen Stoyanchev */ +@SuppressWarnings("resource") public class StandardWebSocketSessionTests { private final HttpHeaders headers = new HttpHeaders(); @@ -54,7 +55,6 @@ public void getPrincipalWithConstructorArg() { } @Test - @SuppressWarnings("resource") public void getPrincipalWithNativeSession() { TestPrincipal user = new TestPrincipal("joe"); @@ -68,7 +68,6 @@ public void getPrincipalWithNativeSession() { } @Test - @SuppressWarnings("resource") public void getPrincipalNone() { Session nativeSession = Mockito.mock(Session.class); given(nativeSession.getUserPrincipal()).willReturn(null); @@ -83,7 +82,6 @@ public void getPrincipalNone() { } @Test - @SuppressWarnings("resource") public void getAcceptedProtocol() { String protocol = "foo"; @@ -99,4 +97,14 @@ public void getAcceptedProtocol() { verifyNoMoreInteractions(nativeSession); } + @Test // gh-29315 + public void addAttributesWithNullKeyOrValue() { + this.attributes.put(null, "value"); + this.attributes.put("key", null); + this.attributes.put("foo", "bar"); + + assertThat(new StandardWebSocketSession(this.headers, this.attributes, null, null).getAttributes()) + .hasSize(1).containsEntry("foo", "bar"); + } + }