Skip to content

Commit

Permalink
feat: migrator support type aliases (#5627)
Browse files Browse the repository at this point in the history
* feat: migrator support type aliases

* perf: check type
  • Loading branch information
a631807682 authored and jinzhu committed Sep 22, 2022
1 parent 642bc4e commit a289321
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 7 deletions.
16 changes: 12 additions & 4 deletions gorm.go
Expand Up @@ -248,10 +248,18 @@ func (db *DB) Session(config *Session) *DB {
if config.PrepareStmt {
if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
preparedStmt := v.(*PreparedStmtDB)
tx.Statement.ConnPool = &PreparedStmtDB{
ConnPool: db.Statement.ConnPool,
Mux: preparedStmt.Mux,
Stmts: preparedStmt.Stmts,
switch t := tx.Statement.ConnPool.(type) {
case Tx:
tx.Statement.ConnPool = &PreparedStmtTX{
Tx: t,
PreparedStmtDB: preparedStmt,
}
default:
tx.Statement.ConnPool = &PreparedStmtDB{
ConnPool: db.Config.ConnPool,
Mux: preparedStmt.Mux,
Stmts: preparedStmt.Stmts,
}
}
txConfig.ConnPool = tx.Statement.ConnPool
txConfig.PrepareStmt = true
Expand Down
1 change: 1 addition & 0 deletions migrator.go
Expand Up @@ -68,6 +68,7 @@ type Migrator interface {
// Database
CurrentDatabase() string
FullDataTypeOf(*schema.Field) clause.Expr
GetTypeAliases(databaseTypeName string) []string

// Tables
CreateTable(dst ...interface{}) error
Expand Down
29 changes: 26 additions & 3 deletions migrator/migrator.go
Expand Up @@ -408,9 +408,27 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy

alterColumn := false

// check type
if !field.PrimaryKey && !strings.HasPrefix(fullDataType, realDataType) {
alterColumn = true
if !field.PrimaryKey {
// check type
var isSameType bool
if strings.HasPrefix(fullDataType, realDataType) {
isSameType = true
}

// check type aliases
if !isSameType {
aliases := m.DB.Migrator().GetTypeAliases(realDataType)
for _, alias := range aliases {
if strings.HasPrefix(fullDataType, alias) {
isSameType = true
break
}
}
}

if !isSameType {
alterColumn = true
}
}

// check size
Expand Down Expand Up @@ -863,3 +881,8 @@ func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} {
func (m Migrator) GetIndexes(dst interface{}) ([]gorm.Index, error) {
return nil, errors.New("not support")
}

// GetTypeAliases return database type aliases
func (m Migrator) GetTypeAliases(databaseTypeName string) []string {
return nil
}
17 changes: 17 additions & 0 deletions tests/prepared_stmt_test.go
Expand Up @@ -2,6 +2,7 @@ package tests_test

import (
"context"
"errors"
"testing"
"time"

Expand Down Expand Up @@ -88,3 +89,19 @@ func TestPreparedStmtFromTransaction(t *testing.T) {
}
tx2.Commit()
}

func TestPreparedStmtInTransaction(t *testing.T) {
user := User{Name: "jinzhu"}

if err := DB.Transaction(func(tx *gorm.DB) error {
tx.Session(&gorm.Session{PrepareStmt: true}).Create(&user)
return errors.New("test")
}); err == nil {
t.Error(err)
}

var result User
if err := DB.First(&result, user.ID).Error; err == nil {
t.Errorf("Failed, got error: %v", err)
}
}

0 comments on commit a289321

Please sign in to comment.