diff --git a/download/download.go b/download/download.go index 24e60ba6b9..7fe0c4db9f 100644 --- a/download/download.go +++ b/download/download.go @@ -63,6 +63,7 @@ type Downloader struct { mtx sync.Mutex stopped bool persist bool + longPollingEnabled bool handleHttpResponseOk func(response *http.Response, m metrics.Metrics) (*bundle.Bundle, *bytes.Buffer, error) } @@ -76,12 +77,13 @@ type downloaderResponse struct { // New returns a new Downloader that can be started. func New(config Config, client rest.Client, path string) *Downloader { d := &Downloader{ - config: config, - client: client, - path: path, - trigger: make(chan chan struct{}), - stop: make(chan chan struct{}), - logger: client.Logger(), + config: config, + client: client, + path: path, + trigger: make(chan chan struct{}), + stop: make(chan chan struct{}), + logger: client.Logger(), + longPollingEnabled: config.Polling.LongPollingTimeoutSeconds != nil, } d.handleHttpResponseOk = d.defaultHandleHttpResponseOk return d @@ -139,7 +141,7 @@ func (d *Downloader) Trigger(ctx context.Context) error { done := make(chan error) go func() { - _, err := d.oneShot(ctx) + err := d.oneShot(ctx) if err != nil { d.logger.Error("Bundle download failed: %v.", err) if ctx.Err() == nil { @@ -205,7 +207,7 @@ func (d *Downloader) loop(ctx context.Context) { var delay time.Duration - longPoll, err := d.oneShot(ctx) + err := d.oneShot(ctx) if ctx.Err() != nil { return @@ -214,16 +216,11 @@ func (d *Downloader) loop(ctx context.Context) { if err != nil { delay = util.DefaultBackoff(float64(minRetryDelay), float64(*d.config.Polling.MaxDelaySeconds), retry) } else { - if !longPoll { - if d.config.Polling.LongPollingTimeoutSeconds != nil { - d.config.Polling.LongPollingTimeoutSeconds = nil - } - + if !d.longPollingEnabled { // revert the response header timeout value on the http client's transport if *d.client.Config().ResponseHeaderTimeoutSeconds == 0 { d.client = d.client.SetResponseHeaderTimeout(&d.respHdrTimeoutSec) } - min := float64(*d.config.Polling.MinDelaySeconds) max := float64(*d.config.Polling.MaxDelaySeconds) delay = time.Duration(((max - min) * rand.Float64()) + min) @@ -245,7 +242,7 @@ func (d *Downloader) loop(ctx context.Context) { } } -func (d *Downloader) oneShot(ctx context.Context) (bool, error) { +func (d *Downloader) oneShot(ctx context.Context) error { m := metrics.New() resp, err := d.download(ctx, m) @@ -255,25 +252,23 @@ func (d *Downloader) oneShot(ctx context.Context) (bool, error) { if d.f != nil { d.f(ctx, Update{ETag: "", Bundle: nil, Error: err, Metrics: m, Raw: nil}) } - - return false, err + return err } d.etag = resp.etag + d.longPollingEnabled = resp.longPoll if d.f != nil { d.f(ctx, Update{ETag: resp.etag, Bundle: resp.b, Error: nil, Metrics: m, Raw: resp.raw}) } - - return resp.longPoll, nil + return nil } func (d *Downloader) download(ctx context.Context, m metrics.Metrics) (*downloaderResponse, error) { d.logger.Debug("Download starting.") d.client = d.client.WithHeader("If-None-Match", d.etag) - - if d.config.Polling.LongPollingTimeoutSeconds != nil { + if d.longPollingEnabled && d.config.Polling.LongPollingTimeoutSeconds != nil { d.client = d.client.WithHeader("Prefer", fmt.Sprintf("wait=%s", strconv.FormatInt(*d.config.Polling.LongPollingTimeoutSeconds, 10))) // fetch existing response header timeout value on the http client's transport and @@ -284,6 +279,8 @@ func (d *Downloader) download(ctx context.Context, m metrics.Metrics) (*download t := int64(0) d.client = d.client.SetResponseHeaderTimeout(&t) } + } else { + d.client = d.client.WithHeader("Prefer", "wait=0") } resp, err := d.client.Do(ctx, "GET", d.path) @@ -311,7 +308,7 @@ func (d *Downloader) download(ctx context.Context, m metrics.Metrics) (*download b: nil, raw: nil, etag: etag, - longPoll: isLongPollSupported(resp.Header), + longPoll: d.longPollingEnabled, }, nil case http.StatusNotFound: return nil, fmt.Errorf("server replied with not found") diff --git a/download/download_test.go b/download/download_test.go index cda52da5e8..e5db61b54a 100644 --- a/download/download_test.go +++ b/download/download_test.go @@ -2,6 +2,7 @@ // Use of this source code is governed by an Apache2 // license that can be found in the LICENSE file. +//go:build slow // +build slow package download @@ -19,6 +20,7 @@ import ( "testing" "time" + "github.com/open-policy-agent/opa/metrics" "github.com/open-policy-agent/opa/plugins" "github.com/open-policy-agent/opa/bundle" @@ -258,7 +260,7 @@ func TestEtagCachingLifecycle(t *testing.T) { // simulate downloader error on first bundle download fixture.server.expCode = 500 fixture.server.expEtag = "some etag value" - _, err := fixture.d.oneShot(ctx) + err := fixture.d.oneShot(ctx) if err == nil { t.Fatal("Expected error but got nil") } else if len(fixture.updates) != 1 { @@ -269,7 +271,7 @@ func TestEtagCachingLifecycle(t *testing.T) { // simulate successful bundle activation and check updated etag on the downloader fixture.server.expCode = 0 - _, err = fixture.d.oneShot(ctx) + err = fixture.d.oneShot(ctx) if err != nil { t.Fatal("Unexpected:", err) } else if len(fixture.updates) != 2 { @@ -280,7 +282,7 @@ func TestEtagCachingLifecycle(t *testing.T) { // simulate another successful bundle activation and check updated etag on the downloader fixture.server.expEtag = "some etag value - 2" - _, err = fixture.d.oneShot(ctx) + err = fixture.d.oneShot(ctx) if err != nil { t.Fatal("Unexpected:", err) } else if len(fixture.updates) != 3 { @@ -292,7 +294,7 @@ func TestEtagCachingLifecycle(t *testing.T) { // simulate bundle activation error and check etag is set from the last successful activation fixture.mockBundleActivationError = true fixture.server.expEtag = "some newer etag value - 3" - _, err = fixture.d.oneShot(ctx) + err = fixture.d.oneShot(ctx) if err != nil { t.Fatal("Unexpected:", err) } else if len(fixture.updates) != 4 { @@ -304,7 +306,7 @@ func TestEtagCachingLifecycle(t *testing.T) { // simulate successful bundle activation and check updated etag on the downloader fixture.server.expCode = 0 fixture.mockBundleActivationError = false - _, err = fixture.d.oneShot(ctx) + err = fixture.d.oneShot(ctx) if err != nil { t.Fatal("Unexpected:", err) } else if len(fixture.updates) != 5 { @@ -315,7 +317,7 @@ func TestEtagCachingLifecycle(t *testing.T) { // simulate downloader error and check etag is set from the last successful activation fixture.server.expCode = 500 - _, err = fixture.d.oneShot(ctx) + err = fixture.d.oneShot(ctx) if err == nil { t.Fatal("Expected error but got nil") } else if len(fixture.updates) != 6 { @@ -327,7 +329,7 @@ func TestEtagCachingLifecycle(t *testing.T) { // simulate bundle activation error and check etag is set from the last successful activation fixture.mockBundleActivationError = true fixture.server.expCode = 0 - _, err = fixture.d.oneShot(ctx) + err = fixture.d.oneShot(ctx) if err != nil { t.Fatal("Unexpected:", err) } else if len(fixture.updates) != 7 { @@ -346,7 +348,7 @@ func TestFailureAuthn(t *testing.T) { d := New(Config{}, fixture.client, "/bundles/test/bundle1") - _, err := d.oneShot(ctx) + err := d.oneShot(ctx) if err == nil { t.Fatal("expected error") } @@ -361,7 +363,7 @@ func TestFailureNotFound(t *testing.T) { d := New(Config{}, fixture.client, "/bundles/test/non-existent") - _, err := d.oneShot(ctx) + err := d.oneShot(ctx) if err == nil { t.Fatal("expected error") } @@ -376,7 +378,7 @@ func TestFailureUnexpected(t *testing.T) { d := New(Config{}, fixture.client, "/bundles/test/bundle1") - _, err := d.oneShot(ctx) + err := d.oneShot(ctx) if err == nil { t.Fatal("expected error") } @@ -395,7 +397,7 @@ func TestEtagInResponse(t *testing.T) { fixture.server.expEtag = "some etag value" - _, err := fixture.d.oneShot(ctx) + err := fixture.d.oneShot(ctx) if err != nil { t.Fatal("Unexpected:", err) } else if len(fixture.updates) != 1 { @@ -409,7 +411,7 @@ func TestEtagInResponse(t *testing.T) { t.Errorf("Expected bundle in response") } - _, err = fixture.d.oneShot(ctx) + err = fixture.d.oneShot(ctx) if err != nil { t.Fatal("Unexpected:", err) } else if len(fixture.updates) != 2 { @@ -518,6 +520,99 @@ func TestTriggerManualWithTimeout(t *testing.T) { d.Stop(ctx) } +func TestDownloadLongPollNotModifiedOn304(t *testing.T) { + + ctx := context.Background() + config := Config{} + timeout := int64(3) // this will result in the test server sleeping for 3 seconds + config.Polling.LongPollingTimeoutSeconds = &timeout + + if err := config.ValidateAndInjectDefaults(); err != nil { + t.Fatal(err) + } + + fixture := newTestFixture(t) + fixture.d = New(config, fixture.client, "/bundles/test/bundle1").WithCallback(fixture.oneShot) + fixture.server.longPoll = true + fixture.server.expEtag = "foo" + fixture.d.etag = fixture.server.expEtag // not modified + fixture.server.expCode = 0 + defer fixture.server.stop() + + resp, err := fixture.d.download(ctx, metrics.New()) + if err != nil { + t.Fatal("Unexpected:", err) + } + if resp.longPoll != fixture.d.longPollingEnabled { + t.Fatalf("Expected same value for longPoll and longPollingEnabled") + } + +} + +func TestOneShotLongPollingSwitch(t *testing.T) { + ctx := context.Background() + config := Config{} + timeout := int64(3) // this will result in the test server sleeping for 3 seconds + config.Polling.LongPollingTimeoutSeconds = &timeout + if err := config.ValidateAndInjectDefaults(); err != nil { + t.Fatal(err) + } + fixture := newTestFixture(t) + fixture.d = New(config, fixture.client, "/bundles/test/bundle1").WithCallback(fixture.oneShot) + fixture.server.expCode = 0 + defer fixture.server.stop() + + fixture.server.longPoll = true + err := fixture.d.oneShot(ctx) + if err != nil { + t.Fatal("Unexpected:", err) + } + if fixture.d.longPollingEnabled != fixture.server.longPoll { + t.Fatalf("Expected same value for longPoll and longPollingEnabled") + } + + fixture.server.longPoll = false + err = fixture.d.oneShot(ctx) + if err != nil { + t.Fatal("Unexpected:", err) + } + if fixture.d.longPollingEnabled != fixture.server.longPoll { + t.Fatalf("Expected same value for longPollingEnabled and longPoll") + } +} + +func TestOneShotNotLongPollingSwitch(t *testing.T) { + ctx := context.Background() + config := Config{} + config.Polling.LongPollingTimeoutSeconds = nil + if err := config.ValidateAndInjectDefaults(); err != nil { + t.Fatal(err) + } + fixture := newTestFixture(t) + fixture.d = New(config, fixture.client, "/bundles/test/bundle1").WithCallback(fixture.oneShot) + fixture.server.expCode = 0 + + defer fixture.server.stop() + + fixture.server.longPoll = true + err := fixture.d.oneShot(ctx) + if err != nil { + t.Fatal("Unexpected:", err) + } + if fixture.d.longPollingEnabled != true { + t.Fatal("Expected long polling to be enabled") + } + + fixture.server.longPoll = false + err = fixture.d.oneShot(ctx) + if err != nil { + t.Fatal("Unexpected:", err) + } + if fixture.d.longPollingEnabled { + t.Fatal("Expected long polling to be disabled") + } +} + type testFixture struct { d *Downloader client rest.Client @@ -601,14 +696,15 @@ func (t *testFixture) oneShot(ctx context.Context, u Update) { } type testServer struct { - t *testing.T - expCode int - expEtag string - expAuth string - bundles map[string]bundle.Bundle - server *httptest.Server - etagInResponse bool - longPoll bool + t *testing.T + expCode int + expEtag string + expAuth string + bundles map[string]bundle.Bundle + server *httptest.Server + etagInResponse bool + longPoll bool + opaVendorMediaTypeEnabled bool } func (t *testServer) handle(w http.ResponseWriter, r *http.Request) { @@ -624,13 +720,8 @@ func (t *testServer) handle(w http.ResponseWriter, r *http.Request) { panic(err) } - // indicate server supports long polling - w.Header().Add("Content-Type", "application/vnd.openpolicyagent.bundles") - // simulate long operation time.Sleep(time.Duration(timeout) * time.Second) - } else { - w.Header().Add("Content-Type", "application/gzip") } if t.expCode != 0 { @@ -652,9 +743,11 @@ func (t *testServer) handle(w http.ResponseWriter, r *http.Request) { return } + contentTypeShouldBeSend := true if t.expEtag != "" { etag := r.Header.Get("If-None-Match") if etag == t.expEtag { + contentTypeShouldBeSend = false if t.etagInResponse { w.Header().Add("Etag", t.expEtag) } @@ -663,6 +756,13 @@ func (t *testServer) handle(w http.ResponseWriter, r *http.Request) { } } + if t.longPoll && contentTypeShouldBeSend { + // in 304 Content-Type is not send according https://datatracker.ietf.org/doc/html/rfc7232#section-4.1 + w.Header().Add("Content-Type", "application/vnd.openpolicyagent.bundles") + } else { + w.Header().Add("Content-Type", "application/gzip") + } + if t.expEtag != "" { w.Header().Add("Etag", t.expEtag) }