diff --git a/backend/data.go b/backend/data.go index d747de17a..2421e6628 100644 --- a/backend/data.go +++ b/backend/data.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "errors" + "net/http" "time" "github.com/grafana/grafana-plugin-sdk-go/data" @@ -37,9 +38,50 @@ func (fn QueryDataHandlerFunc) QueryData(ctx context.Context, req *QueryDataRequ // QueryDataRequest contains a single request which contains multiple queries. // It is the input type for a QueryData call. type QueryDataRequest struct { + // PluginContext the contextual information for the request. PluginContext PluginContext - Headers map[string]string - Queries []DataQuery + + // Headers the environment/metadata information for the request. + // + // To access forwarded HTTP headers please use + // GetHTTPHeaders or GetHTTPHeader. + Headers map[string]string + + // Queries the data queries for the request. + Queries []DataQuery +} + +// SetHTTPHeader sets the header entries associated with key to the +// single element value. It replaces any existing values +// associated with key. The key is case insensitive; it is +// canonicalized by textproto.CanonicalMIMEHeaderKey. +func (req *QueryDataRequest) SetHTTPHeader(key, value string) { + if req.Headers == nil { + req.Headers = map[string]string{} + } + + setHTTPHeaderInStringMap(req.Headers, key, value) +} + +// DeleteHTTPHeader deletes the values associated with key. +// The key is case insensitive; it is canonicalized by +// CanonicalHeaderKey. +func (req *QueryDataRequest) DeleteHTTPHeader(key string) { + deleteHTTPHeaderInStringMap(req.Headers, key) +} + +// GetHTTPHeader gets the first value associated with the given key. If +// there are no values associated with the key, Get returns "". +// It is case insensitive; textproto.CanonicalMIMEHeaderKey is +// used to canonicalize the provided key. Get assumes that all +// keys are stored in canonical form. +func (req *QueryDataRequest) GetHTTPHeader(key string) string { + return req.GetHTTPHeaders().Get(key) +} + +// GetHTTPHeaders returns HTTP headers. +func (req *QueryDataRequest) GetHTTPHeaders() http.Header { + return getHTTPHeadersFromStringMap(req.Headers) } // DataQuery represents a single query as sent from the frontend. @@ -119,6 +161,7 @@ type DataResponse struct { Status Status } +// ErrDataResponse returns an error DataResponse given status and message. func ErrDataResponse(status Status, message string) DataResponse { return DataResponse{ Error: errors.New(message), @@ -149,3 +192,5 @@ type TimeRange struct { func (tr TimeRange) Duration() time.Duration { return tr.To.Sub(tr.From) } + +var _ ForwardHTTPHeaders = (*QueryDataRequest)(nil) diff --git a/backend/data_test.go b/backend/data_test.go new file mode 100644 index 000000000..140b26c19 --- /dev/null +++ b/backend/data_test.go @@ -0,0 +1,105 @@ +package backend + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestQueryDataRequest(t *testing.T) { + req := &QueryDataRequest{} + const customHeaderName = "X-Custom" + + t.Run("Legacy headers", func(t *testing.T) { + req.Headers = map[string]string{ + "Authorization": "a", + "X-ID-Token": "b", + "Cookie": "c", + customHeaderName: "d", + } + + t.Run("GetHTTPHeaders canonical form", func(t *testing.T) { + headers := req.GetHTTPHeaders() + require.Equal(t, "a", headers.Get(OAuthIdentityTokenHeaderName)) + require.Equal(t, "b", headers.Get(OAuthIdentityIDTokenHeaderName)) + require.Equal(t, "c", headers.Get(CookiesHeaderName)) + require.Empty(t, headers.Get(customHeaderName)) + }) + + t.Run("GetHTTPHeader canonical form", func(t *testing.T) { + require.Equal(t, "a", req.GetHTTPHeader(OAuthIdentityTokenHeaderName)) + require.Equal(t, "b", req.GetHTTPHeader(OAuthIdentityIDTokenHeaderName)) + require.Equal(t, "c", req.GetHTTPHeader(CookiesHeaderName)) + require.Empty(t, req.GetHTTPHeader(customHeaderName)) + }) + + t.Run("DeleteHTTPHeader canonical form", func(t *testing.T) { + req.DeleteHTTPHeader(OAuthIdentityTokenHeaderName) + req.DeleteHTTPHeader(OAuthIdentityIDTokenHeaderName) + req.DeleteHTTPHeader(CookiesHeaderName) + req.DeleteHTTPHeader(customHeaderName) + require.Empty(t, req.Headers) + }) + }) + + t.Run("SetHTTPHeader canonical form", func(t *testing.T) { + req.SetHTTPHeader(OAuthIdentityTokenHeaderName, "a") + req.SetHTTPHeader(OAuthIdentityIDTokenHeaderName, "b") + req.SetHTTPHeader(CookiesHeaderName, "c") + req.SetHTTPHeader(customHeaderName, "d") + + t.Run("GetHTTPHeaders canonical form", func(t *testing.T) { + headers := req.GetHTTPHeaders() + require.Equal(t, "a", headers.Get(OAuthIdentityTokenHeaderName)) + require.Equal(t, "b", headers.Get(OAuthIdentityIDTokenHeaderName)) + require.Equal(t, "c", headers.Get(CookiesHeaderName)) + require.Equal(t, "d", headers.Get(customHeaderName)) + }) + + t.Run("GetHTTPHeader canonical form", func(t *testing.T) { + require.Equal(t, "a", req.GetHTTPHeader(OAuthIdentityTokenHeaderName)) + require.Equal(t, "b", req.GetHTTPHeader(OAuthIdentityIDTokenHeaderName)) + require.Equal(t, "c", req.GetHTTPHeader(CookiesHeaderName)) + require.Equal(t, "d", req.GetHTTPHeader(customHeaderName)) + }) + + t.Run("DeleteHTTPHeader canonical form", func(t *testing.T) { + req.DeleteHTTPHeader(OAuthIdentityTokenHeaderName) + req.DeleteHTTPHeader(OAuthIdentityIDTokenHeaderName) + req.DeleteHTTPHeader(CookiesHeaderName) + req.DeleteHTTPHeader(customHeaderName) + require.Empty(t, req.Headers) + }) + }) + + t.Run("SetHTTPHeader non-canonical form", func(t *testing.T) { + req.SetHTTPHeader(strings.ToLower(OAuthIdentityTokenHeaderName), "a") + req.SetHTTPHeader(strings.ToLower(OAuthIdentityIDTokenHeaderName), "b") + req.SetHTTPHeader(strings.ToLower(CookiesHeaderName), "c") + req.SetHTTPHeader(strings.ToLower(customHeaderName), "d") + + t.Run("GetHTTPHeaders non-canonical form", func(t *testing.T) { + headers := req.GetHTTPHeaders() + require.Equal(t, "a", headers.Get(strings.ToLower(OAuthIdentityTokenHeaderName))) + require.Equal(t, "b", headers.Get(strings.ToLower(OAuthIdentityIDTokenHeaderName))) + require.Equal(t, "c", headers.Get(strings.ToLower(CookiesHeaderName))) + require.Equal(t, "d", headers.Get(strings.ToLower(customHeaderName))) + }) + + t.Run("GetHTTPHeader non-canonical form", func(t *testing.T) { + require.Equal(t, "a", req.GetHTTPHeader(strings.ToLower(OAuthIdentityTokenHeaderName))) + require.Equal(t, "b", req.GetHTTPHeader(strings.ToLower(OAuthIdentityIDTokenHeaderName))) + require.Equal(t, "c", req.GetHTTPHeader(strings.ToLower(CookiesHeaderName))) + require.Equal(t, "d", req.GetHTTPHeader(strings.ToLower(customHeaderName))) + }) + + t.Run("DeleteHTTPHeader non-canonical form", func(t *testing.T) { + req.DeleteHTTPHeader(strings.ToLower(OAuthIdentityTokenHeaderName)) + req.DeleteHTTPHeader(strings.ToLower(OAuthIdentityIDTokenHeaderName)) + req.DeleteHTTPHeader(strings.ToLower(CookiesHeaderName)) + req.DeleteHTTPHeader(strings.ToLower(customHeaderName)) + require.Empty(t, req.Headers) + }) + }) +} diff --git a/backend/diagnostics.go b/backend/diagnostics.go index 5631aad12..49d48fe20 100644 --- a/backend/diagnostics.go +++ b/backend/diagnostics.go @@ -2,6 +2,7 @@ package backend import ( "context" + "net/http" "strconv" ) @@ -53,14 +54,58 @@ func (hs HealthStatus) String() string { // CheckHealthRequest contains the healthcheck request type CheckHealthRequest struct { + // PluginContext the contextual information for the request. PluginContext PluginContext - Headers map[string]string + + // Headers the environment/metadata information for the request. + // + // To access forwarded HTTP headers please use + // GetHTTPHeaders or GetHTTPHeader. + Headers map[string]string +} + +// SetHTTPHeader sets the header entries associated with key to the +// single element value. It replaces any existing values +// associated with key. The key is case insensitive; it is +// canonicalized by textproto.CanonicalMIMEHeaderKey. +func (req *CheckHealthRequest) SetHTTPHeader(key, value string) { + if req.Headers == nil { + req.Headers = map[string]string{} + } + + setHTTPHeaderInStringMap(req.Headers, key, value) +} + +// DeleteHTTPHeader deletes the values associated with key. +// The key is case insensitive; it is canonicalized by +// CanonicalHeaderKey. +func (req *CheckHealthRequest) DeleteHTTPHeader(key string) { + deleteHTTPHeaderInStringMap(req.Headers, key) +} + +// GetHTTPHeader gets the first value associated with the given key. If +// there are no values associated with the key, Get returns "". +// It is case insensitive; textproto.CanonicalMIMEHeaderKey is +// used to canonicalize the provided key. Get assumes that all +// keys are stored in canonical form. +func (req *CheckHealthRequest) GetHTTPHeader(key string) string { + return req.GetHTTPHeaders().Get(key) +} + +// GetHTTPHeaders returns HTTP headers. +func (req *CheckHealthRequest) GetHTTPHeaders() http.Header { + return getHTTPHeadersFromStringMap(req.Headers) } // CheckHealthResult contains the healthcheck response type CheckHealthResult struct { - Status HealthStatus - Message string + // Status the HealthStatus of the healthcheck. + Status HealthStatus + + // Message the message of the healthcheck, if any. + Message string + + // JSONDetails the details of the healthcheck, if any, encoded as JSON bytes. JSONDetails []byte } @@ -82,10 +127,14 @@ func (fn CollectMetricsHandlerFunc) CollectMetrics(ctx context.Context, req *Col // CollectMetricsRequest contains the metrics request type CollectMetricsRequest struct { + // PluginContext the contextual information for the request. PluginContext PluginContext } // CollectMetricsResult collect metrics result. type CollectMetricsResult struct { + // PrometheusMetrics the Prometheus metrics encoded as bytes. PrometheusMetrics []byte } + +var _ ForwardHTTPHeaders = (*CheckHealthRequest)(nil) diff --git a/backend/diagnostics_test.go b/backend/diagnostics_test.go new file mode 100644 index 000000000..56b0057e3 --- /dev/null +++ b/backend/diagnostics_test.go @@ -0,0 +1,105 @@ +package backend + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCheckHealthRequest(t *testing.T) { + req := &CheckHealthRequest{} + const customHeaderName = "X-Custom" + + t.Run("Legacy headers", func(t *testing.T) { + req.Headers = map[string]string{ + "Authorization": "a", + "X-ID-Token": "b", + "Cookie": "c", + customHeaderName: "d", + } + + t.Run("GetHTTPHeaders canonical form", func(t *testing.T) { + headers := req.GetHTTPHeaders() + require.Equal(t, "a", headers.Get(OAuthIdentityTokenHeaderName)) + require.Equal(t, "b", headers.Get(OAuthIdentityIDTokenHeaderName)) + require.Equal(t, "c", headers.Get(CookiesHeaderName)) + require.Empty(t, headers.Get(customHeaderName)) + }) + + t.Run("GetHTTPHeader canonical form", func(t *testing.T) { + require.Equal(t, "a", req.GetHTTPHeader(OAuthIdentityTokenHeaderName)) + require.Equal(t, "b", req.GetHTTPHeader(OAuthIdentityIDTokenHeaderName)) + require.Equal(t, "c", req.GetHTTPHeader(CookiesHeaderName)) + require.Empty(t, req.GetHTTPHeader(customHeaderName)) + }) + + t.Run("DeleteHTTPHeader canonical form", func(t *testing.T) { + req.DeleteHTTPHeader(OAuthIdentityTokenHeaderName) + req.DeleteHTTPHeader(OAuthIdentityIDTokenHeaderName) + req.DeleteHTTPHeader(CookiesHeaderName) + req.DeleteHTTPHeader(customHeaderName) + require.Empty(t, req.Headers) + }) + }) + + t.Run("SetHTTPHeader canonical form", func(t *testing.T) { + req.SetHTTPHeader(OAuthIdentityTokenHeaderName, "a") + req.SetHTTPHeader(OAuthIdentityIDTokenHeaderName, "b") + req.SetHTTPHeader(CookiesHeaderName, "c") + req.SetHTTPHeader(customHeaderName, "d") + + t.Run("GetHTTPHeaders canonical form", func(t *testing.T) { + headers := req.GetHTTPHeaders() + require.Equal(t, "a", headers.Get(OAuthIdentityTokenHeaderName)) + require.Equal(t, "b", headers.Get(OAuthIdentityIDTokenHeaderName)) + require.Equal(t, "c", headers.Get(CookiesHeaderName)) + require.Equal(t, "d", headers.Get(customHeaderName)) + }) + + t.Run("GetHTTPHeader canonical form", func(t *testing.T) { + require.Equal(t, "a", req.GetHTTPHeader(OAuthIdentityTokenHeaderName)) + require.Equal(t, "b", req.GetHTTPHeader(OAuthIdentityIDTokenHeaderName)) + require.Equal(t, "c", req.GetHTTPHeader(CookiesHeaderName)) + require.Equal(t, "d", req.GetHTTPHeader(customHeaderName)) + }) + + t.Run("DeleteHTTPHeader canonical form", func(t *testing.T) { + req.DeleteHTTPHeader(OAuthIdentityTokenHeaderName) + req.DeleteHTTPHeader(OAuthIdentityIDTokenHeaderName) + req.DeleteHTTPHeader(CookiesHeaderName) + req.DeleteHTTPHeader(customHeaderName) + require.Empty(t, req.Headers) + }) + }) + + t.Run("SetHTTPHeader non-canonical form", func(t *testing.T) { + req.SetHTTPHeader(strings.ToLower(OAuthIdentityTokenHeaderName), "a") + req.SetHTTPHeader(strings.ToLower(OAuthIdentityIDTokenHeaderName), "b") + req.SetHTTPHeader(strings.ToLower(CookiesHeaderName), "c") + req.SetHTTPHeader(strings.ToLower(customHeaderName), "d") + + t.Run("GetHTTPHeaders non-canonical form", func(t *testing.T) { + headers := req.GetHTTPHeaders() + require.Equal(t, "a", headers.Get(strings.ToLower(OAuthIdentityTokenHeaderName))) + require.Equal(t, "b", headers.Get(strings.ToLower(OAuthIdentityIDTokenHeaderName))) + require.Equal(t, "c", headers.Get(strings.ToLower(CookiesHeaderName))) + require.Equal(t, "d", headers.Get(strings.ToLower(customHeaderName))) + }) + + t.Run("GetHTTPHeader non-canonical form", func(t *testing.T) { + require.Equal(t, "a", req.GetHTTPHeader(strings.ToLower(OAuthIdentityTokenHeaderName))) + require.Equal(t, "b", req.GetHTTPHeader(strings.ToLower(OAuthIdentityIDTokenHeaderName))) + require.Equal(t, "c", req.GetHTTPHeader(strings.ToLower(CookiesHeaderName))) + require.Equal(t, "d", req.GetHTTPHeader(strings.ToLower(customHeaderName))) + }) + + t.Run("DeleteHTTPHeader non-canonical form", func(t *testing.T) { + req.DeleteHTTPHeader(strings.ToLower(OAuthIdentityTokenHeaderName)) + req.DeleteHTTPHeader(strings.ToLower(OAuthIdentityIDTokenHeaderName)) + req.DeleteHTTPHeader(strings.ToLower(CookiesHeaderName)) + req.DeleteHTTPHeader(strings.ToLower(customHeaderName)) + require.Empty(t, req.Headers) + }) + }) +} diff --git a/backend/http_headers.go b/backend/http_headers.go new file mode 100644 index 000000000..689d8ae59 --- /dev/null +++ b/backend/http_headers.go @@ -0,0 +1,91 @@ +package backend + +import ( + "fmt" + "net/http" + "net/textproto" + "strings" +) + +const ( + // OAuthIdentityTokenHeaderName the header name used for forwarding + // OAuth Identity access token. + OAuthIdentityTokenHeaderName = "Authorization" + + // OAuthIdentityIDTokenHeaderName the header name used for forwarding + // OAuth Identity ID token. + OAuthIdentityIDTokenHeaderName = "X-Id-Token" + + // CookiesHeaderName the header name used for forwarding + // cookies. + CookiesHeaderName = "Cookie" + + httpHeaderPrefix = "http_" +) + +// ForwardHTTPHeaders interface marking that forward of HTTP headers is supported. +type ForwardHTTPHeaders interface { + // SetHTTPHeader sets the header entries associated with key to the + // single element value. It replaces any existing values + // associated with key. The key is case insensitive; it is + // canonicalized by textproto.CanonicalMIMEHeaderKey. + SetHTTPHeader(key, value string) + + // DeleteHTTPHeader deletes the values associated with key. + // The key is case insensitive; it is canonicalized by + // CanonicalHeaderKey. + DeleteHTTPHeader(key string) + + // GetHTTPHeader gets the first value associated with the given key. If + // there are no values associated with the key, Get returns "". + // It is case insensitive; textproto.CanonicalMIMEHeaderKey is + // used to canonicalize the provided key. Get assumes that all + // keys are stored in canonical form. + GetHTTPHeader(key string) string + + // GetHTTPHeaders returns HTTP headers. + GetHTTPHeaders() http.Header +} + +func setHTTPHeaderInStringMap(headers map[string]string, key string, value string) { + if headers == nil { + headers = map[string]string{} + } + + headers[fmt.Sprintf("%s%s", httpHeaderPrefix, key)] = value +} + +func getHTTPHeadersFromStringMap(headers map[string]string) http.Header { + httpHeaders := http.Header{} + + for k, v := range headers { + if textproto.CanonicalMIMEHeaderKey(k) == OAuthIdentityTokenHeaderName { + httpHeaders.Set(k, v) + } + + if textproto.CanonicalMIMEHeaderKey(k) == OAuthIdentityIDTokenHeaderName { + httpHeaders.Set(k, v) + } + + if textproto.CanonicalMIMEHeaderKey(k) == CookiesHeaderName { + httpHeaders.Set(k, v) + } + + if strings.HasPrefix(k, httpHeaderPrefix) { + hKey := strings.TrimPrefix(k, httpHeaderPrefix) + httpHeaders.Set(hKey, v) + } + } + + return httpHeaders +} + +func deleteHTTPHeaderInStringMap(headers map[string]string, key string) { + for k := range headers { + if textproto.CanonicalMIMEHeaderKey(k) == textproto.CanonicalMIMEHeaderKey(key) || + textproto.CanonicalMIMEHeaderKey(k) == textproto.CanonicalMIMEHeaderKey(fmt.Sprintf("%s%s", httpHeaderPrefix, key)) { + delete(headers, k) + break + } + } +} diff --git a/backend/http_headers_test.go b/backend/http_headers_test.go new file mode 100644 index 000000000..7b4070a6c --- /dev/null +++ b/backend/http_headers_test.go @@ -0,0 +1,210 @@ +package backend + +import ( + "testing" + + "github.com/davecgh/go-spew/spew" + "github.com/stretchr/testify/require" +) + +func TestSetHTTPHeaderInStringMap(t *testing.T) { + tcs := []struct { + input map[string]string + expected map[string]string + }{ + { + expected: map[string]string{ + "": "", + "a": "", + }, + }, + { + input: map[string]string{ + "authorization": "a", + "x-id-token": "b", + "cookie": "c", + "x-custom": "d", + }, + expected: map[string]string{ + "": "", + "a": "", + "authorization": "a", + "Authorization": "a", + "x-id-token": "b", + "X-Id-Token": "b", + "cookie": "c", + "Cookie": "c", + "x-custom": "d", + "X-Custom": "d", + }, + }, + { + input: map[string]string{ + "Authorization": "a", + "X-ID-Token": "b", + "Cookie": "c", + "X-Custom": "d", + }, + expected: map[string]string{ + "": "", + "a": "", + "authorization": "a", + "Authorization": "a", + "x-id-token": "b", + "X-Id-Token": "b", + "cookie": "c", + "Cookie": "c", + "x-custom": "d", + "X-Custom": "d", + }, + }, + } + + for _, tc := range tcs { + headerMap := map[string]string{} + for k, v := range tc.input { + setHTTPHeaderInStringMap(headerMap, k, v) + } + headers := getHTTPHeadersFromStringMap(headerMap) + spew.Dump(headers) + + for k, v := range tc.expected { + require.Equal(t, v, headers.Get(k)) + } + } +} + +func TestGetHTTPHeadersFromStringMap(t *testing.T) { + tcs := []struct { + input map[string]string + expected map[string]string + }{ + { + expected: map[string]string{ + "": "", + "a": "", + }, + }, + { + input: map[string]string{ + "authorization": "a", + "x-id-token": "b", + "cookie": "c", + httpHeaderPrefix + "x-custom": "d", + }, + expected: map[string]string{ + "": "", + "a": "", + "authorization": "a", + "Authorization": "a", + "x-id-token": "b", + "X-Id-Token": "b", + "cookie": "c", + "Cookie": "c", + "x-custom": "d", + "X-Custom": "d", + }, + }, + { + input: map[string]string{ + "Authorization": "a", + "X-ID-Token": "b", + "Cookie": "c", + httpHeaderPrefix + "X-Custom": "d", + }, + expected: map[string]string{ + "": "", + "a": "", + "authorization": "a", + "Authorization": "a", + "x-id-token": "b", + "X-Id-Token": "b", + "cookie": "c", + "Cookie": "c", + "x-custom": "d", + "X-Custom": "d", + }, + }, + } + + for _, tc := range tcs { + headers := getHTTPHeadersFromStringMap(tc.input) + + for k, v := range tc.expected { + require.Equal(t, v, headers.Get(k)) + } + } +} + +func TestDeleteHTTPHeaderInStringMap(t *testing.T) { + tcs := []struct { + input map[string]string + deleteKeys []string + expected map[string]string + }{ + { + expected: map[string]string{ + "": "", + "a": "", + }, + }, + { + input: map[string]string{ + "authorization": "a", + "x-id-token": "b", + "cookie": "c", + httpHeaderPrefix + "x-custom": "d", + }, + deleteKeys: []string{"authorization", "x-id-token", "cookie", "x-custom"}, + expected: map[string]string{ + "": "", + "a": "", + "authorization": "", + "Authorization": "", + "x-id-token": "", + "X-Id-Token": "", + "cookie": "", + "Cookie": "", + "x-custom": "", + "X-Custom": "", + }, + }, + { + input: map[string]string{ + "Authorization": "a", + "X-ID-Token": "b", + "Cookie": "c", + httpHeaderPrefix + "X-Custom": "d", + }, + deleteKeys: []string{"Authorization", "X-Id-Token", "Cookie", "X-Custom"}, + expected: map[string]string{ + "": "", + "a": "", + "authorization": "", + "Authorization": "", + "x-id-token": "", + "X-Id-Token": "", + "cookie": "", + "Cookie": "", + "x-custom": "", + "X-Custom": "", + }, + }, + } + + for _, tc := range tcs { + headerMap := make(map[string]string, len(tc.input)) + for k, v := range tc.input { + headerMap[k] = v + } + + for _, key := range tc.deleteKeys { + deleteHTTPHeaderInStringMap(headerMap, key) + } + headers := getHTTPHeadersFromStringMap(headerMap) + + for k, v := range tc.expected { + require.Equal(t, v, headers.Get(k)) + } + } +} diff --git a/backend/resource.go b/backend/resource.go index 248f91670..83198dcd5 100644 --- a/backend/resource.go +++ b/backend/resource.go @@ -2,23 +2,95 @@ package backend import ( "context" + "net/http" + "net/textproto" ) // CallResourceRequest represents a request for a resource call. type CallResourceRequest struct { + // PluginContext the contextual information for the request. PluginContext PluginContext - Path string - Method string - URL string - Headers map[string][]string - Body []byte + + // Path the forwarded HTTP path for the request. + Path string + + // Method the forwarded HTTP method for the request. + Method string + + // URL the forwarded HTTP URL for the request. + URL string + + // Headers the forwarded HTTP headers for the request, if any. + // + // Recommended to use GetHTTPHeaders or GetHTTPHeader + // since it automatically handles canonicalization of + // HTTP header keys. + Headers map[string][]string + + // Body the forwarded HTTP body for the request, if any. + Body []byte +} + +// SetHTTPHeader sets the header entries associated with key to the +// single element value. It replaces any existing values +// associated with key. The key is case insensitive; it is +// canonicalized by textproto.CanonicalMIMEHeaderKey. +func (req *CallResourceRequest) SetHTTPHeader(key, value string) { + if req.Headers == nil { + req.Headers = map[string][]string{} + } + + req.Headers[key] = []string{value} +} + +// DeleteHTTPHeader deletes the values associated with key. +// The key is case insensitive; it is canonicalized by +// CanonicalHeaderKey. +func (req *CallResourceRequest) DeleteHTTPHeader(key string) { + if req.Headers == nil { + return + } + + for k := range req.Headers { + if textproto.CanonicalMIMEHeaderKey(k) == textproto.CanonicalMIMEHeaderKey(key) { + delete(req.Headers, k) + break + } + } +} + +// GetHTTPHeader gets the first value associated with the given key. If +// there are no values associated with the key, Get returns "". +// It is case insensitive; textproto.CanonicalMIMEHeaderKey is +// used to canonicalize the provided key. Get assumes that all +// keys are stored in canonical form. +func (req *CallResourceRequest) GetHTTPHeader(key string) string { + return req.GetHTTPHeaders().Get(key) +} + +// GetHTTPHeaders returns HTTP headers. +func (req *CallResourceRequest) GetHTTPHeaders() http.Header { + httpHeaders := http.Header{} + + for k, v := range req.Headers { + for _, strVal := range v { + httpHeaders.Add(k, strVal) + } + } + + return httpHeaders } // CallResourceResponse represents a response from a resource call. type CallResourceResponse struct { - Status int + // Status the HTTP response status. + Status int + + // Headers the HTTP response headers. Headers map[string][]string - Body []byte + + // Body the HTTP response body. + Body []byte } // CallResourceResponseSender is used for sending resource call responses. @@ -41,3 +113,5 @@ type CallResourceHandlerFunc func(ctx context.Context, req *CallResourceRequest, func (fn CallResourceHandlerFunc) CallResource(ctx context.Context, req *CallResourceRequest, sender CallResourceResponseSender) error { return fn(ctx, req, sender) } + +var _ ForwardHTTPHeaders = (*CallResourceRequest)(nil) diff --git a/backend/resource_test.go b/backend/resource_test.go new file mode 100644 index 000000000..5dfe2420e --- /dev/null +++ b/backend/resource_test.go @@ -0,0 +1,105 @@ +package backend + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCallResourceRequest(t *testing.T) { + req := &CallResourceRequest{} + const customHeaderName = "X-Custom" + + t.Run("Legacy headers", func(t *testing.T) { + req.Headers = map[string][]string{ + "Authorization": {"a"}, + "X-ID-Token": {"b"}, + "Cookie": {"c"}, + customHeaderName: {"d"}, + } + + t.Run("GetHTTPHeaders canonical form", func(t *testing.T) { + headers := req.GetHTTPHeaders() + require.Equal(t, "a", headers.Get(OAuthIdentityTokenHeaderName)) + require.Equal(t, "b", headers.Get(OAuthIdentityIDTokenHeaderName)) + require.Equal(t, "c", headers.Get(CookiesHeaderName)) + require.Equal(t, "d", headers.Get(customHeaderName)) + }) + + t.Run("GetHTTPHeader canonical form", func(t *testing.T) { + require.Equal(t, "a", req.GetHTTPHeader(OAuthIdentityTokenHeaderName)) + require.Equal(t, "b", req.GetHTTPHeader(OAuthIdentityIDTokenHeaderName)) + require.Equal(t, "c", req.GetHTTPHeader(CookiesHeaderName)) + require.Equal(t, "d", req.GetHTTPHeader(customHeaderName)) + }) + + t.Run("DeleteHTTPHeader canonical form", func(t *testing.T) { + req.DeleteHTTPHeader(OAuthIdentityTokenHeaderName) + req.DeleteHTTPHeader(OAuthIdentityIDTokenHeaderName) + req.DeleteHTTPHeader(CookiesHeaderName) + req.DeleteHTTPHeader(customHeaderName) + require.Empty(t, req.Headers) + }) + }) + + t.Run("SetHTTPHeader canonical form", func(t *testing.T) { + req.SetHTTPHeader(OAuthIdentityTokenHeaderName, "a") + req.SetHTTPHeader(OAuthIdentityIDTokenHeaderName, "b") + req.SetHTTPHeader(CookiesHeaderName, "c") + req.SetHTTPHeader(customHeaderName, "d") + + t.Run("GetHTTPHeaders canonical form", func(t *testing.T) { + headers := req.GetHTTPHeaders() + require.Equal(t, "a", headers.Get(OAuthIdentityTokenHeaderName)) + require.Equal(t, "b", headers.Get(OAuthIdentityIDTokenHeaderName)) + require.Equal(t, "c", headers.Get(CookiesHeaderName)) + require.Equal(t, "d", headers.Get(customHeaderName)) + }) + + t.Run("GetHTTPHeader canonical form", func(t *testing.T) { + require.Equal(t, "a", req.GetHTTPHeader(OAuthIdentityTokenHeaderName)) + require.Equal(t, "b", req.GetHTTPHeader(OAuthIdentityIDTokenHeaderName)) + require.Equal(t, "c", req.GetHTTPHeader(CookiesHeaderName)) + require.Equal(t, "d", req.GetHTTPHeader(customHeaderName)) + }) + + t.Run("DeleteHTTPHeader canonical form", func(t *testing.T) { + req.DeleteHTTPHeader(OAuthIdentityTokenHeaderName) + req.DeleteHTTPHeader(OAuthIdentityIDTokenHeaderName) + req.DeleteHTTPHeader(CookiesHeaderName) + req.DeleteHTTPHeader(customHeaderName) + require.Empty(t, req.Headers) + }) + }) + + t.Run("SetHTTPHeader non-canonical form", func(t *testing.T) { + req.SetHTTPHeader(strings.ToLower(OAuthIdentityTokenHeaderName), "a") + req.SetHTTPHeader(strings.ToLower(OAuthIdentityIDTokenHeaderName), "b") + req.SetHTTPHeader(strings.ToLower(CookiesHeaderName), "c") + req.SetHTTPHeader(strings.ToLower(customHeaderName), "d") + + t.Run("GetHTTPHeaders non-canonical form", func(t *testing.T) { + headers := req.GetHTTPHeaders() + require.Equal(t, "a", headers.Get(strings.ToLower(OAuthIdentityTokenHeaderName))) + require.Equal(t, "b", headers.Get(strings.ToLower(OAuthIdentityIDTokenHeaderName))) + require.Equal(t, "c", headers.Get(strings.ToLower(CookiesHeaderName))) + require.Equal(t, "d", headers.Get(strings.ToLower(customHeaderName))) + }) + + t.Run("GetHTTPHeader non-canonical form", func(t *testing.T) { + require.Equal(t, "a", req.GetHTTPHeader(strings.ToLower(OAuthIdentityTokenHeaderName))) + require.Equal(t, "b", req.GetHTTPHeader(strings.ToLower(OAuthIdentityIDTokenHeaderName))) + require.Equal(t, "c", req.GetHTTPHeader(strings.ToLower(CookiesHeaderName))) + require.Equal(t, "d", req.GetHTTPHeader(strings.ToLower(customHeaderName))) + }) + + t.Run("DeleteHTTPHeader non-canonical form", func(t *testing.T) { + req.DeleteHTTPHeader(strings.ToLower(OAuthIdentityTokenHeaderName)) + req.DeleteHTTPHeader(strings.ToLower(OAuthIdentityIDTokenHeaderName)) + req.DeleteHTTPHeader(strings.ToLower(CookiesHeaderName)) + req.DeleteHTTPHeader(strings.ToLower(customHeaderName)) + require.Empty(t, req.Headers) + }) + }) +}