diff --git a/pkg/builder/webhook_test.go b/pkg/builder/webhook_test.go index 80703a38ff..eca8c5e3a6 100644 --- a/pkg/builder/webhook_test.go +++ b/pkg/builder/webhook_test.go @@ -539,11 +539,13 @@ func runTests(admissionReviewVersion string) { // TestDefaulter. var _ runtime.Object = &TestDefaulter{} +const testDefaulterKind = "TestDefaulter" + type TestDefaulter struct { Replica int `json:"replica,omitempty"` } -var testDefaulterGVK = schema.GroupVersionKind{Group: "foo.test.org", Version: "v1", Kind: "TestDefaulter"} +var testDefaulterGVK = schema.GroupVersionKind{Group: "foo.test.org", Version: "v1", Kind: testDefaulterKind} func (d *TestDefaulter) GetObjectKind() schema.ObjectKind { return d } func (d *TestDefaulter) DeepCopyObject() runtime.Object { @@ -574,11 +576,13 @@ func (d *TestDefaulter) Default() { // TestValidator. var _ runtime.Object = &TestValidator{} +const testValidatorKind = "TestValidator" + type TestValidator struct { Replica int `json:"replica,omitempty"` } -var testValidatorGVK = schema.GroupVersionKind{Group: "foo.test.org", Version: "v1", Kind: "TestValidator"} +var testValidatorGVK = schema.GroupVersionKind{Group: "foo.test.org", Version: "v1", Kind: testValidatorKind} func (v *TestValidator) GetObjectKind() schema.ObjectKind { return v } func (v *TestValidator) DeepCopyObject() runtime.Object { @@ -694,6 +698,14 @@ func (dv *TestDefaultValidator) ValidateDelete() error { type TestCustomDefaulter struct{} func (*TestCustomDefaulter) Default(ctx context.Context, obj runtime.Object) error { + req, err := admission.RequestFromContext(ctx) + if err != nil { + return fmt.Errorf("expected admission.Request in ctx: %w", err) + } + if req.Kind.Kind != testDefaulterKind { + return fmt.Errorf("expected Kind TestDefaulter got %q", req.Kind.Kind) + } + d := obj.(*TestDefaulter) //nolint:ifshort if d.Replica < 2 { d.Replica = 2 @@ -708,6 +720,14 @@ var _ admission.CustomDefaulter = &TestCustomDefaulter{} type TestCustomValidator struct{} func (*TestCustomValidator) ValidateCreate(ctx context.Context, obj runtime.Object) error { + req, err := admission.RequestFromContext(ctx) + if err != nil { + return fmt.Errorf("expected admission.Request in ctx: %w", err) + } + if req.Kind.Kind != testValidatorKind { + return fmt.Errorf("expected Kind TestValidator got %q", req.Kind.Kind) + } + v := obj.(*TestValidator) //nolint:ifshort if v.Replica < 0 { return errors.New("number of replica should be greater than or equal to 0") @@ -716,6 +736,14 @@ func (*TestCustomValidator) ValidateCreate(ctx context.Context, obj runtime.Obje } func (*TestCustomValidator) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Object) error { + req, err := admission.RequestFromContext(ctx) + if err != nil { + return fmt.Errorf("expected admission.Request in ctx: %w", err) + } + if req.Kind.Kind != testValidatorKind { + return fmt.Errorf("expected Kind TestValidator got %q", req.Kind.Kind) + } + v := newObj.(*TestValidator) old := oldObj.(*TestValidator) //nolint:ifshort if v.Replica < 0 { @@ -728,6 +756,14 @@ func (*TestCustomValidator) ValidateUpdate(ctx context.Context, oldObj, newObj r } func (*TestCustomValidator) ValidateDelete(ctx context.Context, obj runtime.Object) error { + req, err := admission.RequestFromContext(ctx) + if err != nil { + return fmt.Errorf("expected admission.Request in ctx: %w", err) + } + if req.Kind.Kind != testValidatorKind { + return fmt.Errorf("expected Kind TestValidator got %q", req.Kind.Kind) + } + v := obj.(*TestValidator) //nolint:ifshort if v.Replica > 0 { return errors.New("number of replica should be less than or equal to 0 to delete") diff --git a/pkg/webhook/admission/defaulter_custom.go b/pkg/webhook/admission/defaulter_custom.go index a012784e43..d65727e62c 100644 --- a/pkg/webhook/admission/defaulter_custom.go +++ b/pkg/webhook/admission/defaulter_custom.go @@ -19,7 +19,6 @@ package admission import ( "context" "encoding/json" - "errors" "net/http" @@ -61,6 +60,8 @@ func (h *defaulterForType) Handle(ctx context.Context, req Request) Response { panic("object should never be nil") } + ctx = NewContextWithRequest(ctx, req) + // Get the object in the request obj := h.object.DeepCopyObject() if err := h.decoder.Decode(req, obj); err != nil { diff --git a/pkg/webhook/admission/validator_custom.go b/pkg/webhook/admission/validator_custom.go index 38d5565111..33252f1134 100644 --- a/pkg/webhook/admission/validator_custom.go +++ b/pkg/webhook/admission/validator_custom.go @@ -64,6 +64,8 @@ func (h *validatorForType) Handle(ctx context.Context, req Request) Response { panic("object should never be nil") } + ctx = NewContextWithRequest(ctx, req) + // Get the object in the request obj := h.object.DeepCopyObject() diff --git a/pkg/webhook/admission/webhook.go b/pkg/webhook/admission/webhook.go index 3dcff5fadd..cfc46637c3 100644 --- a/pkg/webhook/admission/webhook.go +++ b/pkg/webhook/admission/webhook.go @@ -253,3 +253,21 @@ func StandaloneWebhook(hook *Webhook, opts StandaloneOptions) (http.Handler, err } return metrics.InstrumentedHook(opts.MetricsPath, hook), nil } + +// requestContextKey is how we find the admission.Request in a context.Context. +type requestContextKey struct{} + +// RequestFromContext returns an admission.Request from ctx. +func RequestFromContext(ctx context.Context) (Request, error) { + if v, ok := ctx.Value(requestContextKey{}).(Request); ok { + return v, nil + } + + return Request{}, errors.New("admission.Request not found in context") +} + +// NewContextWithRequest returns a new Context, derived from ctx, which carries the +// provided admission.Request. +func NewContextWithRequest(ctx context.Context, req Request) context.Context { + return context.WithValue(ctx, requestContextKey{}, req) +} diff --git a/pkg/webhook/admission/webhook_test.go b/pkg/webhook/admission/webhook_test.go index 73b0be1694..272d00e57a 100644 --- a/pkg/webhook/admission/webhook_test.go +++ b/pkg/webhook/admission/webhook_test.go @@ -194,6 +194,21 @@ var _ = Describe("Admission Webhooks", func() { }) }) +var _ = Describe("Should be able to write/read admission.Request to/from context", func() { + ctx := context.Background() + testRequest := Request{ + admissionv1.AdmissionRequest{ + UID: "test-uid", + }, + } + + ctx = NewContextWithRequest(ctx, testRequest) + + gotRequest, err := RequestFromContext(ctx) + Expect(err).To(Not(HaveOccurred())) + Expect(gotRequest).To(Equal(testRequest)) +}) + type stringInjector interface { InjectString(s string) error }