Skip to content

Commit

Permalink
Merge pull request #1034 from zchee/ctx-aware
Browse files Browse the repository at this point in the history
Support pass context.Context to all methods
  • Loading branch information
zchee committed Feb 24, 2022
2 parents bbab81d + 741a1f5 commit ff619ea
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 154 deletions.
6 changes: 5 additions & 1 deletion apps.go
Expand Up @@ -44,14 +44,18 @@ func (api *Client) ListEventAuthorizationsContext(ctx context.Context, eventCont
}

func (api *Client) UninstallApp(clientID, clientSecret string) error {
return api.UninstallAppContext(context.Background(), clientID, clientSecret)
}

func (api *Client) UninstallAppContext(ctx context.Context, clientID, clientSecret string) error {
values := url.Values{
"client_id": {clientID},
"client_secret": {clientSecret},
}

response := SlackResponse{}

err := api.getMethod(context.Background(), "apps.uninstall", api.token, values, &response)
err := api.getMethod(ctx, "apps.uninstall", api.token, values, &response)
if err != nil {
return err
}
Expand Down
60 changes: 29 additions & 31 deletions chat.go
Expand Up @@ -86,12 +86,7 @@ func NewPostMessageParameters() PostMessageParameters {

// DeleteMessage deletes a message in a channel
func (api *Client) DeleteMessage(channel, messageTimestamp string) (string, string, error) {
respChannel, respTimestamp, _, err := api.SendMessageContext(
context.Background(),
channel,
MsgOptionDelete(messageTimestamp),
)
return respChannel, respTimestamp, err
return api.DeleteMessageContext(context.Background(), channel, messageTimestamp)
}

// DeleteMessageContext deletes a message in a channel with a custom context
Expand All @@ -108,8 +103,15 @@ func (api *Client) DeleteMessageContext(ctx context.Context, channel, messageTim
// Message is escaped by default according to https://api.slack.com/docs/formatting
// Use http://davestevens.github.io/slack-message-builder/ to help crafting your message.
func (api *Client) ScheduleMessage(channelID, postAt string, options ...MsgOption) (string, string, error) {
return api.ScheduleMessageContext(context.Background(), channelID, postAt, options...)
}

// ScheduleMessageContext sends a message to a channel with a custom context
//
// For more details, see ScheduleMessage documentation.
func (api *Client) ScheduleMessageContext(ctx context.Context, channelID, postAt string, options ...MsgOption) (string, string, error) {
respChannel, respTimestamp, _, err := api.SendMessageContext(
context.Background(),
ctx,
channelID,
MsgOptionSchedule(postAt),
MsgOptionCompose(options...),
Expand All @@ -121,13 +123,7 @@ func (api *Client) ScheduleMessage(channelID, postAt string, options ...MsgOptio
// Message is escaped by default according to https://api.slack.com/docs/formatting
// Use http://davestevens.github.io/slack-message-builder/ to help crafting your message.
func (api *Client) PostMessage(channelID string, options ...MsgOption) (string, string, error) {
respChannel, respTimestamp, _, err := api.SendMessageContext(
context.Background(),
channelID,
MsgOptionPost(),
MsgOptionCompose(options...),
)
return respChannel, respTimestamp, err
return api.PostMessageContext(context.Background(), channelID, options...)
}

// PostMessageContext sends a message to a channel with a custom context
Expand All @@ -146,12 +142,7 @@ func (api *Client) PostMessageContext(ctx context.Context, channelID string, opt
// Message is escaped by default according to https://api.slack.com/docs/formatting
// Use http://davestevens.github.io/slack-message-builder/ to help crafting your message.
func (api *Client) PostEphemeral(channelID, userID string, options ...MsgOption) (string, error) {
return api.PostEphemeralContext(
context.Background(),
channelID,
userID,
options...,
)
return api.PostEphemeralContext(context.Background(), channelID, userID, options...)
}

// PostEphemeralContext sends an ephemeal message to a user in a channel with a custom context
Expand All @@ -168,12 +159,7 @@ func (api *Client) PostEphemeralContext(ctx context.Context, channelID, userID s

// UpdateMessage updates a message in a channel
func (api *Client) UpdateMessage(channelID, timestamp string, options ...MsgOption) (string, string, string, error) {
return api.SendMessageContext(
context.Background(),
channelID,
MsgOptionUpdate(timestamp),
MsgOptionCompose(options...),
)
return api.UpdateMessageContext(context.Background(), channelID, timestamp, options...)
}

// UpdateMessageContext updates a message in a channel
Expand Down Expand Up @@ -225,7 +211,7 @@ func (api *Client) SendMessageContext(ctx context.Context, channelID string, opt
response chatResponseFull
)

if req, parser, err = buildSender(api.endpoint, options...).BuildRequest(api.token, channelID); err != nil {
if req, parser, err = buildSender(api.endpoint, options...).BuildRequestContext(ctx, api.token, channelID); err != nil {
return "", "", "", err
}

Expand Down Expand Up @@ -306,6 +292,10 @@ type sendConfig struct {
}

func (t sendConfig) BuildRequest(token, channelID string) (req *http.Request, _ func(*chatResponseFull) responseParser, err error) {
return t.BuildRequestContext(context.Background(), token, channelID)
}

func (t sendConfig) BuildRequestContext(ctx context.Context, token, channelID string) (req *http.Request, _ func(*chatResponseFull) responseParser, err error) {
if t, err = applyMsgOptions(token, channelID, t.apiurl, t.options...); err != nil {
return nil, nil, err
}
Expand All @@ -320,9 +310,9 @@ func (t sendConfig) BuildRequest(token, channelID string) (req *http.Request, _
responseType: t.responseType,
replaceOriginal: t.replaceOriginal,
deleteOriginal: t.deleteOriginal,
}.BuildRequest()
}.BuildRequestContext(ctx)
default:
return formSender{endpoint: t.endpoint, values: t.values}.BuildRequest()
return formSender{endpoint: t.endpoint, values: t.values}.BuildRequestContext(ctx)
}
}

Expand All @@ -332,7 +322,11 @@ type formSender struct {
}

func (t formSender) BuildRequest() (*http.Request, func(*chatResponseFull) responseParser, error) {
req, err := formReq(t.endpoint, t.values)
return t.BuildRequestContext(context.Background())
}

func (t formSender) BuildRequestContext(ctx context.Context) (*http.Request, func(*chatResponseFull) responseParser, error) {
req, err := formReq(ctx, t.endpoint, t.values)
return req, func(resp *chatResponseFull) responseParser {
return newJSONParser(resp)
}, err
Expand All @@ -349,7 +343,11 @@ type responseURLSender struct {
}

func (t responseURLSender) BuildRequest() (*http.Request, func(*chatResponseFull) responseParser, error) {
req, err := jsonReq(t.endpoint, Msg{
return t.BuildRequestContext(context.Background())
}

func (t responseURLSender) BuildRequestContext(ctx context.Context) (*http.Request, func(*chatResponseFull) responseParser, error) {
req, err := jsonReq(ctx, t.endpoint, Msg{
Text: t.values.Get("text"),
Timestamp: t.values.Get("ts"),
Attachments: t.attachments,
Expand Down
79 changes: 44 additions & 35 deletions files.go
Expand Up @@ -202,48 +202,21 @@ func (api *Client) GetFileInfoContext(ctx context.Context, fileID string, count,

// GetFile retreives a given file from its private download URL
func (api *Client) GetFile(downloadURL string, writer io.Writer) error {
return downloadFile(api.httpclient, api.token, downloadURL, writer, api)
return api.GetFileContext(context.Background(), downloadURL, writer)
}

// GetFileContext retreives a given file from its private download URL with a custom context
//
// For more details, see GetFile documentation.
func (api *Client) GetFileContext(ctx context.Context, downloadURL string, writer io.Writer) error {
return downloadFile(ctx, api.httpclient, api.token, downloadURL, writer, api)
}

// GetFiles retrieves all files according to the parameters given
func (api *Client) GetFiles(params GetFilesParameters) ([]File, *Paging, error) {
return api.GetFilesContext(context.Background(), params)
}

// ListFiles retrieves all files according to the parameters given. Uses cursor based pagination.
func (api *Client) ListFiles(params ListFilesParameters) ([]File, *ListFilesParameters, error) {
return api.ListFilesContext(context.Background(), params)
}

// ListFilesContext retrieves all files according to the parameters given with a custom context. Uses cursor based pagination.
func (api *Client) ListFilesContext(ctx context.Context, params ListFilesParameters) ([]File, *ListFilesParameters, error) {
values := url.Values{
"token": {api.token},
}

if params.User != DEFAULT_FILES_USER {
values.Add("user", params.User)
}
if params.Channel != DEFAULT_FILES_CHANNEL {
values.Add("channel", params.Channel)
}
if params.Limit != DEFAULT_FILES_COUNT {
values.Add("limit", strconv.Itoa(params.Limit))
}
if params.Cursor != "" {
values.Add("cursor", params.Cursor)
}

response, err := api.fileRequest(ctx, "files.list", values)
if err != nil {
return nil, nil, err
}

params.Cursor = response.Metadata.Cursor

return response.Files, &params, nil
}

// GetFilesContext retrieves all files according to the parameters given with a custom context
func (api *Client) GetFilesContext(ctx context.Context, params GetFilesParameters) ([]File, *Paging, error) {
values := url.Values{
Expand Down Expand Up @@ -281,6 +254,42 @@ func (api *Client) GetFilesContext(ctx context.Context, params GetFilesParameter
return response.Files, &response.Paging, nil
}

// ListFiles retrieves all files according to the parameters given. Uses cursor based pagination.
func (api *Client) ListFiles(params ListFilesParameters) ([]File, *ListFilesParameters, error) {
return api.ListFilesContext(context.Background(), params)
}

// ListFilesContext retrieves all files according to the parameters given with a custom context.
//
// For more details, see ListFiles documentation.
func (api *Client) ListFilesContext(ctx context.Context, params ListFilesParameters) ([]File, *ListFilesParameters, error) {
values := url.Values{
"token": {api.token},
}

if params.User != DEFAULT_FILES_USER {
values.Add("user", params.User)
}
if params.Channel != DEFAULT_FILES_CHANNEL {
values.Add("channel", params.Channel)
}
if params.Limit != DEFAULT_FILES_COUNT {
values.Add("limit", strconv.Itoa(params.Limit))
}
if params.Cursor != "" {
values.Add("cursor", params.Cursor)
}

response, err := api.fileRequest(ctx, "files.list", values)
if err != nil {
return nil, nil, err
}

params.Cursor = response.Metadata.Cursor

return response.Files, &params, nil
}

// UploadFile uploads a file
func (api *Client) UploadFile(params FileUploadParameters) (file *File, err error) {
return api.UploadFileContext(context.Background(), params)
Expand Down
6 changes: 5 additions & 1 deletion info.go
Expand Up @@ -321,9 +321,13 @@ type UserPrefs struct {
}

func (api *Client) GetUserPrefs() (*UserPrefsCarrier, error) {
return api.GetUserPrefsContext(context.Background())
}

func (api *Client) GetUserPrefsContext(ctx context.Context) (*UserPrefsCarrier, error) {
response := UserPrefsCarrier{}

err := api.getMethod(context.Background(), "users.prefs.get", api.token, url.Values{}, &response)
err := api.getMethod(ctx, "users.prefs.get", api.token, url.Values{}, &response)
if err != nil {
return nil, err
}
Expand Down
26 changes: 11 additions & 15 deletions misc.go
Expand Up @@ -66,29 +66,27 @@ func (e *RateLimitedError) Retryable() bool {
}

func fileUploadReq(ctx context.Context, path string, values url.Values, r io.Reader) (*http.Request, error) {
req, err := http.NewRequest("POST", path, r)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, path, r)
if err != nil {
return nil, err
}

req = req.WithContext(ctx)
req.URL.RawQuery = (values).Encode()
req.URL.RawQuery = values.Encode()
return req, nil
}

func downloadFile(client httpClient, token string, downloadURL string, writer io.Writer, d Debug) error {
func downloadFile(ctx context.Context, client httpClient, token string, downloadURL string, writer io.Writer, d Debug) error {
if downloadURL == "" {
return fmt.Errorf("received empty download URL")
}

req, err := http.NewRequest("GET", downloadURL, &bytes.Buffer{})
req, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadURL, &bytes.Buffer{})
if err != nil {
return err
}

var bearer = "Bearer " + token
req.Header.Add("Authorization", bearer)
req.WithContext(context.Background())

resp, err := client.Do(req)
if err != nil {
Expand All @@ -107,22 +105,22 @@ func downloadFile(client httpClient, token string, downloadURL string, writer io
return err
}

func formReq(endpoint string, values url.Values) (req *http.Request, err error) {
if req, err = http.NewRequest("POST", endpoint, strings.NewReader(values.Encode())); err != nil {
func formReq(ctx context.Context, endpoint string, values url.Values) (req *http.Request, err error) {
if req, err = http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(values.Encode())); err != nil {
return nil, err
}

req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
return req, nil
}

func jsonReq(endpoint string, body interface{}) (req *http.Request, err error) {
func jsonReq(ctx context.Context, endpoint string, body interface{}) (req *http.Request, err error) {
buffer := bytes.NewBuffer([]byte{})
if err = json.NewEncoder(buffer).Encode(body); err != nil {
return nil, err
}

if req, err = http.NewRequest("POST", endpoint, buffer); err != nil {
if req, err = http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer); err != nil {
return nil, err
}

Expand Down Expand Up @@ -184,7 +182,6 @@ func postWithMultipartResponse(ctx context.Context, client httpClient, path, nam
}
req.Header.Add("Content-Type", wr.FormDataContentType())
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
req = req.WithContext(ctx)
resp, err := client.Do(req)

if err != nil {
Expand All @@ -206,7 +203,6 @@ func postWithMultipartResponse(ctx context.Context, client httpClient, path, nam
}

func doPost(ctx context.Context, client httpClient, req *http.Request, parser responseParser, d Debug) error {
req = req.WithContext(ctx)
resp, err := client.Do(req)
if err != nil {
return err
Expand All @@ -224,7 +220,7 @@ func doPost(ctx context.Context, client httpClient, req *http.Request, parser re
// post JSON.
func postJSON(ctx context.Context, client httpClient, endpoint, token string, json []byte, intf interface{}, d Debug) error {
reqBody := bytes.NewBuffer(json)
req, err := http.NewRequest("POST", endpoint, reqBody)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, reqBody)
if err != nil {
return err
}
Expand All @@ -237,7 +233,7 @@ func postJSON(ctx context.Context, client httpClient, endpoint, token string, js
// post a url encoded form.
func postForm(ctx context.Context, client httpClient, endpoint string, values url.Values, intf interface{}, d Debug) error {
reqBody := strings.NewReader(values.Encode())
req, err := http.NewRequest("POST", endpoint, reqBody)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, reqBody)
if err != nil {
return err
}
Expand All @@ -246,7 +242,7 @@ func postForm(ctx context.Context, client httpClient, endpoint string, values ur
}

func getResource(ctx context.Context, client httpClient, endpoint, token string, values url.Values, intf interface{}, d Debug) error {
req, err := http.NewRequest("GET", endpoint, nil)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return err
}
Expand Down

0 comments on commit ff619ea

Please sign in to comment.