Skip to content

Commit

Permalink
fix: prevent goroutine leak and CPU spinning at websocket transport (#…
Browse files Browse the repository at this point in the history
…2209)

* Added goroutine leak test for chat example

* Improved chat example with proper concurrency

* Revert "fix: prevents goroutine leak at websocket transport (#2168)"

This reverts commit eef7bfa.

* Improved subscription channel usage

* Regenerated examples and codegen

* Add support for subscription keepalives in websocket client

* Update chat example test

* if else chain to switch

Signed-off-by: Steve Coffman <steve@khanacademy.org>

* Revert "Add support for subscription keepalives in websocket client"

This reverts commits 64b882c and 670cf22.

* Fixed chat example race condition

* Fixed chatroom#Messages type

Co-authored-by: Steve Coffman <steve@khanacademy.org>
  • Loading branch information
moofMonkey and StevenACoffman committed May 26, 2022
1 parent 5f5bfcb commit 6855b72
Show file tree
Hide file tree
Showing 7 changed files with 336 additions and 272 deletions.
90 changes: 58 additions & 32 deletions _examples/chat/chat_test.go
Original file line number Diff line number Diff line change
@@ -1,49 +1,75 @@
package chat

import (
"testing"
"time"

"fmt"
"github.com/99designs/gqlgen/client"
"github.com/99designs/gqlgen/graphql/handler"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"runtime"
"sync"
"testing"
)

