Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automated cherry pick of #99839: Cleanup portforward streams after their usage #100954

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions pkg/kubelet/server/portforward/httpstream.go
Expand Up @@ -163,6 +163,10 @@ func (h *httpStreamHandler) removeStreamPair(requestID string) {
h.streamPairsLock.Lock()
defer h.streamPairsLock.Unlock()

if h.conn != nil {
pair := h.streamPairs[requestID]
h.conn.RemoveStreams(pair.dataStream, pair.errorStream)
}
delete(h.streamPairs, requestID)
}

Expand Down
21 changes: 21 additions & 0 deletions pkg/kubelet/server/portforward/httpstream_test.go
Expand Up @@ -92,11 +92,23 @@ func TestHTTPStreamReceived(t *testing.T) {
}
}

type fakeConn struct {
removeStreamsCalled bool
}

func (*fakeConn) CreateStream(headers http.Header) (httpstream.Stream, error) { return nil, nil }
func (*fakeConn) Close() error { return nil }
func (*fakeConn) CloseChan() <-chan bool { return nil }
func (*fakeConn) SetIdleTimeout(timeout time.Duration) {}
func (f *fakeConn) RemoveStreams(streams ...httpstream.Stream) { f.removeStreamsCalled = true }

func TestGetStreamPair(t *testing.T) {
timeout := make(chan time.Time)

conn := &fakeConn{}
h := &httpStreamHandler{
streamPairs: make(map[string]*httpStreamPair),
conn: conn,
}

// test adding a new entry
Expand Down Expand Up @@ -158,6 +170,11 @@ func TestGetStreamPair(t *testing.T) {
// make sure monitorStreamPair completed
<-monitorDone

if !conn.removeStreamsCalled {
t.Fatalf("connection remove stream not called")
}
conn.removeStreamsCalled = false

// make sure the pair was removed
if h.hasStreamPair("1") {
t.Fatal("expected removal of pair after both data and error streams received")
Expand All @@ -171,6 +188,7 @@ func TestGetStreamPair(t *testing.T) {
if p == nil {
t.Fatal("expected p not to be nil")
}

monitorDone = make(chan struct{})
go func() {
h.monitorStreamPair(p, timeout)
Expand All @@ -183,6 +201,9 @@ func TestGetStreamPair(t *testing.T) {
if h.hasStreamPair("2") {
t.Fatal("expected stream pair to be removed")
}
if !conn.removeStreamsCalled {
t.Fatalf("connection remove stream not called")
}
}

func TestRequestID(t *testing.T) {
Expand Down
Expand Up @@ -78,6 +78,8 @@ type Connection interface {
// SetIdleTimeout sets the amount of time the connection may remain idle before
// it is automatically closed.
SetIdleTimeout(timeout time.Duration)
// RemoveStreams can be used to remove a set of streams from the Connection.
RemoveStreams(streams ...Stream)
}

// Stream represents a bidirectional communications channel that is part of an
Expand Down
Expand Up @@ -31,7 +31,7 @@ import (
// streams.
type connection struct {
conn *spdystream.Connection
streams []httpstream.Stream
streams map[uint32]httpstream.Stream
streamLock sync.Mutex
newStreamHandler httpstream.NewStreamHandler
}
Expand Down Expand Up @@ -64,7 +64,11 @@ func NewServerConnection(conn net.Conn, newStreamHandler httpstream.NewStreamHan
// will be invoked when the server receives a newly created stream from the
// client.
func newConnection(conn *spdystream.Connection, newStreamHandler httpstream.NewStreamHandler) httpstream.Connection {
c := &connection{conn: conn, newStreamHandler: newStreamHandler}
c := &connection{
conn: conn,
newStreamHandler: newStreamHandler,
streams: make(map[uint32]httpstream.Stream),
}
go conn.Serve(c.newSpdyStream)
return c
}
Expand All @@ -81,7 +85,7 @@ func (c *connection) Close() error {
// calling Reset instead of Close ensures that all streams are fully torn down
s.Reset()
}
c.streams = make([]httpstream.Stream, 0)
c.streams = make(map[uint32]httpstream.Stream, 0)
c.streamLock.Unlock()

// now that all streams are fully torn down, it's safe to call close on the underlying connection,
Expand All @@ -90,6 +94,15 @@ func (c *connection) Close() error {
return c.conn.Close()
}

// RemoveStreams can be used to removes a set of streams from the Connection.
func (c *connection) RemoveStreams(streams ...httpstream.Stream) {
c.streamLock.Lock()
for _, stream := range streams {
delete(c.streams, stream.Identifier())
}
c.streamLock.Unlock()
}

// CreateStream creates a new stream with the specified headers and registers
// it with the connection.
func (c *connection) CreateStream(headers http.Header) (httpstream.Stream, error) {
Expand All @@ -109,7 +122,7 @@ func (c *connection) CreateStream(headers http.Header) (httpstream.Stream, error
// it owns.
func (c *connection) registerStream(s httpstream.Stream) {
c.streamLock.Lock()
c.streams = append(c.streams, s)
c.streams[s.Identifier()] = s
c.streamLock.Unlock()
}

Expand Down
Expand Up @@ -178,3 +178,41 @@ func TestConnectionCloseIsImmediateThroughAProxy(t *testing.T) {
}
}
}

type fakeStream struct{ id uint32 }

func (*fakeStream) Read(p []byte) (int, error) { return 0, nil }
func (*fakeStream) Write(p []byte) (int, error) { return 0, nil }
func (*fakeStream) Close() error { return nil }
func (*fakeStream) Reset() error { return nil }
func (*fakeStream) Headers() http.Header { return nil }
func (f *fakeStream) Identifier() uint32 { return f.id }

func TestConnectionRemoveStreams(t *testing.T) {
c := &connection{streams: make(map[uint32]httpstream.Stream)}
stream0 := &fakeStream{id: 0}
stream1 := &fakeStream{id: 1}
stream2 := &fakeStream{id: 2}

c.registerStream(stream0)
c.registerStream(stream1)

if len(c.streams) != 2 {
t.Fatalf("should have two streams, has %d", len(c.streams))
}

// not exists
c.RemoveStreams(stream2)

if len(c.streams) != 2 {
t.Fatalf("should have two streams, has %d", len(c.streams))
}

// remove all existing
c.RemoveStreams(stream0, stream1)

if len(c.streams) != 0 {
t.Fatalf("should not have any streams, has %d", len(c.streams))
}

}
Expand Up @@ -68,6 +68,9 @@ func (c *fakeConnection) CloseChan() <-chan bool {
return c.closeChan
}

func (c *fakeConnection) RemoveStreams(_ ...httpstream.Stream) {
}

func (c *fakeConnection) SetIdleTimeout(timeout time.Duration) {
// no-op
}
Expand Down