Skip to content

Commit

Permalink
all: properly implement regenerate session ID
Browse files Browse the repository at this point in the history
  • Loading branch information
unknwon committed May 2, 2024
1 parent 15e62f0 commit 3afea68
Show file tree
Hide file tree
Showing 10 changed files with 105 additions and 54 deletions.
25 changes: 17 additions & 8 deletions file.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,21 @@ type fileStore struct {
nowFunc func() time.Time // The function to return the current time
lifetime time.Duration // The duration to have no access to a session before being recycled
rootDir string // The root directory of file session items stored on the local file system
encoder Encoder // The encoder to encode the session data before saving
decoder Decoder // The decoder to decode binary to session data after reading

encoder Encoder
decoder Decoder
idWriter IDWriter
}

// newFileStore returns a new file session store based on given configuration.
func newFileStore(cfg FileConfig) *fileStore {
func newFileStore(cfg FileConfig, idWriter IDWriter) *fileStore {
return &fileStore{
nowFunc: cfg.nowFunc,
lifetime: cfg.Lifetime,
rootDir: cfg.RootDir,
encoder: cfg.Encoder,
decoder: cfg.Decoder,
idWriter: idWriter,
}
}

Expand Down Expand Up @@ -70,7 +73,7 @@ func (s *fileStore) Read(_ context.Context, sid string) (Session, error) {
return nil, errors.Wrap(err, "create parent directory")
}

return NewBaseSession(sid, s.encoder), nil
return NewBaseSession(sid, s.encoder, s.idWriter), nil
}

// Discard existing data if it's expired
Expand All @@ -79,7 +82,7 @@ func (s *fileStore) Read(_ context.Context, sid string) (Session, error) {
return nil, errors.Wrap(err, "stat file")
}
if !fi.ModTime().Add(s.lifetime).After(s.nowFunc()) {
return NewBaseSession(sid, s.encoder), nil
return NewBaseSession(sid, s.encoder, s.idWriter), nil
}

binary, err := os.ReadFile(filename)
Expand All @@ -91,7 +94,7 @@ func (s *fileStore) Read(_ context.Context, sid string) (Session, error) {
if err != nil {
return nil, errors.Wrap(err, "decode")
}
return NewBaseSessionWithData(sid, s.encoder, data), nil
return NewBaseSessionWithData(sid, s.encoder, s.idWriter, data), nil
}

