Skip to content

Commit

Permalink
Add reddit provider and session (#523)
Browse files Browse the repository at this point in the history
* add reddit provider and session

* fix: wechat failing test
  • Loading branch information
ccaneke committed Oct 3, 2023
1 parent 08df1f0 commit ef6c303
Show file tree
Hide file tree
Showing 5 changed files with 393 additions and 1 deletion.
136 changes: 136 additions & 0 deletions providers/reddit/reddit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
package reddit

import (
"encoding/json"
"fmt"
"github.com/markbates/goth"
"golang.org/x/oauth2"
"io"
"net/http"
"time"
)

const (
authURL = "https://www.reddit.com/api/v1/authorize"
)

type Provider struct {
providerName string
duration string
config oauth2.Config
client http.Client
// TODO: userURL should be a constant
userURL string
}

func New(clientID string, clientSecret string, redirectURI string, duration string, tokenEndpoint string, userURL string, scopes ...string) Provider {
return Provider{
providerName: "reddit",
duration: duration,
config: oauth2.Config{
ClientID: clientID,
ClientSecret: clientSecret,
Endpoint: oauth2.Endpoint{
AuthURL: authURL,
TokenURL: tokenEndpoint,
AuthStyle: 0,
},
RedirectURL: redirectURI,
Scopes: scopes,
},
client: http.Client{},
userURL: userURL,
}
}

func (p *Provider) Name() string {
return p.providerName
}

func (p *Provider) SetName(name string) {
p.providerName = name
}

func (p *Provider) UnmarshalSession(s string) (goth.Session, error) {
session := &Session{}
err := json.Unmarshal([]byte(s), session)
if err != nil {
return nil, err
}

return session, nil
}

func (p *Provider) Debug(b bool) {}

func (p *Provider) RefreshToken(refreshToken string) (*oauth2.Token, error) {
return nil, nil
}

func (p *Provider) RefreshTokenAvailable() bool {
return true
}

func (p *Provider) BeginAuth(state string) (goth.Session, error) {
authCodeOption := oauth2.SetAuthURLParam("duration", p.duration)
return &Session{AuthURL: p.config.AuthCodeURL(state, authCodeOption)}, nil
}

type redditResponse struct {
Id string `json:"id"`
Name string `json:"name"`
}

func (p *Provider) FetchUser(s goth.Session) (goth.User, error) {
session := s.(*Session)
request, err := http.NewRequest("GET", p.userURL, nil)
if err != nil {
return goth.User{}, err
}

bearer := "Bearer " + session.AccessToken
request.Header.Add("Authorization", bearer)

res, err := p.client.Do(request)
if err != nil {
return goth.User{}, err
}

defer res.Body.Close()

if res.StatusCode != http.StatusOK {
if res.StatusCode == http.StatusForbidden {
return goth.User{}, fmt.Errorf("%s responded with a %s because you did not provide the identity scope which is required to fetch user profile", p.providerName, res.Status)
}
return goth.User{}, fmt.Errorf("%s responded with a %d trying to fetch user profile", p.providerName, res.StatusCode)
}

bits, err := io.ReadAll(res.Body)
if err != nil {
return goth.User{}, err
}

var r redditResponse

err = json.Unmarshal(bits, &r)
if err != nil {
return goth.User{}, err
}

gothUser := goth.User{
RawData: nil,
Provider: p.Name(),
Name: r.Name,
UserID: r.Id,
AccessToken: session.AccessToken,
RefreshToken: session.RefreshToken,
ExpiresAt: time.Time{},
}

err = json.Unmarshal(bits, &gothUser.RawData)
if err != nil {
return goth.User{}, err
}

return gothUser, nil
}
88 changes: 88 additions & 0 deletions providers/reddit/reddit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package reddit

import (
"encoding/json"
"github.com/markbates/goth"
"golang.org/x/oauth2"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"time"
)

var response = redditResponse{
Id: "invader21",
Name: "JohnDoe",
}

func TestProvider(t *testing.T) {
t.Run("create a new provider", func(t *testing.T) {
got := New("client id", "client secret", "redirect uri", "duration", "example.com", "userURL", "scope1", "scope2", "scope 3")
want := Provider{
providerName: "reddit",
duration: "duration",
config: oauth2.Config{
ClientID: "client id",
ClientSecret: "client secret",
Endpoint: oauth2.Endpoint{
AuthURL: authURL,
TokenURL: "example.com",
AuthStyle: 0,
},
RedirectURL: "redirect uri",
Scopes: []string{"scope1", "scope2", "scope 3"},
},
userURL: "userURL",
}

if !reflect.DeepEqual(got, want) {
t.Errorf("\033[31;1;4mgot\033[0m %+v, \n\t \033[31;1;4mwant\033[0m %+v", got, want)
}
})

t.Run("fetch reddit user that created the given session", func(t *testing.T) {
redditServer := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
b, err := json.Marshal(response)
if err != nil {
t.Fatal(err)
}
writer.Header().Add("Content-Type", "application/json")
writer.Write(b)
}))

defer redditServer.Close()

userURL := redditServer.URL
p := New("client id", "client secret", "redirect uri", "duration", "example.com", userURL, "scope1", "scope2", "scope 3")
s := &Session{
AuthURL: "",
AccessToken: "i am a token",
TokenType: "bearer",
RefreshToken: "your refresh token",
Expiry: time.Time{},
}

got, err := p.FetchUser(s)
if err != nil {
t.Errorf("did not expect an error: %s", err)
}

want := goth.User{
RawData: map[string]interface{}{
"id": "invader21",
"name": "JohnDoe",
},
Provider: "reddit",
Name: "JohnDoe",
UserID: "invader21",
AccessToken: "i am a token",
RefreshToken: "your refresh token",
ExpiresAt: time.Time{},
}

if !reflect.DeepEqual(got, want) {
t.Errorf("\033[31;1;4mgot\033[0m %+v, \n\t\t \033[31;1;4mwant\033[0m %+v", got, want)
}
})
}
46 changes: 46 additions & 0 deletions providers/reddit/session.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package reddit

