Skip to content

Commit

Permalink
Update Migrator ColumnType interface
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Feb 18, 2022
1 parent 373b1f0 commit 11959bf
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 66 deletions.
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.3 h1:PlHq1bSCSZL9K0wUhbm2pGLoTWs2GwVhsP6emvGV/ZI=
github.com/jinzhu/now v1.1.3/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/jinzhu/now v1.1.4 h1:tHnRBy1i5F2Dh8BAFxqFzxKqqvezXrL2OW1TnX+Mlas=
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
gorm.io/gorm v1.22.4 h1:8aPcyEJhY0MAt8aY6Dc524Pn+pO29K+ydu+e/cXSpQM=
gorm.io/gorm v1.22.4/go.mod h1:1aeVC+pe9ZmvKZban/gW4QPra7PRoTEssyc922qCAkk=
122 changes: 56 additions & 66 deletions migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package mysql
import (
"database/sql"
"fmt"
"strings"

"gorm.io/gorm"
"gorm.io/gorm/clause"
Expand All @@ -15,57 +16,6 @@ type Migrator struct {
Dialector
}

type Column struct {
name string
nullable sql.NullString
datatype string
maxLen sql.NullInt64
precision sql.NullInt64
scale sql.NullInt64
datetimePrecision sql.NullInt64
}

func (c Column) Name() string {
return c.name
}

func (c Column) DatabaseTypeName() string {
return c.datatype
}

func (c Column) Length() (int64, bool) {
if c.maxLen.Valid {
return c.maxLen.Int64, c.maxLen.Valid
}

return 0, false
}

func (c Column) Nullable() (bool, bool) {
if c.nullable.Valid {
return c.nullable.String == "YES", true
}

return false, false
}

// DecimalSize return precision int64, scale int64, ok bool
func (c Column) DecimalSize() (int64, int64, bool) {
if c.precision.Valid {
if c.scale.Valid {
return c.precision.Int64, c.scale.Int64, true
}

return c.precision.Int64, 0, true
}

if c.datetimePrecision.Valid {
return c.datetimePrecision.Int64, 0, true
}

return 0, 0, false
}

func (m Migrator) FullDataTypeOf(field *schema.Field) clause.Expr {
expr := m.Migrator.FullDataTypeOf(field)

Expand Down Expand Up @@ -159,17 +109,17 @@ func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error

func (m Migrator) DropTable(values ...interface{}) error {
values = m.ReorderModels(values, false)
tx := m.DB.Session(&gorm.Session{})
tx.Exec("SET FOREIGN_KEY_CHECKS = 0;")
for i := len(values) - 1; i >= 0; i-- {
if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", clause.Table{Name: stmt.Table}).Error
}); err != nil {
return err
return m.DB.Connection(func(tx *gorm.DB) error {
tx.Exec("SET FOREIGN_KEY_CHECKS = 0;")
for i := len(values) - 1; i >= 0; i-- {
if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", clause.Table{Name: stmt.Table}).Error
}); err != nil {
return err
}
}
}
tx.Exec("SET FOREIGN_KEY_CHECKS = 1;")
return nil
return tx.Exec("SET FOREIGN_KEY_CHECKS = 1;").Error
})
}

func (m Migrator) DropConstraint(value interface{}, name string) error {
Expand All @@ -194,9 +144,20 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
var (
currentDatabase = m.DB.Migrator().CurrentDatabase()
columnTypeSQL = "SELECT column_name, is_nullable, data_type, character_maximum_length, numeric_precision, numeric_scale "
columnTypeSQL = "SELECT column_name, column_default, is_nullable = 'YES', data_type, character_maximum_length, column_type, column_key, extra, column_comment, numeric_precision, numeric_scale "
rows, err = m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows()
)

if err != nil {
return err
}

defer func() {
err = rows.Close()
}()

rawColumnTypes, err := rows.ColumnTypes()

if !m.DisableDatetimePrecision {
columnTypeSQL += ", datetime_precision "
}
Expand All @@ -210,17 +171,46 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
defer columns.Close()

for columns.Next() {
var column Column
var values = []interface{}{&column.name, &column.nullable, &column.datatype,
&column.maxLen, &column.precision, &column.scale}
var (
column migrator.ColumnType
datetimePrecision sql.NullInt64
extraValue sql.NullString
columnKey sql.NullString
values = []interface{}{
&column.NameValue, &column.DefaultValueValue, &column.NullableValue, &column.DataTypeValue, &column.LengthValue, &column.ColumnTypeValue, &columnKey, &extraValue, &column.CommentValue, &column.DecimalSizeValue, &column.ScaleValue,
}
)

if !m.DisableDatetimePrecision {
values = append(values, &column.datetimePrecision)
values = append(values, &datetimePrecision)
}

if scanErr := columns.Scan(values...); scanErr != nil {
return scanErr
}

switch columnKey.String {
case "PRI":
column.PrimayKeyValue = sql.NullBool{Bool: true, Valid: true}
case "UNI":
column.UniqueValue = sql.NullBool{Bool: true, Valid: true}
}

if strings.Contains(extraValue.String, "auto_increment") {
column.AutoIncrementValue = sql.NullBool{Bool: true, Valid: true}
}

if datetimePrecision.Valid {
column.DecimalSizeValue = datetimePrecision
}

for _, c := range rawColumnTypes {
if c.Name() == column.NameValue.String {
column.SQLColumnType = c
break
}
}

columnTypes = append(columnTypes, column)
}

Expand Down

0 comments on commit 11959bf

Please sign in to comment.