func (s *fileStore) Destroy(_ context.Context, sid string) error {
Expand Down Expand Up @@ -169,7 +172,7 @@ func (s *fileStore) GC(ctx context.Context) error {

// FileConfig contains options for the file session store.
type FileConfig struct {
// For tests only
// For tests only.
nowFunc func() time.Time

// Lifetime is the duration to have no access to a session before being
Expand All @@ -188,12 +191,18 @@ type FileConfig struct {
func FileIniter() Initer {
return func(ctx context.Context, args ...interface{}) (Store, error) {
var cfg *FileConfig
var idWriter IDWriter
for i := range args {
switch v := args[i].(type) {
case FileConfig:
cfg = &v
case IDWriter:
idWriter = v
}
}
if idWriter == nil {
return nil, errors.New("IDWriter not given")
}

if cfg == nil {
return nil, fmt.Errorf("config object with the type '%T' not found", FileConfig{})
Expand All @@ -214,6 +223,6 @@ func FileIniter() Initer {
cfg.Decoder = GobDecoder
}

return newFileStore(*cfg), nil
return newFileStore(*cfg, idWriter), nil
}
}
2 changes: 2 additions & 0 deletions file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ func TestFileStore_GC(t *testing.T) {
RootDir: filepath.Join(os.TempDir(), "sessions"),
Lifetime: time.Second,
},
IDWriter(func(http.ResponseWriter, *http.Request, string) {}),
)
require.Nil(t, err)

Expand Down Expand Up @@ -137,6 +138,7 @@ func TestFileStore_Touch(t *testing.T) {
RootDir: filepath.Join(os.TempDir(), "sessions"),
Lifetime: time.Second,
},
IDWriter(func(http.ResponseWriter, *http.Request, string) {}),
)
require.Nil(t, err)

Expand Down
2 changes: 1 addition & 1 deletion manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func TestIsValidSessionID(t *testing.T) {
}

func TestManager_startGC(t *testing.T) {
m := newManager(newMemoryStore(MemoryConfig{}))
m := newManager(newMemoryStore(MemoryConfig{}, nil))
stop := m.startGC(
context.Background(),
time.Minute,
Expand Down
21 changes: 16 additions & 5 deletions memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"context"
"sync"
"time"

"github.com/pkg/errors"
)

var _ Session = (*memorySession)(nil)
Expand All @@ -24,9 +26,9 @@ type memorySession struct {
}

// newMemorySession returns a new memory session with given session ID.
func newMemorySession(sid string) *memorySession {
func newMemorySession(sid string, idWriter IDWriter) *memorySession {
return &memorySession{
BaseSession: NewBaseSession(sid, nil),
BaseSession: NewBaseSession(sid, nil, idWriter),
}
}

Expand All @@ -52,15 +54,18 @@ type memoryStore struct {
lock sync.RWMutex // The mutex to guard accesses to the heap and index
heap []*memorySession // The heap to be managed by operations of heap.Interface
index map[string]*memorySession // The index to be managed by operations of heap.Interface

idWriter IDWriter
}

// newMemoryStore returns a new memory session store based on given
// configuration.
func newMemoryStore(cfg MemoryConfig) *memoryStore {
func newMemoryStore(cfg MemoryConfig, idWriter IDWriter) *memoryStore {
return &memoryStore{
nowFunc: cfg.nowFunc,
lifetime: cfg.Lifetime,
index: make(map[string]*memorySession),
idWriter: idWriter,
}
}

Expand Down Expand Up @@ -136,7 +141,7 @@ func (s *memoryStore) Read(_ context.Context, sid string) (Session, error) {
return sess, nil
}

sess = newMemorySession(sid)
sess = newMemorySession(sid, s.idWriter)
sess.SetLastAccessedAt(s.nowFunc())
heap.Push(s, sess)
return sess, nil
Expand Down Expand Up @@ -219,12 +224,18 @@ type MemoryConfig struct {
func MemoryIniter() Initer {
return func(_ context.Context, args ...interface{}) (Store, error) {
var cfg *MemoryConfig
var idWriter IDWriter
for i := range args {
switch v := args[i].(type) {
case MemoryConfig:
cfg = &v
case IDWriter:
idWriter = v
}
}
if idWriter == nil {
return nil, errors.New("IDWriter not given")
}

if cfg == nil {
cfg = &MemoryConfig{}
Expand All @@ -237,6 +248,6 @@ func MemoryIniter() Initer {
cfg.Lifetime = 3600 * time.Second
}

return newMemoryStore(*cfg), nil
return newMemoryStore(*cfg, idWriter), nil
}
}
2 changes: 2 additions & 0 deletions memory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ func TestMemoryStore_GC(t *testing.T) {
nowFunc: func() time.Time { return now },
Lifetime: time.Second,
},
nil,
)

sess1, err := store.Read(ctx, "1")
Expand Down Expand Up @@ -125,6 +126,7 @@ func TestMemoryStore_Touch(t *testing.T) {
nowFunc: func() time.Time { return now },
Lifetime: time.Second,
},
nil,
)

sess, err := store.Read(ctx, "1")
Expand Down
23 changes: 16 additions & 7 deletions mongo/mongo.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,22 @@ type mongoStore struct {
lifetime time.Duration // The duration to have access to a session before being recycled
db *mongo.Database // The database connection
collection string // The database collection for storing session data
encoder session.Encoder // The encoder to encode the session data before saving
decoder session.Decoder // The decoder to decode binary to session data after reading

encoder session.Encoder
decoder session.Decoder
idWriter session.IDWriter
}

// newMongoStore returns a new MongoDB session store based on given configuration.
func newMongoStore(cfg Config) *mongoStore {
func newMongoStore(cfg Config, idWriter session.IDWriter) *mongoStore {
return &mongoStore{
nowFunc: cfg.nowFunc,
lifetime: cfg.Lifetime,
db: cfg.db,
collection: cfg.Collection,
encoder: cfg.Encoder,
decoder: cfg.Decoder,
idWriter: idWriter,
}
}

Expand All @@ -63,19 +66,19 @@ func (s *mongoStore) Read(ctx context.Context, sid string) (session.Session, err

// Discard existing data if it's expired
if !s.nowFunc().Before(expiredAt.Time().Add(s.lifetime)) {
return session.NewBaseSession(sid, s.encoder), nil
return session.NewBaseSession(sid, s.encoder, s.idWriter), nil
}

data, err := s.decoder(binary.Data)
if err != nil {
return nil, errors.Wrap(err, "decode")
}
return session.NewBaseSessionWithData(sid, s.encoder, data), nil
return session.NewBaseSessionWithData(sid, s.encoder, s.idWriter, data), nil
} else if err != mongo.ErrNoDocuments {
return nil, errors.Wrap(err, "find")
}

return session.NewBaseSession(sid, s.encoder), nil
return session.NewBaseSession(sid, s.encoder, s.idWriter), nil
}

func (s *mongoStore) Destroy(ctx context.Context, sid string) error {
Expand Down Expand Up @@ -157,12 +160,18 @@ type Config struct {
func Initer() session.Initer {
return func(ctx context.Context, args ...interface{}) (session.Store, error) {
var cfg *Config
var idWriter session.IDWriter
for i := range args {
switch v := args[i].(type) {
case Config:
cfg = &v
case session.IDWriter:
idWriter = v
}
}
if idWriter == nil {
return nil, errors.New("IDWriter not given")
}

if cfg == nil {
return nil, fmt.Errorf("config object with the type '%T' not found", Config{})
Expand Down Expand Up @@ -194,6 +203,6 @@ func Initer() session.Initer {
cfg.Decoder = session.GobDecoder
}

return newMongoStore(*cfg), nil
return newMongoStore(*cfg, idWriter), nil
}
}
29 changes: 19 additions & 10 deletions redis/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,24 @@ var _ session.Store = (*redisStore)(nil)

// redisStore is a Redis implementation of the session store.
type redisStore struct {
client *redis.Client // The client connection
keyPrefix string // The prefix to use for keys
lifetime time.Duration // The duration to have access to a session before being recycled
encoder session.Encoder // The encoder to encode the session data before saving
decoder session.Decoder // The decoder to decode binary to session data after reading
client *redis.Client // The client connection
keyPrefix string // The prefix to use for keys
lifetime time.Duration // The duration to have access to a session before being recycled

encoder session.Encoder
decoder session.Decoder
idWriter session.IDWriter
}

// newRedisStore returns a new Redis session store based on given configuration.
func newRedisStore(cfg Config) *redisStore {
func newRedisStore(cfg Config, idWriter session.IDWriter) *redisStore {
return &redisStore{
client: cfg.Client,
keyPrefix: cfg.KeyPrefix,
lifetime: cfg.Lifetime,
encoder: cfg.Encoder,
decoder: cfg.Decoder,
idWriter: idWriter,
}
}

Expand All @@ -45,8 +48,8 @@ func (s *redisStore) Exist(ctx context.Context, sid string) bool {
func (s *redisStore) Read(ctx context.Context, sid string) (session.Session, error) {
binary, err := s.client.Get(ctx, s.keyPrefix+sid).Result()
if err != nil {
if err == redis.Nil {
return session.NewBaseSession(sid, s.encoder), nil
if errors.Is(err, redis.Nil) {
return session.NewBaseSession(sid, s.encoder, s.idWriter), nil
}
return nil, errors.Wrap(err, "get")
}
Expand All @@ -55,7 +58,7 @@ func (s *redisStore) Read(ctx context.Context, sid string) (session.Session, err
if err != nil {
return nil, errors.Wrap(err, "decode")
}
return session.NewBaseSessionWithData(sid, s.encoder, data), nil
return session.NewBaseSessionWithData(sid, s.encoder, s.idWriter, data), nil
}

func (s *redisStore) Destroy(ctx context.Context, sid string) error {
Expand Down Expand Up @@ -112,12 +115,18 @@ type Config struct {
func Initer() session.Initer {
return func(ctx context.Context, args ...interface{}) (session.Store, error) {
var cfg *Config
var idWriter session.IDWriter
for i := range args {
switch v := args[i].(type) {
case Config:
cfg = &v
case session.IDWriter:
idWriter = v
}
}
if idWriter == nil {
return nil, errors.New("IDWriter not given")
}

if cfg == nil {
return nil, fmt.Errorf("config object with the type '%T' not found", Config{})
Expand All @@ -141,6 +150,6 @@ func Initer() session.Initer {
cfg.Decoder = session.GobDecoder
}

return newRedisStore(*cfg), nil
return newRedisStore(*cfg, idWriter), nil
}
}

0 comments on commit 3afea68

Please sign in to comment.