Skip to content

Commit

Permalink
feat: deprecate WatchOverallProgress and WatchProgress function
Browse files Browse the repository at this point in the history
  • Loading branch information
jooola committed Apr 30, 2024
1 parent 5131c8d commit 704af1e
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 128 deletions.
120 changes: 39 additions & 81 deletions hcloud/action_watch.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package hcloud
import (
"context"
"fmt"
"time"
)

// WatchOverallProgress watches several actions' progress until they complete
Expand All @@ -24,6 +23,8 @@ import (
//
// WatchOverallProgress uses the [WithPollBackoffFunc] of the [Client] to wait
// until sending the next request.
//
// Deprecated: WatchOverallProgress is deprecated, use [WaitForFunc] instead.
func (c *ActionClient) WatchOverallProgress(ctx context.Context, actions []*Action) (<-chan int, <-chan error) {
errCh := make(chan error, len(actions))
progressCh := make(chan int)
Expand All @@ -32,66 +33,37 @@ func (c *ActionClient) WatchOverallProgress(ctx context.Context, actions []*Acti
defer close(errCh)
defer close(progressCh)

completedIDs := make([]int64, 0, len(actions))
watchIDs := make(map[int64]struct{}, len(actions))
for _, action := range actions {
watchIDs[action.ID] = struct{}{}
}

retries := 0
previousProgress := 0

for {
select {
case <-ctx.Done():
errCh <- ctx.Err()
return
case <-time.After(c.action.client.pollBackoffFunc(retries)):
retries++
}

opts := ActionListOpts{}
for watchID := range watchIDs {
opts.ID = append(opts.ID, watchID)
previousGlobalProgress := 0
progressByAction := make(map[int64]int, len(actions))
err := c.WaitForFunc(ctx, func(update *Action) error {
switch update.Status {
case ActionStatusRunning:
progressByAction[update.ID] = update.Progress
case ActionStatusSuccess:
progressByAction[update.ID] = 100
case ActionStatusError:
progressByAction[update.ID] = 100
errCh <- fmt.Errorf("action %d failed: %w", update.ID, update.Error())
}

as, err := c.AllWithOpts(ctx, opts)
if err != nil {
errCh <- err
return
}
if len(as) == 0 {
// No actions returned for the provided IDs, they do not exist in the API.
// We need to catch and fail early for this, otherwise the loop will continue
// indefinitely.
errCh <- fmt.Errorf("failed to wait for actions: remaining actions (%v) are not returned from API", opts.ID)
return
// Compute global progress
progressSum := 0
for _, value := range progressByAction {
progressSum += value
}
globalProgress := progressSum / len(actions)

progress := 0
for _, a := range as {
switch a.Status {
case ActionStatusRunning:
progress += a.Progress
case ActionStatusSuccess:
delete(watchIDs, a.ID)
completedIDs = append(completedIDs, a.ID)
case ActionStatusError:
delete(watchIDs, a.ID)
completedIDs = append(completedIDs, a.ID)
errCh <- fmt.Errorf("action %d failed: %w", a.ID, a.Error())
}
// Only send progress when it changed
if globalProgress != 0 && globalProgress != previousGlobalProgress {
sendProgress(progressCh, globalProgress)
previousGlobalProgress = globalProgress
}

progress += len(completedIDs) * 100
if progress != 0 && progress != previousProgress {
sendProgress(progressCh, progress/len(actions))
previousProgress = progress
}
return nil
}, actions...)

if len(watchIDs) == 0 {
return
}
if err != nil {
errCh <- err
}
}()

Expand All @@ -116,6 +88,8 @@ func (c *ActionClient) WatchOverallProgress(ctx context.Context, actions []*Acti
//
// WatchProgress uses the [WithPollBackoffFunc] of the [Client] to wait until
// sending the next request.
//
// Deprecated: WatchProgress is deprecated, use [WaitForFunc] instead.
func (c *ActionClient) WatchProgress(ctx context.Context, action *Action) (<-chan int, <-chan error) {
errCh := make(chan error, 1)
progressCh := make(chan int)
Expand All @@ -124,38 +98,22 @@ func (c *ActionClient) WatchProgress(ctx context.Context, action *Action) (<-cha
defer close(errCh)
defer close(progressCh)

retries := 0

for {
select {
case <-ctx.Done():
errCh <- ctx.Err()
return
case <-time.After(c.action.client.pollBackoffFunc(retries)):
retries++
}

a, _, err := c.GetByID(ctx, action.ID)
if err != nil {
errCh <- err
return
}
if a == nil {
errCh <- fmt.Errorf("failed to wait for action %d: action not returned from API", action.ID)
return
}

switch a.Status {
err := c.WaitForFunc(ctx, func(update *Action) error {
switch update.Status {
case ActionStatusRunning:
sendProgress(progressCh, a.Progress)
sendProgress(progressCh, update.Progress)
case ActionStatusSuccess:
sendProgress(progressCh, 100)
errCh <- nil
return
case ActionStatusError:
errCh <- a.Error()
return
// Do not wrap the action error
return update.Error()
}

return nil
}, action)

if err != nil {
errCh <- err
}
}()

Expand Down
83 changes: 36 additions & 47 deletions hcloud/action_watch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ import (
"errors"
"net/http"
"reflect"
"strings"
"testing"

"github.com/stretchr/testify/assert"

"github.com/hetznercloud/hcloud-go/v2/hcloud/schema"
)

Expand Down Expand Up @@ -127,7 +128,7 @@ func TestActionClientWatchOverallProgress(t *testing.T) {
t.Fatalf("expected hcloud.Error, but got: %#v", err)
}

expectedProgressUpdates := []int{50, 100}
expectedProgressUpdates := []int{25, 62, 100}
if !reflect.DeepEqual(progressUpdates, expectedProgressUpdates) {
t.Fatalf("expected progresses %v but received %v", expectedProgressUpdates, progressUpdates)
}
Expand Down Expand Up @@ -202,9 +203,7 @@ func TestActionClientWatchOverallProgressInvalidID(t *testing.T) {

err := errs[0]

if !strings.HasPrefix(err.Error(), "failed to wait for actions") {
t.Fatalf("expected failed to wait for actions error, but got: %#v", err)
}
assert.Equal(t, "actions not found: [1]", err.Error())

expectedProgressUpdates := []int{}
if !reflect.DeepEqual(progressUpdates, expectedProgressUpdates) {
Expand All @@ -218,39 +217,36 @@ func TestActionClientWatchProgress(t *testing.T) {

callCount := 0

env.Mux.HandleFunc("/actions/1", func(w http.ResponseWriter, r *http.Request) {
env.Mux.HandleFunc("/actions", func(w http.ResponseWriter, r *http.Request) {
callCount++
w.Header().Set("Content-Type", "application/json")
switch callCount {
case 1:
_ = json.NewEncoder(w).Encode(schema.ActionGetResponse{
Action: schema.Action{
ID: 1,
Status: "running",
Progress: 50,
},
})
_, _ = w.Write([]byte(`{
"actions": [
{ "id": 1, "status": "running", "progress": 50 }
],
"meta": { "pagination": { "page": 1 }}
}`))
case 2:
w.WriteHeader(http.StatusConflict)
_ = json.NewEncoder(w).Encode(schema.ErrorResponse{
Error: schema.Error{
Code: string(ErrorCodeConflict),
Message: "conflict",
},
})
_, _ = w.Write([]byte(`{
"error": {
"code": "conflict",
"message": "conflict"
}
}`))
return
case 3:
_ = json.NewEncoder(w).Encode(schema.ActionGetResponse{
Action: schema.Action{
ID: 1,
Status: "error",
Progress: 100,
Error: &schema.ActionError{
Code: "action_failed",
Message: "action failed",
},
},
})
_, _ = w.Write([]byte(`{
"actions": [
{ "id": 1, "status": "error", "progress": 100, "error": {
"code": "action_failed",
"message": "action failed"
} }
],
"meta": { "pagination": { "page": 1 }}
}`))
default:
t.Errorf("unexpected number of calls to the test server: %v", callCount)
}
Expand Down Expand Up @@ -293,7 +289,7 @@ func TestActionClientWatchProgressError(t *testing.T) {
env := newTestEnv()
defer env.Teardown()

env.Mux.HandleFunc("/actions/1", func(w http.ResponseWriter, r *http.Request) {
env.Mux.HandleFunc("/actions", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnprocessableEntity)
_ = json.NewEncoder(w).Encode(schema.ErrorResponse{
Expand All @@ -304,7 +300,7 @@ func TestActionClientWatchProgressError(t *testing.T) {
})
})

action := &Action{ID: 1}
action := &Action{ID: 1, Status: ActionStatusRunning}
ctx := context.Background()
_, errCh := env.Client.Action.WatchProgress(ctx, action)
if err := <-errCh; err == nil {
Expand All @@ -318,26 +314,20 @@ func TestActionClientWatchProgressInvalidID(t *testing.T) {

callCount := 0

env.Mux.HandleFunc("/actions/1", func(w http.ResponseWriter, r *http.Request) {
env.Mux.HandleFunc("/actions", func(w http.ResponseWriter, r *http.Request) {
callCount++
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusNotFound)
switch callCount {
case 1:
_ = json.NewEncoder(w).Encode(schema.ErrorResponse{
Error: schema.Error{
Code: string(ErrorCodeNotFound),
Message: "action with ID '1' not found",
Details: nil,
},
})
_, _ = w.Write([]byte(`{
"actions": [],
"meta": { "pagination": { "page": 1 }}
}`))
default:
t.Errorf("unexpected number of calls to the test server: %v", callCount)
}
})
action := &Action{
ID: 1,
}
action := &Action{ID: 1, Status: ActionStatusRunning}

ctx := context.Background()
progressCh, errCh := env.Client.Action.WatchProgress(ctx, action)
Expand All @@ -356,9 +346,8 @@ loop:
}
}

if !strings.HasPrefix(err.Error(), "failed to wait for action") {
t.Fatalf("expected failed to wait for action error, but got: %#v", err)
}
assert.Equal(t, "actions not found: [1]", err.Error())

if len(progressUpdates) != 0 {
t.Fatalf("unexpected progress updates: %v", progressUpdates)
}
Expand Down

0 comments on commit 704af1e

Please sign in to comment.