Skip to content

Commit

Permalink
Prepare internals for exposing context.Context in exported API
Browse files Browse the repository at this point in the history
As of today there is not a way for callers to control the timeout of requests on
a per-call basis, as we do not expose context.Context in the APIs of this
package.

This change updates some of the internal code paths to use context.Context,
making use of a context.TODO() for now (until we can expose context.Context) in
the API.

This change also includes some style changes to help make the code a bit less
dense / more idiomatic in some cases. This also adds a few TODO comments, as a
reminder to come back and look at concerning blocks of code that caught my eye.
  • Loading branch information
theckman committed Feb 6, 2021
1 parent eab6a10 commit bdf4bc5
Show file tree
Hide file tree
Showing 23 changed files with 199 additions and 161 deletions.
6 changes: 4 additions & 2 deletions ability.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package pagerduty

import "context"

// ListAbilityResponse is the response when calling the ListAbility API endpoint.
type ListAbilityResponse struct {
Abilities []string `json:"abilities"`
}

// ListAbilities lists all abilities on your account.
func (c *Client) ListAbilities() (*ListAbilityResponse, error) {
resp, err := c.get("/abilities")
resp, err := c.get(context.TODO(), "/abilities")
if err != nil {
return nil, err
}
Expand All @@ -17,6 +19,6 @@ func (c *Client) ListAbilities() (*ListAbilityResponse, error) {

// TestAbility Check if your account has the given ability.
func (c *Client) TestAbility(ability string) error {
_, err := c.get("/abilities/" + ability)
_, err := c.get(context.TODO(), "/abilities/"+ability)
return err
}
13 changes: 7 additions & 6 deletions addon.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package pagerduty

import (
"context"
"fmt"
"net/http"

Expand Down Expand Up @@ -35,7 +36,7 @@ func (c *Client) ListAddons(o ListAddonOptions) (*ListAddonResponse, error) {
if err != nil {
return nil, err
}
resp, err := c.get("/addons?" + v.Encode())
resp, err := c.get(context.TODO(), "/addons?"+v.Encode())
if err != nil {
return nil, err
}
Expand All @@ -47,8 +48,8 @@ func (c *Client) ListAddons(o ListAddonOptions) (*ListAddonResponse, error) {
func (c *Client) InstallAddon(a Addon) (*Addon, error) {
data := make(map[string]Addon)
data["addon"] = a
resp, err := c.post("/addons", data, nil)
defer resp.Body.Close()
resp, err := c.post(context.TODO(), "/addons", data, nil)
defer resp.Body.Close() // TODO(theckman): validate that this is safe
if err != nil {
return nil, err
}
Expand All @@ -60,13 +61,13 @@ func (c *Client) InstallAddon(a Addon) (*Addon, error) {

// DeleteAddon deletes an add-on from your account.
func (c *Client) DeleteAddon(id string) error {
_, err := c.delete("/addons/" + id)
_, err := c.delete(context.TODO(), "/addons/"+id)
return err
}

// GetAddon gets details about an existing add-on.
func (c *Client) GetAddon(id string) (*Addon, error) {
resp, err := c.get("/addons/" + id)
resp, err := c.get(context.TODO(), "/addons/"+id)
if err != nil {
return nil, err
}
Expand All @@ -77,7 +78,7 @@ func (c *Client) GetAddon(id string) (*Addon, error) {
func (c *Client) UpdateAddon(id string, a Addon) (*Addon, error) {
v := make(map[string]Addon)
v["addon"] = a
resp, err := c.put("/addons/"+id, v, nil)
resp, err := c.put(context.TODO(), "/addons/"+id, v, nil)
if err != nil {
return nil, err
}
Expand Down
11 changes: 6 additions & 5 deletions business_service.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package pagerduty

import (
"context"
"fmt"
"net/http"

Expand Down Expand Up @@ -75,7 +76,7 @@ func (c *Client) ListBusinessServices(o ListBusinessServiceOptions) (*ListBusine
}

// Make call to get all pages associated with the base endpoint.
if err := c.pagedGet("/business_services"+queryParms.Encode(), responseHandler); err != nil {
if err := c.pagedGet(context.TODO(), "/business_services"+queryParms.Encode(), responseHandler); err != nil {
return nil, err
}
businessServiceResponse.BusinessServices = businessServices
Expand All @@ -87,19 +88,19 @@ func (c *Client) ListBusinessServices(o ListBusinessServiceOptions) (*ListBusine
func (c *Client) CreateBusinessService(b *BusinessService) (*BusinessService, *http.Response, error) {
data := make(map[string]*BusinessService)
data["business_service"] = b
resp, err := c.post("/business_services", data, nil)
resp, err := c.post(context.TODO(), "/business_services", data, nil)
return getBusinessServiceFromResponse(c, resp, err)
}

// GetBusinessService gets details about a business service.
func (c *Client) GetBusinessService(ID string) (*BusinessService, *http.Response, error) {
resp, err := c.get("/business_services/" + ID)
resp, err := c.get(context.TODO(), "/business_services/"+ID)
return getBusinessServiceFromResponse(c, resp, err)
}

// DeleteBusinessService deletes a business_service.
func (c *Client) DeleteBusinessService(ID string) error {
_, err := c.delete("/business_services/" + ID)
_, err := c.delete(context.TODO(), "/business_services/"+ID)
return err
}

Expand All @@ -109,7 +110,7 @@ func (c *Client) UpdateBusinessService(b *BusinessService) (*BusinessService, *h
id := b.ID
b.ID = ""
v["business_service"] = b
resp, err := c.put("/business_services/"+id, v, nil)
resp, err := c.put(context.TODO(), "/business_services/"+id, v, nil)
return getBusinessServiceFromResponse(c, resp, err)
}

Expand Down
5 changes: 4 additions & 1 deletion change_events.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package pagerduty

import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
)

const changeEventPath = "/v2/change/enqueue"
Expand Down Expand Up @@ -55,8 +57,9 @@ func (c *Client) CreateChangeEvent(e ChangeEvent) (*ChangeEventResponse, error)
}

resp, err := c.doWithEndpoint(
context.TODO(),
c.v2EventsAPIEndpoint,
"POST",
http.MethodPost,
changeEventPath,
false,
bytes.NewBuffer(data),
Expand Down
36 changes: 20 additions & 16 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pagerduty

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -156,37 +157,40 @@ func WithOAuth() ClientOptions {
}
}

func (c *Client) delete(path string) (*http.Response, error) {
return c.do("DELETE", path, nil, nil)
func (c *Client) delete(ctx context.Context, path string) (*http.Response, error) {
return c.do(ctx, http.MethodDelete, path, nil, nil)
}

func (c *Client) put(path string, payload interface{}, headers *map[string]string) (*http.Response, error) {

func (c *Client) put(ctx context.Context, path string, payload interface{}, headers *map[string]string) (*http.Response, error) {
if payload != nil {
data, err := json.Marshal(payload)
if err != nil {
return nil, err
}
return c.do("PUT", path, bytes.NewBuffer(data), headers)
return c.do(ctx, http.MethodPut, path, bytes.NewBuffer(data), headers)
}
return c.do("PUT", path, nil, headers)
return c.do(ctx, http.MethodPut, path, nil, headers)
}

func (c *Client) post(path string, payload interface{}, headers *map[string]string) (*http.Response, error) {
func (c *Client) post(ctx context.Context, path string, payload interface{}, headers *map[string]string) (*http.Response, error) {
data, err := json.Marshal(payload)
if err != nil {
return nil, err
}
return c.do("POST", path, bytes.NewBuffer(data), headers)
return c.do(ctx, http.MethodPost, path, bytes.NewBuffer(data), headers)
}

func (c *Client) get(path string) (*http.Response, error) {
return c.do("GET", path, nil, nil)
func (c *Client) get(ctx context.Context, path string) (*http.Response, error) {
return c.do(ctx, http.MethodGet, path, nil, nil)
}

// needed where pagerduty use a different endpoint for certain actions (eg: v2 events)
func (c *Client) doWithEndpoint(endpoint, method, path string, authRequired bool, body io.Reader, headers *map[string]string) (*http.Response, error) {
req, _ := http.NewRequest(method, endpoint+path, body)
func (c *Client) doWithEndpoint(ctx context.Context, endpoint, method, path string, authRequired bool, body io.Reader, headers *map[string]string) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, method, endpoint+path, body)
if err != nil {
return nil, fmt.Errorf("failed to build request: %w", err)
}

req.Header.Set("Accept", "application/vnd.pagerduty+json;version=2")
if headers != nil {
for k, v := range *headers {
Expand All @@ -210,8 +214,8 @@ func (c *Client) doWithEndpoint(endpoint, method, path string, authRequired bool
return c.checkResponse(resp, err)
}

func (c *Client) do(method, path string, body io.Reader, headers *map[string]string) (*http.Response, error) {
return c.doWithEndpoint(c.apiEndpoint, method, path, true, body, headers)
func (c *Client) do(ctx context.Context, method, path string, body io.Reader, headers *map[string]string) (*http.Response, error) {
return c.doWithEndpoint(ctx, c.apiEndpoint, method, path, true, body, headers)
}

func (c *Client) decodeJSON(resp *http.Response, payload interface{}) error {
Expand Down Expand Up @@ -254,7 +258,7 @@ func (c *Client) getErrorFromResponse(resp *http.Response) (*errorObject, error)
// a specific slice. The responseHandler is responsible for closing the response.
type responseHandler func(response *http.Response) (APIListObject, error)

func (c *Client) pagedGet(basePath string, handler responseHandler) error {
func (c *Client) pagedGet(ctx context.Context, basePath string, handler responseHandler) error {
// Indicates whether there are still additional pages associated with request.
var stillMore bool

Expand All @@ -263,7 +267,7 @@ func (c *Client) pagedGet(basePath string, handler responseHandler) error {

// While there are more pages, keep adjusting the offset to get all results.
for stillMore, nextOffset = true, 0; stillMore; {
response, err := c.do("GET", fmt.Sprintf("%s?offset=%d", basePath, nextOffset), nil, nil)
response, err := c.do(ctx, http.MethodGet, fmt.Sprintf("%s?offset=%d", basePath, nextOffset), nil, nil)
if err != nil {
return err
}
Expand Down
21 changes: 11 additions & 10 deletions escalation_policy.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package pagerduty

import (
"context"
"fmt"
"net/http"

Expand Down Expand Up @@ -62,7 +63,7 @@ func (c *Client) ListEscalationPolicies(o ListEscalationPoliciesOptions) (*ListE
if err != nil {
return nil, err
}
resp, err := c.get(escPath + "?" + v.Encode())
resp, err := c.get(context.TODO(), escPath+"?"+v.Encode())
if err != nil {
return nil, err
}
Expand All @@ -74,13 +75,13 @@ func (c *Client) ListEscalationPolicies(o ListEscalationPoliciesOptions) (*ListE
func (c *Client) CreateEscalationPolicy(e EscalationPolicy) (*EscalationPolicy, error) {
data := make(map[string]EscalationPolicy)
data["escalation_policy"] = e
resp, err := c.post(escPath, data, nil)
resp, err := c.post(context.TODO(), escPath, data, nil)
return getEscalationPolicyFromResponse(c, resp, err)
}

// DeleteEscalationPolicy deletes an existing escalation policy and rules.
func (c *Client) DeleteEscalationPolicy(id string) error {
_, err := c.delete(escPath + "/" + id)
_, err := c.delete(context.TODO(), escPath+"/"+id)
return err
}

Expand All @@ -95,15 +96,15 @@ func (c *Client) GetEscalationPolicy(id string, o *GetEscalationPolicyOptions) (
if err != nil {
return nil, err
}
resp, err := c.get(escPath + "/" + id + "?" + v.Encode())
resp, err := c.get(context.TODO(), escPath+"/"+id+"?"+v.Encode())
return getEscalationPolicyFromResponse(c, resp, err)
}

// UpdateEscalationPolicy updates an existing escalation policy and its rules.
func (c *Client) UpdateEscalationPolicy(id string, e *EscalationPolicy) (*EscalationPolicy, error) {
data := make(map[string]EscalationPolicy)
data["escalation_policy"] = *e
resp, err := c.put(escPath+"/"+id, data, nil)
resp, err := c.put(context.TODO(), escPath+"/"+id, data, nil)
return getEscalationPolicyFromResponse(c, resp, err)
}

Expand All @@ -112,7 +113,7 @@ func (c *Client) UpdateEscalationPolicy(id string, e *EscalationPolicy) (*Escala
func (c *Client) CreateEscalationRule(escID string, e EscalationRule) (*EscalationRule, error) {
data := make(map[string]EscalationRule)
data["escalation_rule"] = e
resp, err := c.post(escPath+"/"+escID+"/escalation_rules", data, nil)
resp, err := c.post(context.TODO(), escPath+"/"+escID+"/escalation_rules", data, nil)
return getEscalationRuleFromResponse(c, resp, err)
}

Expand All @@ -122,27 +123,27 @@ func (c *Client) GetEscalationRule(escID string, id string, o *GetEscalationRule
if err != nil {
return nil, err
}
resp, err := c.get(escPath + "/" + escID + "/escalation_rules/" + id + "?" + v.Encode())
resp, err := c.get(context.TODO(), escPath+"/"+escID+"/escalation_rules/"+id+"?"+v.Encode())
return getEscalationRuleFromResponse(c, resp, err)
}

// DeleteEscalationRule deletes an existing escalation rule.
func (c *Client) DeleteEscalationRule(escID string, id string) error {
_, err := c.delete(escPath + "/" + escID + "/escalation_rules/" + id)
_, err := c.delete(context.TODO(), escPath+"/"+escID+"/escalation_rules/"+id)
return err
}

// UpdateEscalationRule updates an existing escalation rule.
func (c *Client) UpdateEscalationRule(escID string, id string, e *EscalationRule) (*EscalationRule, error) {
data := make(map[string]EscalationRule)
data["escalation_rule"] = *e
resp, err := c.put(escPath+"/"+escID+"/escalation_rules/"+id, data, nil)
resp, err := c.put(context.TODO(), escPath+"/"+escID+"/escalation_rules/"+id, data, nil)
return getEscalationRuleFromResponse(c, resp, err)
}

// ListEscalationRules lists all of the escalation rules for an existing escalation policy.
func (c *Client) ListEscalationRules(escID string) (*ListEscalationRulesResponse, error) {
resp, err := c.get(escPath + "/" + escID + "/escalation_rules")
resp, err := c.get(context.TODO(), escPath+"/"+escID+"/escalation_rules")
if err != nil {
return nil, err
}
Expand Down
13 changes: 11 additions & 2 deletions event_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pagerduty

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io/ioutil"
Expand Down Expand Up @@ -48,9 +49,16 @@ func ManageEvent(e V2Event) (*V2EventResponse, error) {
if err != nil {
return nil, err
}
req, _ := http.NewRequest("POST", v2eventEndPoint, bytes.NewBuffer(data))

req, err := http.NewRequestWithContext(context.TODO(), http.MethodPost, v2eventEndPoint, bytes.NewBuffer(data))
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
}

req.Header.Set("User-Agent", "go-pagerduty/"+Version)
req.Header.Set("Content-Type", "application/json")

// TODO(theckman): switch to a package-local default client
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
Expand Down Expand Up @@ -78,7 +86,8 @@ func (c *Client) ManageEvent(e *V2Event) (*V2EventResponse, error) {
if err != nil {
return nil, err
}
resp, err := c.doWithEndpoint(c.v2EventsAPIEndpoint, "POST", "/v2/enqueue", false, bytes.NewBuffer(data), &headers)

resp, err := c.doWithEndpoint(context.TODO(), c.v2EventsAPIEndpoint, http.MethodPost, "/v2/enqueue", false, bytes.NewBuffer(data), &headers)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit bdf4bc5

Please sign in to comment.