Skip to content

Commit

Permalink
Improve polling updater.stop() call when long polling (#152)
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulSonOfLars committed Feb 19, 2024
1 parent 954c160 commit e2bc46c
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 24 deletions.
12 changes: 7 additions & 5 deletions ext/botmapping.go
Expand Up @@ -163,9 +163,15 @@ func (m *botMapping) getHandlerFunc(prefix string) func(writer http.ResponseWrit
w.WriteHeader(http.StatusNotFound)
return
}

b.updateWriterControl.Add(1)
defer b.updateWriterControl.Done()

if b.shouldStopUpdates() {
w.WriteHeader(http.StatusServiceUnavailable)
return
}

headerSecret := r.Header.Get("X-Telegram-Bot-Api-Secret-Token")
if b.webhookSecret != "" && b.webhookSecret != headerSecret {
// Drop any updates from invalid secret tokens.
Expand All @@ -184,10 +190,6 @@ func (m *botMapping) getHandlerFunc(prefix string) func(writer http.ResponseWrit
return
}

if b.isUpdateChannelStopped() {
return
}

b.updateChan <- bytes
}
}
Expand All @@ -213,7 +215,7 @@ func (b *botData) stop() {
close(b.updateChan)
}

func (b *botData) isUpdateChannelStopped() bool {
func (b *botData) shouldStopUpdates() bool {
select {
case <-b.stopUpdates:
// if anything comes in on the closing channel, we know the channel is closed.
Expand Down
4 changes: 2 additions & 2 deletions ext/botmapping_test.go
Expand Up @@ -89,13 +89,13 @@ func Test_botData_isUpdateChannelStopped(t *testing.T) {
t.Errorf("bot with token %s should not have failed to be added", b.Token)
return
}
if bData.isUpdateChannelStopped() {
if bData.shouldStopUpdates() {
t.Errorf("bot with token %s should not be stopped yet", b.Token)
return
}

bData.stop()
if !bData.isUpdateChannelStopped() {
if !bData.shouldStopUpdates() {
t.Errorf("bot with token %s should be stopped", b.Token)
return
}
Expand Down
12 changes: 8 additions & 4 deletions ext/updater.go
Expand Up @@ -173,6 +173,11 @@ func (u *Updater) pollingLoop(bData *botData, opts *gotgbot.RequestOpts, v map[s
defer bData.updateWriterControl.Done()

for {
// Check if updater loop has been terminated.
if bData.shouldStopUpdates() {
return
}

// Manually craft the getUpdate calls to improve memory management, reduce json parsing overheads, and
// unnecessary reallocation of url.Values in the polling loop.
r, err := bData.bot.Request("getUpdates", v, nil, opts)
Expand Down Expand Up @@ -219,10 +224,6 @@ func (u *Updater) pollingLoop(bData *botData, opts *gotgbot.RequestOpts, v map[s

v["offset"] = strconv.FormatInt(lastUpdate.UpdateId+1, 10)

if bData.isUpdateChannelStopped() {
return
}

for _, updData := range rawUpdates {
temp := updData // use new mem address to avoid loop conflicts
bData.updateChan <- temp
Expand All @@ -240,6 +241,9 @@ func (u *Updater) Idle() {
}

// Stop stops the current updater and dispatcher instances.
//
// When using long polling, Stop() will wait for the getUpdates call to return, which may cause a delay due to the
// request timeout.
func (u *Updater) Stop() error {
// Stop any running servers.
if u.webhookServer != nil {
Expand Down
32 changes: 24 additions & 8 deletions ext/updater_test.go
Expand Up @@ -10,6 +10,7 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -98,8 +99,14 @@ func concurrentTest(t *testing.T) {
t.Parallel()

delay := time.Second
server := basicTestServer(t, map[string]testEndpoint{
"getUpdates": {delay: delay, reply: `{"ok": true, "result": [{"message": {"text": "stop"}}]}`},
server := basicTestServer(t, map[string]*testEndpoint{
"getUpdates": {
delay: delay,
replies: []string{
`{"ok": true, "result": [{"message": {"text": "stop"}}]}`,
},
reply: `{"ok": true, "result": []}`,
},
"deleteWebhook": {reply: `{"ok": true, "result": true}`},
})
defer server.Close()
Expand Down Expand Up @@ -290,7 +297,7 @@ func TestUpdater_GetHandlerFunc(t *testing.T) {
}

func TestUpdaterAllowsWebhookDeletion(t *testing.T) {
server := basicTestServer(t, map[string]testEndpoint{
server := basicTestServer(t, map[string]*testEndpoint{
"getUpdates": {reply: `{"ok": true}`},
"deleteWebhook": {reply: `{"ok": true, "result": true}`},
})
Expand Down Expand Up @@ -329,7 +336,7 @@ func TestUpdaterAllowsWebhookDeletion(t *testing.T) {
}

func TestUpdaterSupportsTwoPollingBots(t *testing.T) {
server := basicTestServer(t, map[string]testEndpoint{
server := basicTestServer(t, map[string]*testEndpoint{
"getUpdates": {reply: `{"ok": true, "result": []}`},
})
defer server.Close()
Expand Down Expand Up @@ -384,7 +391,7 @@ func TestUpdaterSupportsTwoPollingBots(t *testing.T) {
}

func TestUpdaterThrowsErrorWhenSameLongPollAddedTwice(t *testing.T) {
server := basicTestServer(t, map[string]testEndpoint{
server := basicTestServer(t, map[string]*testEndpoint{
"getUpdates": {reply: `{"ok": true, "result": []}`},
})
defer server.Close()
Expand Down Expand Up @@ -432,7 +439,7 @@ func TestUpdaterThrowsErrorWhenSameLongPollAddedTwice(t *testing.T) {
}

func TestUpdaterSupportsLongPollReAdding(t *testing.T) {
server := basicTestServer(t, map[string]testEndpoint{
server := basicTestServer(t, map[string]*testEndpoint{
"getUpdates": {reply: `{"ok": true, "result": []}`},
})
defer server.Close()
Expand Down Expand Up @@ -484,10 +491,14 @@ func TestUpdaterSupportsLongPollReAdding(t *testing.T) {

type testEndpoint struct {
delay time.Duration
// Will reply these until we run out of replies, at which point we repeat "reply"
replies []string
idx atomic.Int32
// default reply
reply string
}

func basicTestServer(t *testing.T, methods map[string]testEndpoint) *httptest.Server {
func basicTestServer(t *testing.T, methods map[string]*testEndpoint) *httptest.Server {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
pathItems := strings.Split(r.URL.Path, "/")
lastItem := pathItems[len(pathItems)-1]
Expand All @@ -498,7 +509,12 @@ func basicTestServer(t *testing.T, methods map[string]testEndpoint) *httptest.Se
if out.delay != 0 {
time.Sleep(out.delay)
}
fmt.Fprint(w, out.reply)
count := int(out.idx.Add(1) - 1)
if len(out.replies) != 0 && len(out.replies) > count {
fmt.Fprint(w, out.replies[count])
} else {
fmt.Fprint(w, out.reply)
}
return
}

Expand Down
17 changes: 12 additions & 5 deletions samples/echoMultiBot/main.go
Expand Up @@ -86,7 +86,10 @@ func main() {

// If we get here, the updater.Idle() has ended.
// This means that updater.Stop() has been called, stopping all bots gracefully.
log.Println("Updater is no longer idling; all bots have been stopped gracefully.")
log.Println("Updater is no longer idling; all bots have been stopped gracefully. Exiting in 1s.")

// We sleep one last second to allow for the "stopall" goroutine to send the shutdown message.
time.Sleep(time.Second)
}

// startLongPollingBots demonstrates how to start multiple bots with long-polling.
Expand Down Expand Up @@ -159,11 +162,14 @@ func stop(b *gotgbot.Bot, ctx *ext.Context, updater *ext.Updater) error {
return fmt.Errorf("failed to echo message: %w", err)
}

if !updater.StopBot(b.Token) {
ctx.EffectiveMessage.Reply(b, fmt.Sprintf("Unable to find bot %d; was it already stopped?", b.Id), nil)
return nil
}
go func() {
if !updater.StopBot(b.Token) {
ctx.EffectiveMessage.Reply(b, fmt.Sprintf("Unable to find bot %d; was it already stopped?", b.Id), nil)
return
}

ctx.EffectiveMessage.Reply(b, "Stopped @"+b.Username, nil)
}()
return nil
}

Expand All @@ -181,6 +187,7 @@ func stopAll(b *gotgbot.Bot, ctx *ext.Context, updater *ext.Updater) error {
ctx.EffectiveMessage.Reply(b, fmt.Sprintf("Failed to stop updater: %s", err.Error()), nil)
return
}
ctx.EffectiveMessage.Reply(b, "All bots have been stopped.", nil)
}()

return nil
Expand Down

0 comments on commit e2bc46c

Please sign in to comment.