Skip to content

Commit

Permalink
download/download.go: fix switch between long polling and regular pol…
Browse files Browse the repository at this point in the history
…ling

download/download_test.go: testing about oneShot and download method

When 304 content-type should not be send. Fix the download code for switch to regular polling and keep the previous value.

Fixes: open-policy-agent#3923
Signed-off-by: Gasc Florian <florian.gasc@gmail.com>
  • Loading branch information
floriangasc committed Nov 21, 2021
1 parent 2b7df1c commit afcb2a8
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 62 deletions.
71 changes: 34 additions & 37 deletions download/download.go
Expand Up @@ -48,21 +48,22 @@ type Update struct {
// updates from the remote HTTP endpoint that the client is configured to
// connect to.
type Downloader struct {
config Config // downloader configuration for tuning polling and other downloader behaviour
client rest.Client // HTTP client to use for bundle downloading
path string // path to use in bundle download request
trigger chan chan struct{} // channel to signal downloads when manual triggering is enabled
stop chan chan struct{} // used to signal plugin to stop running
f func(context.Context, Update) // callback function invoked when download updates occur
etag string // HTTP Etag for caching purposes
sizeLimitBytes *int64 // max bundle file size in bytes (passed to reader)
bvc *bundle.VerificationConfig
respHdrTimeoutSec int64
wg sync.WaitGroup
logger logging.Logger
mtx sync.Mutex
stopped bool
persist bool
config Config // downloader configuration for tuning polling and other downloader behaviour
client rest.Client // HTTP client to use for bundle downloading
path string // path to use in bundle download request
trigger chan chan struct{} // channel to signal downloads when manual triggering is enabled
stop chan chan struct{} // used to signal plugin to stop running
f func(context.Context, Update) // callback function invoked when download updates occur
etag string // HTTP Etag for caching purposes
sizeLimitBytes *int64 // max bundle file size in bytes (passed to reader)
bvc *bundle.VerificationConfig
respHdrTimeoutSec int64
wg sync.WaitGroup
logger logging.Logger
mtx sync.Mutex
stopped bool
persist bool
longPollingEnabled bool
}

type downloaderResponse struct {
Expand All @@ -75,12 +76,13 @@ type downloaderResponse struct {
// New returns a new Downloader that can be started.
func New(config Config, client rest.Client, path string) *Downloader {
return &Downloader{
config: config,
client: client,
path: path,
trigger: make(chan chan struct{}),
stop: make(chan chan struct{}),
logger: client.Logger(),
config: config,
client: client,
path: path,
trigger: make(chan chan struct{}),
stop: make(chan chan struct{}),
logger: client.Logger(),
longPollingEnabled: config.Polling.LongPollingTimeoutSeconds != nil,
}
}

Expand Down Expand Up @@ -131,7 +133,7 @@ func (d *Downloader) Trigger(ctx context.Context) error {
done := make(chan error)

go func() {
_, err := d.oneShot(ctx)
err := d.oneShot(ctx)
if err != nil {
d.logger.Error("Bundle download failed: %v.", err)
if ctx.Err() == nil {
Expand Down Expand Up @@ -197,7 +199,7 @@ func (d *Downloader) loop(ctx context.Context) {

var delay time.Duration

longPoll, err := d.oneShot(ctx)
err := d.oneShot(ctx)

if ctx.Err() != nil {
return
Expand All @@ -206,16 +208,11 @@ func (d *Downloader) loop(ctx context.Context) {
if err != nil {
delay = util.DefaultBackoff(float64(minRetryDelay), float64(*d.config.Polling.MaxDelaySeconds), retry)
} else {
if !longPoll {
if d.config.Polling.LongPollingTimeoutSeconds != nil {
d.config.Polling.LongPollingTimeoutSeconds = nil
}

if !d.longPollingEnabled {
// revert the response header timeout value on the http client's transport
if *d.client.Config().ResponseHeaderTimeoutSeconds == 0 {
d.client = d.client.SetResponseHeaderTimeout(&d.respHdrTimeoutSec)
}

min := float64(*d.config.Polling.MinDelaySeconds)
max := float64(*d.config.Polling.MaxDelaySeconds)
delay = time.Duration(((max - min) * rand.Float64()) + min)
Expand All @@ -237,7 +234,7 @@ func (d *Downloader) loop(ctx context.Context) {
}
}

func (d *Downloader) oneShot(ctx context.Context) (bool, error) {
func (d *Downloader) oneShot(ctx context.Context) error {
m := metrics.New()
resp, err := d.download(ctx, m)

Expand All @@ -247,25 +244,23 @@ func (d *Downloader) oneShot(ctx context.Context) (bool, error) {
if d.f != nil {
d.f(ctx, Update{ETag: "", Bundle: nil, Error: err, Metrics: m, Raw: nil})
}

return false, err
return err
}

d.etag = resp.etag
d.longPollingEnabled = resp.longPoll

if d.f != nil {
d.f(ctx, Update{ETag: resp.etag, Bundle: resp.b, Error: nil, Metrics: m, Raw: resp.raw})
}

return resp.longPoll, nil
return nil
}

func (d *Downloader) download(ctx context.Context, m metrics.Metrics) (*downloaderResponse, error) {
d.logger.Debug("Download starting.")

d.client = d.client.WithHeader("If-None-Match", d.etag)

if d.config.Polling.LongPollingTimeoutSeconds != nil {
if d.longPollingEnabled && d.config.Polling.LongPollingTimeoutSeconds != nil {
d.client = d.client.WithHeader("Prefer", fmt.Sprintf("wait=%s", strconv.FormatInt(*d.config.Polling.LongPollingTimeoutSeconds, 10)))

// fetch existing response header timeout value on the http client's transport and
Expand All @@ -276,6 +271,8 @@ func (d *Downloader) download(ctx context.Context, m metrics.Metrics) (*download
t := int64(0)
d.client = d.client.SetResponseHeaderTimeout(&t)
}
} else {
d.client = d.client.WithHeader("Prefer", "wait=0")
}

resp, err := d.client.Do(ctx, "GET", d.path)
Expand Down Expand Up @@ -335,7 +332,7 @@ func (d *Downloader) download(ctx context.Context, m metrics.Metrics) (*download
b: nil,
raw: nil,
etag: etag,
longPoll: isLongPollSupported(resp.Header),
longPoll: d.longPollingEnabled,
}, nil
case http.StatusNotFound:
return nil, fmt.Errorf("server replied with not found")
Expand Down

0 comments on commit afcb2a8

Please sign in to comment.