Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support pass context.Context to all methods #1034

Merged
merged 6 commits into from Feb 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion apps.go
Expand Up @@ -44,14 +44,18 @@ func (api *Client) ListEventAuthorizationsContext(ctx context.Context, eventCont
}

func (api *Client) UninstallApp(clientID, clientSecret string) error {
return api.UninstallAppContext(context.Background(), clientID, clientSecret)
}

func (api *Client) UninstallAppContext(ctx context.Context, clientID, clientSecret string) error {
values := url.Values{
"client_id": {clientID},
"client_secret": {clientSecret},
}

response := SlackResponse{}

err := api.getMethod(context.Background(), "apps.uninstall", api.token, values, &response)
err := api.getMethod(ctx, "apps.uninstall", api.token, values, &response)
if err != nil {
return err
}
Expand Down
60 changes: 29 additions & 31 deletions chat.go
Expand Up @@ -86,12 +86,7 @@ func NewPostMessageParameters() PostMessageParameters {

// DeleteMessage deletes a message in a channel
func (api *Client) DeleteMessage(channel, messageTimestamp string) (string, string, error) {
respChannel, respTimestamp, _, err := api.SendMessageContext(
context.Background(),
channel,
MsgOptionDelete(messageTimestamp),
)
return respChannel, respTimestamp, err
return api.DeleteMessageContext(context.Background(), channel, messageTimestamp)
}

// DeleteMessageContext deletes a message in a channel with a custom context
Expand All @@ -108,8 +103,15 @@ func (api *Client) DeleteMessageContext(ctx context.Context, channel, messageTim
// Message is escaped by default according to https://api.slack.com/docs/formatting
// Use http://davestevens.github.io/slack-message-builder/ to help crafting your message.
func (api *Client) ScheduleMessage(channelID, postAt string, options ...MsgOption) (string, string, error) {
return api.ScheduleMessageContext(context.Background(), channelID, postAt, options...)
}

// ScheduleMessageContext sends a message to a channel with a custom context
//
// For more details, see ScheduleMessage documentation.
func (api *Client) ScheduleMessageContext(ctx context.Context, channelID, postAt string, options ...MsgOption) (string, string, error) {
respChannel, respTimestamp, _, err := api.SendMessageContext(
context.Background(),
ctx,
channelID,
MsgOptionSchedule(postAt),
MsgOptionCompose(options...),
Expand All @@ -121,13 +123,7 @@ func (api *Client) ScheduleMessage(channelID, postAt string, options ...MsgOptio
// Message is escaped by default according to https://api.slack.com/docs/formatting
// Use http://davestevens.github.io/slack-message-builder/ to help crafting your message.
func (api *Client) PostMessage(channelID string, options ...MsgOption) (string, string, error) {
respChannel, respTimestamp, _, err := api.SendMessageContext(
context.Background(),
channelID,
MsgOptionPost(),
MsgOptionCompose(options...),
)
return respChannel, respTimestamp, err
return api.PostMessageContext(context.Background(), channelID, options...)
}

// PostMessageContext sends a message to a channel with a custom context
Expand All @@ -146,12 +142,7 @@ func (api *Client) PostMessageContext(ctx context.Context, channelID string, opt
// Message is escaped by default according to https://api.slack.com/docs/formatting
// Use http://davestevens.github.io/slack-message-builder/ to help crafting your message.
func (api *Client) PostEphemeral(channelID, userID string, options ...MsgOption) (string, error) {
return api.PostEphemeralContext(
context.Background(),
channelID,
userID,
options...,
)
return api.PostEphemeralContext(context.Background(), channelID, userID, options...)
}

// PostEphemeralContext sends an ephemeal message to a user in a channel with a custom context
Expand All @@ -168,12 +159,7 @@ func (api *Client) PostEphemeralContext(ctx context.Context, channelID, userID s

// UpdateMessage updates a message in a channel
func (api *Client) UpdateMessage(channelID, timestamp string, options ...MsgOption) (string, string, string, error) {
return api.SendMessageContext(
context.Background(),
channelID,
MsgOptionUpdate(timestamp),
MsgOptionCompose(options...),
)
return api.UpdateMessageContext(context.Background(), channelID, timestamp, options...)
}

// UpdateMessageContext updates a message in a channel
Expand Down Expand Up @@ -225,7 +211,7 @@ func (api *Client) SendMessageContext(ctx context.Context, channelID string, opt
response chatResponseFull
)

if req, parser, err = buildSender(api.endpoint, options...).BuildRequest(api.token, channelID); err != nil {
if req, parser, err = buildSender(api.endpoint, options...).BuildRequestContext(ctx, api.token, channelID); err != nil {
return "", "", "", err
}

Expand Down Expand Up @@ -306,6 +292,10 @@ type sendConfig struct {
}

func (t sendConfig) BuildRequest(token, channelID string) (req *http.Request, _ func(*chatResponseFull) responseParser, err error) {
return t.BuildRequestContext(context.Background(), token, channelID)
}

func (t sendConfig) BuildRequestContext(ctx context.Context, token, channelID string) (req *http.Request, _ func(*chatResponseFull) responseParser, err error) {
if t, err = applyMsgOptions(token, channelID, t.apiurl, t.options...); err != nil {
return nil, nil, err
}
Expand All @@ -320,9 +310,9 @@ func (t sendConfig) BuildRequest(token, channelID string) (req *http.Request, _
responseType: t.responseType,
replaceOriginal: t.replaceOriginal,
deleteOriginal: t.deleteOriginal,
}.BuildRequest()
}.BuildRequestContext(ctx)
default:
return formSender{endpoint: t.endpoint, values: t.values}.BuildRequest()
return formSender{endpoint: t.endpoint, values: t.values}.BuildRequestContext(ctx)
}
}

Expand All @@ -332,7 +322,11 @@ type formSender struct {
}

func (t formSender) BuildRequest() (*http.Request, func(*chatResponseFull) responseParser, error) {
req, err := formReq(t.endpoint, t.values)
return t.BuildRequestContext(context.Background())
}

func (t formSender) BuildRequestContext(ctx context.Context) (*http.Request, func(*chatResponseFull) responseParser, error) {
req, err := formReq(ctx, t.endpoint, t.values)
return req, func(resp *chatResponseFull) responseParser {
return newJSONParser(resp)
}, err
Expand All @@ -349,7 +343,11 @@ type responseURLSender struct {
}

func (t responseURLSender) BuildRequest() (*http.Request, func(*chatResponseFull) responseParser, error) {
req, err := jsonReq(t.endpoint, Msg{
return t.BuildRequestContext(context.Background())
}

func (t responseURLSender) BuildRequestContext(ctx context.Context) (*http.Request, func(*chatResponseFull) responseParser, error) {
req, err := jsonReq(ctx, t.endpoint, Msg{
Text: t.values.Get("text"),
Timestamp: t.values.Get("ts"),
Attachments: t.attachments,
Expand Down
79 changes: 44 additions & 35 deletions files.go
Expand Up @@ -202,48 +202,21 @@ func (api *Client) GetFileInfoContext(ctx context.Context, fileID string, count,

// GetFile retreives a given file from its private download URL
func (api *Client) GetFile(downloadURL string, writer io.Writer) error {
return downloadFile(api.httpclient, api.token, downloadURL, writer, api)
return api.GetFileContext(context.Background(), downloadURL, writer)
}

// GetFileContext retreives a given file from its private download URL with a custom context
//
// For more details, see GetFile documentation.
func (api *Client) GetFileContext(ctx context.Context, downloadURL string, writer io.Writer) error {
return downloadFile(ctx, api.httpclient, api.token, downloadURL, writer, api)
}

// GetFiles retrieves all files according to the parameters given
func (api *Client) GetFiles(params GetFilesParameters) ([]File, *Paging, error) {
return api.GetFilesContext(context.Background(), params)
}

// ListFiles retrieves all files according to the parameters given. Uses cursor based pagination.
func (api *Client) ListFiles(params ListFilesParameters) ([]File, *ListFilesParameters, error) {
return api.ListFilesContext(context.Background(), params)
}

// ListFilesContext retrieves all files according to the parameters given with a custom context. Uses cursor based pagination.
func (api *Client) ListFilesContext(ctx context.Context, params ListFilesParameters) ([]File, *ListFilesParameters, error) {
values := url.Values{
"token": {api.token},
}

if params.User != DEFAULT_FILES_USER {
values.Add("user", params.User)
}
if params.Channel != DEFAULT_FILES_CHANNEL {
values.Add("channel", params.Channel)
}
if params.Limit != DEFAULT_FILES_COUNT {
values.Add("limit", strconv.Itoa(params.Limit))
}
if params.Cursor != "" {
values.Add("cursor", params.Cursor)
}

response, err := api.fileRequest(ctx, "files.list", values)
if err != nil {
return nil, nil, err
}

params.Cursor = response.Metadata.Cursor

return response.Files, &params, nil
}

// GetFilesContext retrieves all files according to the parameters given with a custom context
func (api *Client) GetFilesContext(ctx context.Context, params GetFilesParameters) ([]File, *Paging, error) {
values := url.Values{
Expand Down Expand Up @@ -281,6 +254,42 @@ func (api *Client) GetFilesContext(ctx context.Context, params GetFilesParameter
return response.Files, &response.Paging, nil
}

// ListFiles retrieves all files according to the parameters given. Uses cursor based pagination.
func (api *Client) ListFiles(params ListFilesParameters) ([]File, *ListFilesParameters, error) {
return api.ListFilesContext(context.Background(), params)
}

// ListFilesContext retrieves all files according to the parameters given with a custom context.
//
// For more details, see ListFiles documentation.
func (api *Client) ListFilesContext(ctx context.Context, params ListFilesParameters) ([]File, *ListFilesParameters, error) {
values := url.Values{
"token": {api.token},
}

if params.User != DEFAULT_FILES_USER {
values.Add("user", params.User)
}
if params.Channel != DEFAULT_FILES_CHANNEL {
values.Add("channel", params.Channel)
}
if params.Limit != DEFAULT_FILES_COUNT {
values.Add("limit", strconv.Itoa(params.Limit))
}
if params.Cursor != "" {
values.Add("cursor", params.Cursor)
}

response, err := api.fileRequest(ctx, "files.list", values)
if err != nil {
return nil, nil, err
}

params.Cursor = response.Metadata.Cursor

return response.Files, &params, nil
}

// UploadFile uploads a file
func (api *Client) UploadFile(params FileUploadParameters) (file *File, err error) {
return api.UploadFileContext(context.Background(), params)
Expand Down
6 changes: 5 additions & 1 deletion info.go
Expand Up @@ -321,9 +321,13 @@ type UserPrefs struct {
}

func (api *Client) GetUserPrefs() (*UserPrefsCarrier, error) {
return api.GetUserPrefsContext(context.Background())
}

func (api *Client) GetUserPrefsContext(ctx context.Context) (*UserPrefsCarrier, error) {
response := UserPrefsCarrier{}

err := api.getMethod(context.Background(), "users.prefs.get", api.token, url.Values{}, &response)
err := api.getMethod(ctx, "users.prefs.get", api.token, url.Values{}, &response)
if err != nil {
return nil, err
}
Expand Down
26 changes: 11 additions & 15 deletions misc.go
Expand Up @@ -66,29 +66,27 @@ func (e *RateLimitedError) Retryable() bool {
}

func fileUploadReq(ctx context.Context, path string, values url.Values, r io.Reader) (*http.Request, error) {
req, err := http.NewRequest("POST", path, r)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, path, r)
if err != nil {
return nil, err
}

req = req.WithContext(ctx)
req.URL.RawQuery = (values).Encode()
req.URL.RawQuery = values.Encode()
return req, nil
}

func downloadFile(client httpClient, token string, downloadURL string, writer io.Writer, d Debug) error {
func downloadFile(ctx context.Context, client httpClient, token string, downloadURL string, writer io.Writer, d Debug) error {
if downloadURL == "" {
return fmt.Errorf("received empty download URL")
}

req, err := http.NewRequest("GET", downloadURL, &bytes.Buffer{})
req, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadURL, &bytes.Buffer{})
if err != nil {
return err
}

var bearer = "Bearer " + token
req.Header.Add("Authorization", bearer)
req.WithContext(context.Background())

resp, err := client.Do(req)
if err != nil {
Expand All @@ -107,22 +105,22 @@ func downloadFile(client httpClient, token string, downloadURL string, writer io
return err
}

func formReq(endpoint string, values url.Values) (req *http.Request, err error) {
if req, err = http.NewRequest("POST", endpoint, strings.NewReader(values.Encode())); err != nil {
func formReq(ctx context.Context, endpoint string, values url.Values) (req *http.Request, err error) {
if req, err = http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(values.Encode())); err != nil {
return nil, err
}

req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
return req, nil
}

func jsonReq(endpoint string, body interface{}) (req *http.Request, err error) {
func jsonReq(ctx context.Context, endpoint string, body interface{}) (req *http.Request, err error) {
buffer := bytes.NewBuffer([]byte{})
if err = json.NewEncoder(buffer).Encode(body); err != nil {
return nil, err
}

if req, err = http.NewRequest("POST", endpoint, buffer); err != nil {
if req, err = http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer); err != nil {
return nil, err
}

Expand Down Expand Up @@ -184,7 +182,6 @@ func postWithMultipartResponse(ctx context.Context, client httpClient, path, nam
}
req.Header.Add("Content-Type", wr.FormDataContentType())
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
req = req.WithContext(ctx)
resp, err := client.Do(req)

if err != nil {
Expand All @@ -206,7 +203,6 @@ func postWithMultipartResponse(ctx context.Context, client httpClient, path, nam
}

func doPost(ctx context.Context, client httpClient, req *http.Request, parser responseParser, d Debug) error {
req = req.WithContext(ctx)
resp, err := client.Do(req)
if err != nil {
return err
Expand All @@ -224,7 +220,7 @@ func doPost(ctx context.Context, client httpClient, req *http.Request, parser re
// post JSON.
func postJSON(ctx context.Context, client httpClient, endpoint, token string, json []byte, intf interface{}, d Debug) error {
reqBody := bytes.NewBuffer(json)
req, err := http.NewRequest("POST", endpoint, reqBody)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, reqBody)
if err != nil {
return err
}
Expand All @@ -237,7 +233,7 @@ func postJSON(ctx context.Context, client httpClient, endpoint, token string, js
// post a url encoded form.
func postForm(ctx context.Context, client httpClient, endpoint string, values url.Values, intf interface{}, d Debug) error {
reqBody := strings.NewReader(values.Encode())
req, err := http.NewRequest("POST", endpoint, reqBody)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, reqBody)
if err != nil {
return err
}
Expand All @@ -246,7 +242,7 @@ func postForm(ctx context.Context, client httpClient, endpoint string, values ur
}

func getResource(ctx context.Context, client httpClient, endpoint, token string, values url.Values, intf interface{}, d Debug) error {
req, err := http.NewRequest("GET", endpoint, nil)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return err
}
Expand Down