diff --git a/pkg/kubelet/server/portforward/httpstream.go b/pkg/kubelet/server/portforward/httpstream.go index 4b5e66d6607b..b9d58a7e403c 100644 --- a/pkg/kubelet/server/portforward/httpstream.go +++ b/pkg/kubelet/server/portforward/httpstream.go @@ -163,6 +163,10 @@ func (h *httpStreamHandler) removeStreamPair(requestID string) { h.streamPairsLock.Lock() defer h.streamPairsLock.Unlock() + if h.conn != nil { + pair := h.streamPairs[requestID] + h.conn.RemoveStreams(pair.dataStream, pair.errorStream) + } delete(h.streamPairs, requestID) } diff --git a/pkg/kubelet/server/portforward/httpstream_test.go b/pkg/kubelet/server/portforward/httpstream_test.go index 26e6905bbbc0..37e0ce8f9070 100644 --- a/pkg/kubelet/server/portforward/httpstream_test.go +++ b/pkg/kubelet/server/portforward/httpstream_test.go @@ -92,11 +92,23 @@ func TestHTTPStreamReceived(t *testing.T) { } } +type fakeConn struct { + removeStreamsCalled bool +} + +func (*fakeConn) CreateStream(headers http.Header) (httpstream.Stream, error) { return nil, nil } +func (*fakeConn) Close() error { return nil } +func (*fakeConn) CloseChan() <-chan bool { return nil } +func (*fakeConn) SetIdleTimeout(timeout time.Duration) {} +func (f *fakeConn) RemoveStreams(streams ...httpstream.Stream) { f.removeStreamsCalled = true } + func TestGetStreamPair(t *testing.T) { timeout := make(chan time.Time) + conn := &fakeConn{} h := &httpStreamHandler{ streamPairs: make(map[string]*httpStreamPair), + conn: conn, } // test adding a new entry @@ -158,6 +170,11 @@ func TestGetStreamPair(t *testing.T) { // make sure monitorStreamPair completed <-monitorDone + if !conn.removeStreamsCalled { + t.Fatalf("connection remove stream not called") + } + conn.removeStreamsCalled = false + // make sure the pair was removed if h.hasStreamPair("1") { t.Fatal("expected removal of pair after both data and error streams received") @@ -171,6 +188,7 @@ func TestGetStreamPair(t *testing.T) { if p == nil { t.Fatal("expected p not to be nil") } + monitorDone = make(chan struct{}) go func() { h.monitorStreamPair(p, timeout) @@ -183,6 +201,9 @@ func TestGetStreamPair(t *testing.T) { if h.hasStreamPair("2") { t.Fatal("expected stream pair to be removed") } + if !conn.removeStreamsCalled { + t.Fatalf("connection remove stream not called") + } } func TestRequestID(t *testing.T) { diff --git a/staging/src/k8s.io/apimachinery/pkg/util/httpstream/httpstream.go b/staging/src/k8s.io/apimachinery/pkg/util/httpstream/httpstream.go index 00ce5f785c8b..32f075782a9a 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/httpstream/httpstream.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/httpstream/httpstream.go @@ -78,6 +78,8 @@ type Connection interface { // SetIdleTimeout sets the amount of time the connection may remain idle before // it is automatically closed. SetIdleTimeout(timeout time.Duration) + // RemoveStreams can be used to remove a set of streams from the Connection. + RemoveStreams(streams ...Stream) } // Stream represents a bidirectional communications channel that is part of an diff --git a/staging/src/k8s.io/apimachinery/pkg/util/httpstream/spdy/connection.go b/staging/src/k8s.io/apimachinery/pkg/util/httpstream/spdy/connection.go index 9d222faa898f..b6903c527641 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/httpstream/spdy/connection.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/httpstream/spdy/connection.go @@ -31,7 +31,7 @@ import ( // streams. type connection struct { conn *spdystream.Connection - streams []httpstream.Stream + streams map[uint32]httpstream.Stream streamLock sync.Mutex newStreamHandler httpstream.NewStreamHandler } @@ -64,7 +64,11 @@ func NewServerConnection(conn net.Conn, newStreamHandler httpstream.NewStreamHan // will be invoked when the server receives a newly created stream from the // client. func newConnection(conn *spdystream.Connection, newStreamHandler httpstream.NewStreamHandler) httpstream.Connection { - c := &connection{conn: conn, newStreamHandler: newStreamHandler} + c := &connection{ + conn: conn, + newStreamHandler: newStreamHandler, + streams: make(map[uint32]httpstream.Stream), + } go conn.Serve(c.newSpdyStream) return c } @@ -81,7 +85,7 @@ func (c *connection) Close() error { // calling Reset instead of Close ensures that all streams are fully torn down s.Reset() } - c.streams = make([]httpstream.Stream, 0) + c.streams = make(map[uint32]httpstream.Stream, 0) c.streamLock.Unlock() // now that all streams are fully torn down, it's safe to call close on the underlying connection, @@ -90,6 +94,15 @@ func (c *connection) Close() error { return c.conn.Close() } +// RemoveStreams can be used to removes a set of streams from the Connection. +func (c *connection) RemoveStreams(streams ...httpstream.Stream) { + c.streamLock.Lock() + for _, stream := range streams { + delete(c.streams, stream.Identifier()) + } + c.streamLock.Unlock() +} + // CreateStream creates a new stream with the specified headers and registers // it with the connection. func (c *connection) CreateStream(headers http.Header) (httpstream.Stream, error) { @@ -109,7 +122,7 @@ func (c *connection) CreateStream(headers http.Header) (httpstream.Stream, error // it owns. func (c *connection) registerStream(s httpstream.Stream) { c.streamLock.Lock() - c.streams = append(c.streams, s) + c.streams[s.Identifier()] = s c.streamLock.Unlock() } diff --git a/staging/src/k8s.io/apimachinery/pkg/util/httpstream/spdy/connection_test.go b/staging/src/k8s.io/apimachinery/pkg/util/httpstream/spdy/connection_test.go index e00b29c461e1..cfeef2c9075b 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/httpstream/spdy/connection_test.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/httpstream/spdy/connection_test.go @@ -178,3 +178,41 @@ func TestConnectionCloseIsImmediateThroughAProxy(t *testing.T) { } } } + +type fakeStream struct{ id uint32 } + +func (*fakeStream) Read(p []byte) (int, error) { return 0, nil } +func (*fakeStream) Write(p []byte) (int, error) { return 0, nil } +func (*fakeStream) Close() error { return nil } +func (*fakeStream) Reset() error { return nil } +func (*fakeStream) Headers() http.Header { return nil } +func (f *fakeStream) Identifier() uint32 { return f.id } + +func TestConnectionRemoveStreams(t *testing.T) { + c := &connection{streams: make(map[uint32]httpstream.Stream)} + stream0 := &fakeStream{id: 0} + stream1 := &fakeStream{id: 1} + stream2 := &fakeStream{id: 2} + + c.registerStream(stream0) + c.registerStream(stream1) + + if len(c.streams) != 2 { + t.Fatalf("should have two streams, has %d", len(c.streams)) + } + + // not exists + c.RemoveStreams(stream2) + + if len(c.streams) != 2 { + t.Fatalf("should have two streams, has %d", len(c.streams)) + } + + // remove all existing + c.RemoveStreams(stream0, stream1) + + if len(c.streams) != 0 { + t.Fatalf("should not have any streams, has %d", len(c.streams)) + } + +} diff --git a/staging/src/k8s.io/client-go/tools/portforward/portforward_test.go b/staging/src/k8s.io/client-go/tools/portforward/portforward_test.go index 5b9afabeaab1..034be748fe2a 100644 --- a/staging/src/k8s.io/client-go/tools/portforward/portforward_test.go +++ b/staging/src/k8s.io/client-go/tools/portforward/portforward_test.go @@ -68,6 +68,9 @@ func (c *fakeConnection) CloseChan() <-chan bool { return c.closeChan } +func (c *fakeConnection) RemoveStreams(_ ...httpstream.Stream) { +} + func (c *fakeConnection) SetIdleTimeout(timeout time.Duration) { // no-op }