Skip to content

Commit

Permalink
download/download.go: fix switch between long polling and regular pol…
Browse files Browse the repository at this point in the history
…ling

download/download_test.go: testing about oneShot and download method

When 304 content-type should not be send. Fix the download code for switch to regular polling and keep the previous value.

Fixes: open-policy-agent#3923
Signed-off-by: Gasc Florian <florian.gasc@gmail.com>
  • Loading branch information
floriangasc authored and floriangascsimplifia committed Nov 22, 2021
1 parent 52e82e5 commit c265009
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 47 deletions.
41 changes: 19 additions & 22 deletions download/download.go
Expand Up @@ -63,6 +63,7 @@ type Downloader struct {
mtx sync.Mutex
stopped bool
persist bool
longPollingEnabled bool
handleHttpResponseOk func(response *http.Response, m metrics.Metrics) (*bundle.Bundle, *bytes.Buffer, error)
}

Expand All @@ -76,12 +77,13 @@ type downloaderResponse struct {
// New returns a new Downloader that can be started.
func New(config Config, client rest.Client, path string) *Downloader {
d := &Downloader{
config: config,
client: client,
path: path,
trigger: make(chan chan struct{}),
stop: make(chan chan struct{}),
logger: client.Logger(),
config: config,
client: client,
path: path,
trigger: make(chan chan struct{}),
stop: make(chan chan struct{}),
logger: client.Logger(),
longPollingEnabled: config.Polling.LongPollingTimeoutSeconds != nil,
}
d.handleHttpResponseOk = d.defaultHandleHttpResponseOk
return d
Expand Down Expand Up @@ -139,7 +141,7 @@ func (d *Downloader) Trigger(ctx context.Context) error {
done := make(chan error)

go func() {
_, err := d.oneShot(ctx)
err := d.oneShot(ctx)
if err != nil {
d.logger.Error("Bundle download failed: %v.", err)
if ctx.Err() == nil {
Expand Down Expand Up @@ -205,7 +207,7 @@ func (d *Downloader) loop(ctx context.Context) {

var delay time.Duration

longPoll, err := d.oneShot(ctx)
err := d.oneShot(ctx)

if ctx.Err() != nil {
return
Expand All @@ -214,16 +216,11 @@ func (d *Downloader) loop(ctx context.Context) {
if err != nil {
delay = util.DefaultBackoff(float64(minRetryDelay), float64(*d.config.Polling.MaxDelaySeconds), retry)
} else {
if !longPoll {
if d.config.Polling.LongPollingTimeoutSeconds != nil {
d.config.Polling.LongPollingTimeoutSeconds = nil
}

if !d.longPollingEnabled {
// revert the response header timeout value on the http client's transport
if *d.client.Config().ResponseHeaderTimeoutSeconds == 0 {
d.client = d.client.SetResponseHeaderTimeout(&d.respHdrTimeoutSec)
}

min := float64(*d.config.Polling.MinDelaySeconds)
max := float64(*d.config.Polling.MaxDelaySeconds)
delay = time.Duration(((max - min) * rand.Float64()) + min)
Expand All @@ -245,7 +242,7 @@ func (d *Downloader) loop(ctx context.Context) {
}
}

func (d *Downloader) oneShot(ctx context.Context) (bool, error) {
func (d *Downloader) oneShot(ctx context.Context) error {
m := metrics.New()
resp, err := d.download(ctx, m)

Expand All @@ -255,25 +252,23 @@ func (d *Downloader) oneShot(ctx context.Context) (bool, error) {
if d.f != nil {
d.f(ctx, Update{ETag: "", Bundle: nil, Error: err, Metrics: m, Raw: nil})
}

return false, err
return err
}

d.etag = resp.etag
d.longPollingEnabled = resp.longPoll

if d.f != nil {
d.f(ctx, Update{ETag: resp.etag, Bundle: resp.b, Error: nil, Metrics: m, Raw: resp.raw})
}

return resp.longPoll, nil
return nil
}

func (d *Downloader) download(ctx context.Context, m metrics.Metrics) (*downloaderResponse, error) {
d.logger.Debug("Download starting.")

d.client = d.client.WithHeader("If-None-Match", d.etag)

if d.config.Polling.LongPollingTimeoutSeconds != nil {
if d.longPollingEnabled && d.config.Polling.LongPollingTimeoutSeconds != nil {
d.client = d.client.WithHeader("Prefer", fmt.Sprintf("wait=%s", strconv.FormatInt(*d.config.Polling.LongPollingTimeoutSeconds, 10)))

// fetch existing response header timeout value on the http client's transport and
Expand All @@ -284,6 +279,8 @@ func (d *Downloader) download(ctx context.Context, m metrics.Metrics) (*download
t := int64(0)
d.client = d.client.SetResponseHeaderTimeout(&t)
}
} else {
d.client = d.client.WithHeader("Prefer", "wait=0")
}

resp, err := d.client.Do(ctx, "GET", d.path)
Expand Down Expand Up @@ -311,7 +308,7 @@ func (d *Downloader) download(ctx context.Context, m metrics.Metrics) (*download
b: nil,
raw: nil,
etag: etag,
longPoll: isLongPollSupported(resp.Header),
longPoll: d.longPollingEnabled,
}, nil
case http.StatusNotFound:
return nil, fmt.Errorf("server replied with not found")
Expand Down

0 comments on commit c265009

Please sign in to comment.