-
Notifications
You must be signed in to change notification settings - Fork 560
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add reddit provider and session (#523)
* add reddit provider and session * fix: wechat failing test
- Loading branch information
Showing
5 changed files
with
393 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
package reddit | ||
|
||
import ( | ||
"encoding/json" | ||
"fmt" | ||
"github.com/markbates/goth" | ||
"golang.org/x/oauth2" | ||
"io" | ||
"net/http" | ||
"time" | ||
) | ||
|
||
const ( | ||
authURL = "https://www.reddit.com/api/v1/authorize" | ||
) | ||
|
||
type Provider struct { | ||
providerName string | ||
duration string | ||
config oauth2.Config | ||
client http.Client | ||
// TODO: userURL should be a constant | ||
userURL string | ||
} | ||
|
||
func New(clientID string, clientSecret string, redirectURI string, duration string, tokenEndpoint string, userURL string, scopes ...string) Provider { | ||
return Provider{ | ||
providerName: "reddit", | ||
duration: duration, | ||
config: oauth2.Config{ | ||
ClientID: clientID, | ||
ClientSecret: clientSecret, | ||
Endpoint: oauth2.Endpoint{ | ||
AuthURL: authURL, | ||
TokenURL: tokenEndpoint, | ||
AuthStyle: 0, | ||
}, | ||
RedirectURL: redirectURI, | ||
Scopes: scopes, | ||
}, | ||
client: http.Client{}, | ||
userURL: userURL, | ||
} | ||
} | ||
|
||
func (p *Provider) Name() string { | ||
return p.providerName | ||
} | ||
|
||
func (p *Provider) SetName(name string) { | ||
p.providerName = name | ||
} | ||
|
||
func (p *Provider) UnmarshalSession(s string) (goth.Session, error) { | ||
session := &Session{} | ||
err := json.Unmarshal([]byte(s), session) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return session, nil | ||
} | ||
|
||
func (p *Provider) Debug(b bool) {} | ||
|
||
func (p *Provider) RefreshToken(refreshToken string) (*oauth2.Token, error) { | ||
return nil, nil | ||
} | ||
|
||
func (p *Provider) RefreshTokenAvailable() bool { | ||
return true | ||
} | ||
|
||
func (p *Provider) BeginAuth(state string) (goth.Session, error) { | ||
authCodeOption := oauth2.SetAuthURLParam("duration", p.duration) | ||
return &Session{AuthURL: p.config.AuthCodeURL(state, authCodeOption)}, nil | ||
} | ||
|
||
type redditResponse struct { | ||
Id string `json:"id"` | ||
Name string `json:"name"` | ||
} | ||
|
||
func (p *Provider) FetchUser(s goth.Session) (goth.User, error) { | ||
session := s.(*Session) | ||
request, err := http.NewRequest("GET", p.userURL, nil) | ||
if err != nil { | ||
return goth.User{}, err | ||
} | ||
|
||
bearer := "Bearer " + session.AccessToken | ||
request.Header.Add("Authorization", bearer) | ||
|
||
res, err := p.client.Do(request) | ||
if err != nil { | ||
return goth.User{}, err | ||
} | ||
|
||
defer res.Body.Close() | ||
|
||
if res.StatusCode != http.StatusOK { | ||
if res.StatusCode == http.StatusForbidden { | ||
return goth.User{}, fmt.Errorf("%s responded with a %s because you did not provide the identity scope which is required to fetch user profile", p.providerName, res.Status) | ||
} | ||
return goth.User{}, fmt.Errorf("%s responded with a %d trying to fetch user profile", p.providerName, res.StatusCode) | ||
} | ||
|
||
bits, err := io.ReadAll(res.Body) | ||
if err != nil { | ||
return goth.User{}, err | ||
} | ||
|
||
var r redditResponse | ||
|
||
err = json.Unmarshal(bits, &r) | ||
if err != nil { | ||
return goth.User{}, err | ||
} | ||
|
||
gothUser := goth.User{ | ||
RawData: nil, | ||
Provider: p.Name(), | ||
Name: r.Name, | ||
UserID: r.Id, | ||
AccessToken: session.AccessToken, | ||
RefreshToken: session.RefreshToken, | ||
ExpiresAt: time.Time{}, | ||
} | ||
|
||
err = json.Unmarshal(bits, &gothUser.RawData) | ||
if err != nil { | ||
return goth.User{}, err | ||
} | ||
|
||
return gothUser, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
package reddit | ||
|
||
import ( | ||
"encoding/json" | ||
"github.com/markbates/goth" | ||
"golang.org/x/oauth2" | ||
"net/http" | ||
"net/http/httptest" | ||
"reflect" | ||
"testing" | ||
"time" | ||
) | ||
|
||
var response = redditResponse{ | ||
Id: "invader21", | ||
Name: "JohnDoe", | ||
} | ||
|
||
func TestProvider(t *testing.T) { | ||
t.Run("create a new provider", func(t *testing.T) { | ||
got := New("client id", "client secret", "redirect uri", "duration", "example.com", "userURL", "scope1", "scope2", "scope 3") | ||
want := Provider{ | ||
providerName: "reddit", | ||
duration: "duration", | ||
config: oauth2.Config{ | ||
ClientID: "client id", | ||
ClientSecret: "client secret", | ||
Endpoint: oauth2.Endpoint{ | ||
AuthURL: authURL, | ||
TokenURL: "example.com", | ||
AuthStyle: 0, | ||
}, | ||
RedirectURL: "redirect uri", | ||
Scopes: []string{"scope1", "scope2", "scope 3"}, | ||
}, | ||
userURL: "userURL", | ||
} | ||
|
||
if !reflect.DeepEqual(got, want) { | ||
t.Errorf("\033[31;1;4mgot\033[0m %+v, \n\t \033[31;1;4mwant\033[0m %+v", got, want) | ||
} | ||
}) | ||
|
||
t.Run("fetch reddit user that created the given session", func(t *testing.T) { | ||
redditServer := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { | ||
b, err := json.Marshal(response) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
writer.Header().Add("Content-Type", "application/json") | ||
writer.Write(b) | ||
})) | ||
|
||
defer redditServer.Close() | ||
|
||
userURL := redditServer.URL | ||
p := New("client id", "client secret", "redirect uri", "duration", "example.com", userURL, "scope1", "scope2", "scope 3") | ||
s := &Session{ | ||
AuthURL: "", | ||
AccessToken: "i am a token", | ||
TokenType: "bearer", | ||
RefreshToken: "your refresh token", | ||
Expiry: time.Time{}, | ||
} | ||
|
||
got, err := p.FetchUser(s) | ||
if err != nil { | ||
t.Errorf("did not expect an error: %s", err) | ||
} | ||
|
||
want := goth.User{ | ||
RawData: map[string]interface{}{ | ||
"id": "invader21", | ||
"name": "JohnDoe", | ||
}, | ||
Provider: "reddit", | ||
Name: "JohnDoe", | ||
UserID: "invader21", | ||
AccessToken: "i am a token", | ||
RefreshToken: "your refresh token", | ||
ExpiresAt: time.Time{}, | ||
} | ||
|
||
if !reflect.DeepEqual(got, want) { | ||
t.Errorf("\033[31;1;4mgot\033[0m %+v, \n\t\t \033[31;1;4mwant\033[0m %+v", got, want) | ||
} | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
package reddit | ||
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
"errors" | ||
"github.com/markbates/goth" | ||
"golang.org/x/oauth2" | ||
"time" | ||
) | ||
|
||
type Session struct { | ||
AuthURL string | ||
AccessToken string `json:"access_token"` | ||
TokenType string `json:"token_type,omitempty"` | ||
RefreshToken string `json:"refresh_token,omitempty"` | ||
Expiry time.Time `json:"expiry,omitempty"` | ||
} | ||
|
||
func (s *Session) GetAuthURL() (string, error) { | ||
return s.AuthURL, nil | ||
} | ||
|
||
func (s *Session) Marshal() string { | ||
b, _ := json.Marshal(s) | ||
return string(b) | ||
} | ||
|
||
func (s *Session) Authorize(provider goth.Provider, params goth.Params) (string, error) { | ||
p := provider.(*Provider) | ||
t, err := p.config.Exchange(context.WithValue(context.Background(), oauth2.HTTPClient, p.client), params.Get("code")) | ||
if err != nil { | ||
return "", err | ||
} | ||
|
||
if !t.Valid() { | ||
return "", errors.New("invalid token received from provider") | ||
} | ||
|
||
s.AccessToken = t.AccessToken | ||
s.TokenType = t.TokenType | ||
s.RefreshToken = t.RefreshToken | ||
s.Expiry = t.Expiry | ||
|
||
return s.AccessToken, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
package reddit | ||
|
||
import ( | ||
"encoding/json" | ||
"net/http" | ||
"net/http/httptest" | ||
"net/url" | ||
"testing" | ||
) | ||
|
||
var validAuthResponseTestData = struct { | ||
AccessToken string `json:"access_token"` | ||
TokenType string `json:"token_type"` | ||
ExpiresIn int `json:"expires_in"` | ||
Scope string `json:"scope"` | ||
RefreshToken string `json:"refresh_token"` | ||
}{ | ||
AccessToken: "i am a token", | ||
TokenType: "type", | ||
ExpiresIn: 120, | ||
Scope: "identity", | ||
RefreshToken: "your refresh token", | ||
} | ||
|
||
var invalidAuthResponseTestData = struct { | ||
AccessToken string `json:"access_token"` | ||
TokenType string `json:"token_type"` | ||
ExpiresIn int `json:"expires_in"` | ||
Scope string `json:"scope"` | ||
RefreshToken string `json:"refresh_token"` | ||
}{ | ||
AccessToken: "", | ||
TokenType: "type", | ||
ExpiresIn: 120, | ||
Scope: "identity", | ||
RefreshToken: "Your refresh token", | ||
} | ||
|
||
func TestSession(t *testing.T) { | ||
t.Run("gets the URL for the authentication end-point for the provider", func(t *testing.T) { | ||
s := Session{AuthURL: "example.com"} | ||
got, err := s.GetAuthURL() | ||
if err != nil { | ||
t.Fatal("should return a url string") | ||
} | ||
|
||
want := "example.com" | ||
|
||
if got != want { | ||
t.Errorf("got %q want %q", got, want) | ||
} | ||
}) | ||
|
||
t.Run("generates a string representation of the session", func(t *testing.T) { | ||
s := Session{ | ||
AuthURL: "example", | ||
} | ||
got := s.Marshal() | ||
want := `{"AuthURL":"example","access_token":"","expiry":"0001-01-01T00:00:00Z"}` | ||
|
||
if got != want { | ||
t.Errorf("got %q want %q", got, want) | ||
} | ||
}) | ||
|
||
t.Run("return an access token", func(t *testing.T) { | ||
|
||
s := Session{AuthURL: "example.com"} | ||
authServer := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { | ||
b, err := json.Marshal(validAuthResponseTestData) | ||
if err != nil { | ||
writer.WriteHeader(http.StatusInternalServerError) | ||
return | ||
} | ||
writer.Header().Add("Content-Type", "application/json") | ||
writer.WriteHeader(http.StatusOK) | ||
writer.Write(b) | ||
})) | ||
|
||
tokenURL := authServer.URL | ||
|
||
p := New("CLIENT_ID", "CLIENT_SECRET", "URI", "DURATION", tokenURL, "SCOPE_STRING1", "SCOPE_STRING2") | ||
u := url.Values{} | ||
u.Set("code", "12345678") | ||
|
||
got, err := s.Authorize(&p, u) | ||
if err != nil { | ||
t.Fatal("did not expect an error: ", err) | ||
} | ||
|
||
want := validAuthResponseTestData.AccessToken | ||
|
||
if got != want { | ||
t.Errorf("got %q want %q", got, want) | ||
} | ||
}) | ||
|
||
t.Run("validates access token", func(t *testing.T) { | ||
s := Session{AuthURL: "example.com"} | ||
authServer := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { | ||
b, err := json.Marshal(invalidAuthResponseTestData) | ||
if err != nil { | ||
writer.WriteHeader(http.StatusInternalServerError) | ||
return | ||
} | ||
writer.Header().Add("Content-Type", "application/json") | ||
writer.WriteHeader(http.StatusOK) | ||
writer.Write(b) | ||
})) | ||
|
||
tokenURL := authServer.URL | ||
|
||
p := New("CLIENT_ID", "CLIENT_SECRET", "URI", "DURATION", tokenURL, "SCOPE_STRING1", "SCOPE_STRING2") | ||
u := url.Values{} | ||
u.Set("code", "12345678") | ||
|
||
_, err := s.Authorize(&p, u) | ||
if err == nil { | ||
t.Errorf("expected an error but didn't get one") | ||
} | ||
}) | ||
} |
Oops, something went wrong.