func TestChatSubscriptions(t *testing.T) {
c := client.New(handler.NewDefaultServer(NewExecutableSchema(New())))

sub := c.Websocket(`subscription @user(username:"vektah") { messageAdded(roomName:"#gophers") { text createdBy } }`)
defer sub.Close()

go func() {
var resp interface{}
time.Sleep(10 * time.Millisecond)
err := c.Post(`mutation {
a:post(text:"Hello!", roomName:"#gophers", username:"vektah") { id }
b:post(text:"Hello Vektah!", roomName:"#gophers", username:"andrey") { id }
c:post(text:"Whats up?", roomName:"#gophers", username:"vektah") { id }
}`, &resp)
assert.NoError(t, err)
}()

var msg struct {
resp struct {
MessageAdded struct {
Text string
CreatedBy string
const batchSize = 128
var wg sync.WaitGroup
for i := 0; i < batchSize*8; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
sub := c.Websocket(fmt.Sprintf(
`subscription @user(username:"vektah") { messageAdded(roomName:"#gophers%d") { text createdBy } }`,
i,
))
defer sub.Close()

var msg struct {
resp struct {
MessageAdded struct {
Text string
CreatedBy string
}
}
err error
}

msg.err = sub.Next(&msg.resp)
require.NoError(t, msg.err, "sub.Next")
require.Equal(t, "You've joined the room", msg.resp.MessageAdded.Text)
require.Equal(t, "system", msg.resp.MessageAdded.CreatedBy)

go func() {
var resp interface{}
err := c.Post(fmt.Sprintf(`mutation {
a:post(text:"Hello!", roomName:"#gophers%d", username:"vektah") { id }
b:post(text:"Hello Vektah!", roomName:"#gophers%d", username:"andrey") { id }
c:post(text:"Whats up?", roomName:"#gophers%d", username:"vektah") { id }
}`, i, i, i), &resp)
assert.NoError(t, err)
}()

msg.err = sub.Next(&msg.resp)
require.NoError(t, msg.err, "sub.Next")
require.Equal(t, "Hello!", msg.resp.MessageAdded.Text)
require.Equal(t, "vektah", msg.resp.MessageAdded.CreatedBy)

msg.err = sub.Next(&msg.resp)
require.NoError(t, msg.err, "sub.Next")
require.Equal(t, "Whats up?", msg.resp.MessageAdded.Text)
require.Equal(t, "vektah", msg.resp.MessageAdded.CreatedBy)
}(i)
// wait for goroutines to finish every N tests to not starve on CPU
if (i+1)%batchSize == 0 {
wg.Wait()
}
err error
}
wg.Wait()

msg.err = sub.Next(&msg.resp)
require.NoError(t, msg.err, "sub.Next")
require.Equal(t, "Hello!", msg.resp.MessageAdded.Text)
require.Equal(t, "vektah", msg.resp.MessageAdded.CreatedBy)

msg.err = sub.Next(&msg.resp)
require.NoError(t, msg.err, "sub.Next")
require.Equal(t, "Whats up?", msg.resp.MessageAdded.Text)
require.Equal(t, "vektah", msg.resp.MessageAdded.CreatedBy)
// 1 for the main thread, 1 for the testing package and remainder is reserved for the HTTP server threads
// TODO: use something like runtime.Stack to filter out HTTP server threads,
// TODO: which is required for proper concurrency and leaks testing
require.Less(t, runtime.NumGoroutine(), 1+1+batchSize*2, "goroutine leak")
}
22 changes: 13 additions & 9 deletions _examples/chat/generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

104 changes: 39 additions & 65 deletions _examples/chat/resolvers.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ import (
type ckey string

type resolver struct {
Rooms map[string]*Chatroom
mu sync.Mutex // nolint: structcheck
Rooms sync.Map
}

func (r *resolver) Mutation() MutationResolver {
Expand All @@ -33,7 +32,7 @@ func (r *resolver) Subscription() SubscriptionResolver {
func New() Config {
return Config{
Resolvers: &resolver{
Rooms: map[string]*Chatroom{},
Rooms: sync.Map{},
},
Directives: DirectiveRoot{
User: func(ctx context.Context, obj interface{}, next graphql.Resolver, username string) (res interface{}, err error) {
Expand All @@ -50,103 +49,78 @@ func getUsername(ctx context.Context) string {
return ""
}

type Observer struct {
Username string
Message chan *Message
}

type Chatroom struct {
Name string
Messages []Message
Observers map[string]struct {
Username string
Message chan *Message
}
Observers sync.Map
}

type mutationResolver struct{ *resolver }

func (r *mutationResolver) Post(ctx context.Context, text string, username string, roomName string) (*Message, error) {
r.mu.Lock()
room := r.Rooms[roomName]
if room == nil {
room = &Chatroom{
Name: roomName,
Observers: map[string]struct {
Username string
Message chan *Message
}{},
}
r.Rooms[roomName] = room
}
r.mu.Unlock()
room := r.getRoom(roomName)

message := Message{
message := &Message{
ID: randString(8),
CreatedAt: time.Now(),
Text: text,
CreatedBy: username,
}

room.Messages = append(room.Messages, message)
r.mu.Lock()
for _, observer := range room.Observers {
room.Messages = append(room.Messages, *message)
room.Observers.Range(func(_, v interface{}) bool {
observer := v.(*Observer)
if observer.Username == "" || observer.Username == message.CreatedBy {
observer.Message <- &message
observer.Message <- message
}
}
r.mu.Unlock()
return &message, nil
return true
})
return message, nil
}

type queryResolver struct{ *resolver }

func (r *queryResolver) Room(ctx context.Context, name string) (*Chatroom, error) {
r.mu.Lock()
room := r.Rooms[name]
if room == nil {
room = &Chatroom{
Name: name,
Observers: map[string]struct {
Username string
Message chan *Message
}{},
}
r.Rooms[name] = room
}
r.mu.Unlock()
func (r *resolver) getRoom(name string) *Chatroom {
room, _ := r.Rooms.LoadOrStore(name, &Chatroom{
Name: name,
Observers: sync.Map{},
})
return room.(*Chatroom)
}

return room, nil
func (r *queryResolver) Room(ctx context.Context, name string) (*Chatroom, error) {
return r.getRoom(name), nil
}

type subscriptionResolver struct{ *resolver }

func (r *subscriptionResolver) MessageAdded(ctx context.Context, roomName string) (<-chan *Message, error) {
r.mu.Lock()
room := r.Rooms[roomName]
if room == nil {
room = &Chatroom{
Name: roomName,
Observers: map[string]struct {
Username string
Message chan *Message
}{},
}
r.Rooms[roomName] = room
}
r.mu.Unlock()
room := r.getRoom(roomName)

id := randString(8)
events := make(chan *Message, 1)

go func() {
<-ctx.Done()
r.mu.Lock()
delete(room.Observers, id)
r.mu.Unlock()
room.Observers.Delete(id)
}()

r.mu.Lock()
room.Observers[id] = struct {
Username string
Message chan *Message
}{Username: getUsername(ctx), Message: events}
r.mu.Unlock()
room.Observers.Store(id, &Observer{
Username: getUsername(ctx),
Message: events,
})

events <- &Message{
ID: randString(8),
CreatedAt: time.Now(),
Text: "You've joined the room",
CreatedBy: "system",
}

return events, nil
}
Expand Down
22 changes: 13 additions & 9 deletions codegen/field.gotpl
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,21 @@ func (ec *executionContext) _{{$object.Name}}_{{$field.Name}}(ctx context.Contex
}
{{- if $object.Stream }}
return func(ctx context.Context) graphql.Marshaler {
res, ok := <-resTmp.(<-chan {{$field.TypeReference.GO | ref}})
if !ok {
select {
case res, ok := <-resTmp.(<-chan {{$field.TypeReference.GO | ref}}):
if !ok {
return nil
}
return graphql.WriterFunc(func(w io.Writer) {
w.Write([]byte{'{'})
graphql.MarshalString(field.Alias).MarshalGQL(w)
w.Write([]byte{':'})
ec.{{ $field.TypeReference.MarshalFunc }}(ctx, field.Selections, res).MarshalGQL(w)
w.Write([]byte{'}'})
})
case <-ctx.Done():
return nil
}
return graphql.WriterFunc(func(w io.Writer) {
w.Write([]byte{'{'})
graphql.MarshalString(field.Alias).MarshalGQL(w)
w.Write([]byte{':'})
ec.{{ $field.TypeReference.MarshalFunc }}(ctx, field.Selections, res).MarshalGQL(w)
w.Write([]byte{'}'})
})
}
{{- else }}
res := resTmp.({{$field.TypeReference.GO | ref}})
Expand Down

0 comments on commit 6855b72

Please sign in to comment.