Skip to content

Commit

Permalink
Merge pull request #58 from scribd/laynax/SERF-498/upgrade-gorm
Browse files Browse the repository at this point in the history
[SERF-498] Upgrade gorm to v2
  • Loading branch information
laynax committed Feb 17, 2023
2 parents 85448f8 + 19d6969 commit cccd416
Show file tree
Hide file tree
Showing 16 changed files with 414 additions and 385 deletions.
15 changes: 9 additions & 6 deletions README.md
Expand Up @@ -744,7 +744,7 @@ func main() {
### ORM Integration
`go-sdk` comes with an integration with the popular
[gorm](https://github.com/jinzhu/gorm) as an object-relational mapper (ORM).
[gorm](https://gorm.io/gorm) as an object-relational mapper (ORM).
Using the configuration details, namely the [data source
name](https://en.wikipedia.org/wiki/Data_source_name) (DSN) as their product,
gorm is able to open a connection and give the `go-sdk` users a preconfigured
Expand Down Expand Up @@ -773,7 +773,7 @@ words the `NewConnection` function, so they remain opaque for the user.
#### Usage of ORM
Invoking the constructor for a database connection, `go-sdk` returns a
[Gorm-powered](https://github.com/jinzhu/gorm) database connection. It can be
[Gorm-powered](https://gorm.io/gorm) database connection. It can be
used right away to query the database:
```go
Expand All @@ -783,7 +783,7 @@ import (
"fmt"

sdkdb "github.com/scribd/go-sdk/pkg/database"
"github.com/jinzhu/gorm"
"gorm.io/gorm"
)

type User struct {
Expand Down Expand Up @@ -1041,7 +1041,7 @@ if err != nil {
`DatabaseLogging` for both HTTP and gRPC servers.
The `Database` middleware which instruments the
[Gorm-powered](https://github.com/jinzhu/gorm) database connection. It utilizes
[Gorm-powered](https://gorm.io/gorm) database connection. It utilizes
Gorm-specific callbacks that report spans and traces to Datadog. The
instrumented Gorm database connection is injected in the request `Context` and
it is always scoped within the request.
Expand All @@ -1051,8 +1051,11 @@ The `DatabaseLogging` middleware checks for a logger injected in the request
which in turn uses the logger to produce database query logs. A nice
side-effect of this approach is that, if the logger is tagged with a
`request_id`, there's a logs correlation between the HTTP requests and the
database queries. Also, if the logger is tagged with `treace_id` we can easily
correlate logs with traces and see corresponding database queries.
database queries. Also, if the logger is tagged with `trace_id` we can easily
correlate logs with traces and see corresponding database queries. Keep in mind
that ORM logging happens at the same level as the base logger. The additional
database fields ('duration', 'affected_rows' and 'sql') are available
only when using the `trace` levels.
#### HTTP server middleware example
Expand Down
10 changes: 7 additions & 3 deletions go.mod
Expand Up @@ -3,16 +3,15 @@ module github.com/scribd/go-sdk
go 1.19

require (
github.com/DATA-DOG/go-sqlmock v1.5.0
github.com/DATA-DOG/go-txdb v0.1.3
github.com/DataDog/datadog-go v4.8.2+incompatible
github.com/aws/aws-sdk-go v1.34.28
github.com/getsentry/sentry-go v0.12.0
github.com/google/go-cmp v0.5.7
github.com/google/uuid v1.3.0
github.com/grpc-ecosystem/go-grpc-middleware v1.0.0
github.com/jinzhu/gorm v1.9.11
github.com/magefile/mage v1.13.0
github.com/mattn/go-sqlite3 v1.14.14
github.com/rs/cors v1.7.0
github.com/sirupsen/logrus v1.7.0
github.com/spf13/viper v1.10.1
Expand All @@ -22,6 +21,9 @@ require (
google.golang.org/protobuf v1.28.0
gopkg.in/DataDog/dd-trace-go.v1 v1.47.0
gopkg.in/natefinch/lumberjack.v2 v2.0.0
gorm.io/driver/mysql v1.4.6
gorm.io/driver/sqlite v1.4.4
gorm.io/gorm v1.24.5
)

require (
Expand All @@ -37,18 +39,20 @@ require (
github.com/dgraph-io/ristretto v0.1.0 // indirect
github.com/dustin/go-humanize v1.0.0 // indirect
github.com/fsnotify/fsnotify v1.5.1 // indirect
github.com/go-sql-driver/mysql v1.6.0 // indirect
github.com/go-sql-driver/mysql v1.7.0 // indirect
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b // indirect
github.com/golang/protobuf v1.5.2 // indirect
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 // indirect
github.com/gorilla/mux v1.7.1 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/klauspost/compress v1.15.2 // indirect
github.com/magiconair/properties v1.8.5 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/mattn/go-sqlite3 v1.14.15 // indirect
github.com/mitchellh/mapstructure v1.4.3 // indirect
github.com/pelletier/go-toml v1.9.4 // indirect
github.com/philhofer/fwd v1.1.1 // indirect
Expand Down
103 changes: 17 additions & 86 deletions go.sum

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pkg/context/database/context.go
Expand Up @@ -4,7 +4,7 @@ import (
"context"
"fmt"

"github.com/jinzhu/gorm"
"gorm.io/gorm"
)

type ctxDatabaseMarker struct{}
Expand Down
51 changes: 35 additions & 16 deletions pkg/database/gorm.go
@@ -1,31 +1,50 @@
package database

import (
"github.com/DATA-DOG/go-txdb"
"github.com/jinzhu/gorm"
"strconv"
"time"

// Imports required gorm MySQL dialect.
_ "github.com/jinzhu/gorm/dialects/mysql"
"github.com/DATA-DOG/go-txdb"
"gorm.io/driver/mysql"
"gorm.io/gorm"
)

const testEnv = "test"

// NewConnection returns a new Gorm database connection.
func NewConnection(config *Config, environment string) (*gorm.DB, error) {
var db *gorm.DB
var err error

connectionDetails := NewConnectionDetails(config)

switch environment {
case "test":
txdb.Register("txdb", connectionDetails.Dialect, connectionDetails.String())
db, err = gorm.Open(connectionDetails.Dialect, "txdb", "tx_1")
default:
db, err = gorm.Open(connectionDetails.Dialect, connectionDetails.String())
connectionString := connectionDetails.String()
driverName := connectionDetails.Dialect

// Register the test driver and mock driver name and connection string in test environment.
if environment == testEnv {
// Using time.Now() as a unique identifier for the test database so that we can call NewConnection()
// multiple times without getting an error.
testDriverName := strconv.Itoa(int(time.Now().UnixNano()))

txdb.Register(testDriverName, connectionDetails.Dialect, connectionString)
driverName = testDriverName
connectionString = testDriverName
}

if err == nil {
db.DB().SetMaxIdleConns(connectionDetails.Pool)
dialector := mysql.New(mysql.Config{
DSN: connectionString,
DriverName: driverName,
})

db, err := gorm.Open(dialector)
if err != nil {
return nil, err
}

return db, err
sqlDB, err := db.DB()
if err != nil {
return nil, err
}

sqlDB.SetMaxIdleConns(config.Pool)

return db, nil
}
101 changes: 67 additions & 34 deletions pkg/instrumentation/database.go
Expand Up @@ -5,9 +5,9 @@ import (
"fmt"
"strings"

"github.com/jinzhu/gorm"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext"
ddtrace "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
"gorm.io/gorm"
)

const (
Expand All @@ -34,7 +34,7 @@ func InstrumentDatabase(db *gorm.DB, appName string) {
registerCallbacks(db, "query", callbacks)
registerCallbacks(db, "update", callbacks)
registerCallbacks(db, "delete", callbacks)
registerCallbacks(db, "row_query", callbacks)
registerCallbacks(db, "row", callbacks)
}

type callbacks struct {
Expand All @@ -47,18 +47,18 @@ func newCallbacks(appName string) *callbacks {
}
}

func (c *callbacks) beforeCreate(scope *gorm.Scope) { c.before(scope, "INSERT", c.serviceName) }
func (c *callbacks) afterCreate(scope *gorm.Scope) { c.after(scope) }
func (c *callbacks) beforeQuery(scope *gorm.Scope) { c.before(scope, "SELECT", c.serviceName) }
func (c *callbacks) afterQuery(scope *gorm.Scope) { c.after(scope) }
func (c *callbacks) beforeUpdate(scope *gorm.Scope) { c.before(scope, "UPDATE", c.serviceName) }
func (c *callbacks) afterUpdate(scope *gorm.Scope) { c.after(scope) }
func (c *callbacks) beforeDelete(scope *gorm.Scope) { c.before(scope, "DELETE", c.serviceName) }
func (c *callbacks) afterDelete(scope *gorm.Scope) { c.after(scope) }
func (c *callbacks) beforeRowQuery(scope *gorm.Scope) { c.before(scope, "", c.serviceName) }
func (c *callbacks) afterRowQuery(scope *gorm.Scope) { c.after(scope) }
func (c *callbacks) before(scope *gorm.Scope, operationName string, serviceName string) {
val, ok := scope.Get(ParentSpanGormKey)
func (c *callbacks) beforeCreate(db *gorm.DB) { c.before(db, "INSERT", c.serviceName) }
func (c *callbacks) afterCreate(db *gorm.DB) { c.after(db) }
func (c *callbacks) beforeQuery(db *gorm.DB) { c.before(db, "SELECT", c.serviceName) }
func (c *callbacks) afterQuery(db *gorm.DB) { c.after(db) }
func (c *callbacks) beforeUpdate(db *gorm.DB) { c.before(db, "UPDATE", c.serviceName) }
func (c *callbacks) afterUpdate(db *gorm.DB) { c.after(db) }
func (c *callbacks) beforeDelete(db *gorm.DB) { c.before(db, "DELETE", c.serviceName) }
func (c *callbacks) afterDelete(db *gorm.DB) { c.after(db) }
func (c *callbacks) beforeRow(db *gorm.DB) { c.before(db, "", c.serviceName) }
func (c *callbacks) afterRow(db *gorm.DB) { c.after(db) }
func (c *callbacks) before(db *gorm.DB, operationName string, serviceName string) {
val, ok := db.Get(ParentSpanGormKey)
if !ok {
return
}
Expand All @@ -70,46 +70,79 @@ func (c *callbacks) before(scope *gorm.Scope, operationName string, serviceName
ddtrace.ServiceName(serviceName),
}
if operationName == "" {
operationName = strings.Split(scope.SQL, " ")[0]
operationName = strings.Split(db.Statement.SQL.String(), " ")[0]
}
sp := ddtrace.StartSpan(operationName, spanOpts...)
scope.Set(SpanGormKey, sp)
db.Set(SpanGormKey, sp)
}

func (c *callbacks) after(scope *gorm.Scope) {
val, ok := scope.Get(SpanGormKey)
func (c *callbacks) after(db *gorm.DB) {
val, ok := db.Get(SpanGormKey)
if !ok {
return
}

sp := val.(ddtrace.Span)
sp.SetTag(ext.ResourceName, strings.ToUpper(scope.SQL))
sp.SetTag("db.table", scope.TableName())
sp.SetTag("db.query", strings.ToUpper(scope.SQL))
sp.SetTag("db.err", scope.HasError())
sp.SetTag("db.count", scope.DB().RowsAffected)
sp.SetTag(ext.ResourceName, strings.ToUpper(db.Statement.SQL.String()))
sp.SetTag("db.table", db.Statement.Table)
sp.SetTag("db.query", strings.ToUpper(db.Statement.SQL.String()))
sp.SetTag("db.err", db.Error)
sp.SetTag("db.count", db.RowsAffected)
sp.Finish()
}

func registerCallbacks(db *gorm.DB, name string, c *callbacks) {
var err error

beforeName := fmt.Sprintf("tracing:%v_before", name)
afterName := fmt.Sprintf("tracing:%v_after", name)
gormCallbackName := fmt.Sprintf("gorm:%v", name)
// gorm does some magic, if you pass CallbackProcessor here - nothing works
switch name {
case "create":
db.Callback().Create().Before(gormCallbackName).Register(beforeName, c.beforeCreate)
db.Callback().Create().After(gormCallbackName).Register(afterName, c.afterCreate)
err = db.Callback().Create().Before(gormCallbackName).Register(beforeName, c.beforeCreate)
if err != nil {
return
}
err = db.Callback().Create().After(gormCallbackName).Register(afterName, c.afterCreate)
if err != nil {
return
}
case "query":
db.Callback().Query().Before(gormCallbackName).Register(beforeName, c.beforeQuery)
db.Callback().Query().After(gormCallbackName).Register(afterName, c.afterQuery)
err = db.Callback().Query().Before(gormCallbackName).Register(beforeName, c.beforeQuery)
if err != nil {
return
}
err = db.Callback().Query().After(gormCallbackName).Register(afterName, c.afterQuery)
if err != nil {
return
}
case "update":
db.Callback().Update().Before(gormCallbackName).Register(beforeName, c.beforeUpdate)
db.Callback().Update().After(gormCallbackName).Register(afterName, c.afterUpdate)
err = db.Callback().Update().Before(gormCallbackName).Register(beforeName, c.beforeUpdate)
if err != nil {
return
}
err = db.Callback().Update().After(gormCallbackName).Register(afterName, c.afterUpdate)
if err != nil {
return
}
case "delete":
db.Callback().Delete().Before(gormCallbackName).Register(beforeName, c.beforeDelete)
db.Callback().Delete().After(gormCallbackName).Register(afterName, c.afterDelete)
case "row_query":
db.Callback().RowQuery().Before(gormCallbackName).Register(beforeName, c.beforeRowQuery)
db.Callback().RowQuery().After(gormCallbackName).Register(afterName, c.afterRowQuery)
err = db.Callback().Delete().Before(gormCallbackName).Register(beforeName, c.beforeDelete)
if err != nil {
return
}
err = db.Callback().Delete().After(gormCallbackName).Register(afterName, c.afterDelete)
if err != nil {
return
}
case "row":
err = db.Callback().Row().Before(gormCallbackName).Register(beforeName, c.beforeRow)
if err != nil {
return
}
err = db.Callback().Row().After(gormCallbackName).Register(afterName, c.afterRow)
if err != nil {
return
}
}
}

0 comments on commit cccd416

Please sign in to comment.