import (
"context"
"encoding/json"
"errors"
"github.com/markbates/goth"
"golang.org/x/oauth2"
"time"
)

type Session struct {
AuthURL string
AccessToken string `json:"access_token"`
TokenType string `json:"token_type,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
Expiry time.Time `json:"expiry,omitempty"`
}

func (s *Session) GetAuthURL() (string, error) {
return s.AuthURL, nil
}

func (s *Session) Marshal() string {
b, _ := json.Marshal(s)
return string(b)
}

func (s *Session) Authorize(provider goth.Provider, params goth.Params) (string, error) {
p := provider.(*Provider)
t, err := p.config.Exchange(context.WithValue(context.Background(), oauth2.HTTPClient, p.client), params.Get("code"))
if err != nil {
return "", err
}

if !t.Valid() {
return "", errors.New("invalid token received from provider")
}

s.AccessToken = t.AccessToken
s.TokenType = t.TokenType
s.RefreshToken = t.RefreshToken
s.Expiry = t.Expiry

return s.AccessToken, nil
}
122 changes: 122 additions & 0 deletions providers/reddit/session_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package reddit

import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"testing"
)

var validAuthResponseTestData = struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
RefreshToken string `json:"refresh_token"`
}{
AccessToken: "i am a token",
TokenType: "type",
ExpiresIn: 120,
Scope: "identity",
RefreshToken: "your refresh token",
}

var invalidAuthResponseTestData = struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
RefreshToken string `json:"refresh_token"`
}{
AccessToken: "",
TokenType: "type",
ExpiresIn: 120,
Scope: "identity",
RefreshToken: "Your refresh token",
}

func TestSession(t *testing.T) {
t.Run("gets the URL for the authentication end-point for the provider", func(t *testing.T) {
s := Session{AuthURL: "example.com"}
got, err := s.GetAuthURL()
if err != nil {
t.Fatal("should return a url string")
}

want := "example.com"

if got != want {
t.Errorf("got %q want %q", got, want)
}
})

t.Run("generates a string representation of the session", func(t *testing.T) {
s := Session{
AuthURL: "example",
}
got := s.Marshal()
want := `{"AuthURL":"example","access_token":"","expiry":"0001-01-01T00:00:00Z"}`

if got != want {
t.Errorf("got %q want %q", got, want)
}
})

t.Run("return an access token", func(t *testing.T) {

s := Session{AuthURL: "example.com"}
authServer := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
b, err := json.Marshal(validAuthResponseTestData)
if err != nil {
writer.WriteHeader(http.StatusInternalServerError)
return
}
writer.Header().Add("Content-Type", "application/json")
writer.WriteHeader(http.StatusOK)
writer.Write(b)
}))

tokenURL := authServer.URL

p := New("CLIENT_ID", "CLIENT_SECRET", "URI", "DURATION", tokenURL, "SCOPE_STRING1", "SCOPE_STRING2")
u := url.Values{}
u.Set("code", "12345678")

got, err := s.Authorize(&p, u)
if err != nil {
t.Fatal("did not expect an error: ", err)
}

want := validAuthResponseTestData.AccessToken

if got != want {
t.Errorf("got %q want %q", got, want)
}
})

t.Run("validates access token", func(t *testing.T) {
s := Session{AuthURL: "example.com"}
authServer := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
b, err := json.Marshal(invalidAuthResponseTestData)
if err != nil {
writer.WriteHeader(http.StatusInternalServerError)
return
}
writer.Header().Add("Content-Type", "application/json")
writer.WriteHeader(http.StatusOK)
writer.Write(b)
}))

tokenURL := authServer.URL

p := New("CLIENT_ID", "CLIENT_SECRET", "URI", "DURATION", tokenURL, "SCOPE_STRING1", "SCOPE_STRING2")
u := url.Values{}
u.Set("code", "12345678")

_, err := s.Authorize(&p, u)
if err == nil {
t.Errorf("expected an error but didn't get one")
}
})
}

0 comments on commit ef6c303

Please sign in to comment.