From 1648e291104c3e1882e40674fa9941ad46e57f27 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 3 Apr 2022 09:38:20 +0800 Subject: [PATCH] Fix auto migration with transaction, close https://github.com/go-gorm/gorm/issues/5175 --- migrator.go | 135 +++++++++++++++++++++++++++------------------------- 1 file changed, 69 insertions(+), 66 deletions(-) diff --git a/migrator.go b/migrator.go index 5660d6d..30f5e97 100644 --- a/migrator.go +++ b/migrator.go @@ -155,90 +155,93 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { return err } - defer func() { - err = rows.Close() - }() - - var ( - rawColumnTypes, _ = rows.ColumnTypes() - columnTypeSQL = "SELECT column_name, data_type, column_default, is_nullable, character_maximum_length, numeric_precision, numeric_precision_radix, numeric_scale, datetime_precision FROM INFORMATION_SCHEMA.COLUMNS WHERE table_catalog = ? AND table_name = ?" - columns, rowErr = m.DB.Raw(columnTypeSQL, m.CurrentDatabase(), stmt.Table).Rows() - ) - - if rowErr != nil { - return rowErr - } - - defer columns.Close() + rawColumnTypes, _ := rows.ColumnTypes() + rows.Close() - for columns.Next() { + { var ( - column = migrator.ColumnType{ - PrimaryKeyValue: sql.NullBool{Valid: true}, - UniqueValue: sql.NullBool{Valid: true}, - } - datetimePrecision sql.NullInt64 - radixValue sql.NullInt64 - nullableValue sql.NullString - values = []interface{}{ - &column.NameValue, &column.ColumnTypeValue, &column.DefaultValueValue, &nullableValue, &column.LengthValue, &column.DecimalSizeValue, &radixValue, &column.ScaleValue, &datetimePrecision, - } + columnTypeSQL = "SELECT column_name, data_type, column_default, is_nullable, character_maximum_length, numeric_precision, numeric_precision_radix, numeric_scale, datetime_precision FROM INFORMATION_SCHEMA.COLUMNS WHERE table_catalog = ? AND table_name = ?" + columns, rowErr = m.DB.Raw(columnTypeSQL, m.CurrentDatabase(), stmt.Table).Rows() ) - if scanErr := columns.Scan(values...); scanErr != nil { - return scanErr + if rowErr != nil { + return rowErr } - if nullableValue.Valid { - column.NullableValue = sql.NullBool{Bool: strings.EqualFold(nullableValue.String, "YES"), Valid: true} - } + for columns.Next() { + var ( + column = migrator.ColumnType{ + PrimaryKeyValue: sql.NullBool{Valid: true}, + UniqueValue: sql.NullBool{Valid: true}, + } + datetimePrecision sql.NullInt64 + radixValue sql.NullInt64 + nullableValue sql.NullString + values = []interface{}{ + &column.NameValue, &column.ColumnTypeValue, &column.DefaultValueValue, &nullableValue, &column.LengthValue, &column.DecimalSizeValue, &radixValue, &column.ScaleValue, &datetimePrecision, + } + ) - if datetimePrecision.Valid { - column.DecimalSizeValue = datetimePrecision - } + if scanErr := columns.Scan(values...); scanErr != nil { + return scanErr + } - if column.DefaultValueValue.Valid { - matches := defaultValueTrimRegexp.FindStringSubmatch(column.DefaultValueValue.String) - for len(matches) > 1 { - column.DefaultValueValue.String = matches[1] - matches = defaultValueTrimRegexp.FindStringSubmatch(column.DefaultValueValue.String) + if nullableValue.Valid { + column.NullableValue = sql.NullBool{Bool: strings.EqualFold(nullableValue.String, "YES"), Valid: true} + } + + if datetimePrecision.Valid { + column.DecimalSizeValue = datetimePrecision + } + + if column.DefaultValueValue.Valid { + matches := defaultValueTrimRegexp.FindStringSubmatch(column.DefaultValueValue.String) + for len(matches) > 1 { + column.DefaultValueValue.String = matches[1] + matches = defaultValueTrimRegexp.FindStringSubmatch(column.DefaultValueValue.String) + } + } else { + column.DefaultValueValue.Valid = true } - } else { - column.DefaultValueValue.Valid = true - } - for _, c := range rawColumnTypes { - if c.Name() == column.NameValue.String { - column.SQLColumnType = c - break + for _, c := range rawColumnTypes { + if c.Name() == column.NameValue.String { + column.SQLColumnType = c + break + } } + + columnTypes = append(columnTypes, column) } - columnTypes = append(columnTypes, column) + columns.Close() } - columnTypeRows, err := m.DB.Raw("SELECT c.column_name, t.constraint_type FROM information_schema.table_constraints t JOIN information_schema.constraint_column_usage c ON c.constraint_name=t.constraint_name WHERE t.constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_name = ?", m.CurrentDatabase(), stmt.Table).Rows() - if err != nil { - return err - } - defer columnTypeRows.Close() - - for columnTypeRows.Next() { - var name, columnType string - columnTypeRows.Scan(&name, &columnType) - for idx, c := range columnTypes { - mc := c.(migrator.ColumnType) - if mc.NameValue.String == name { - switch columnType { - case "PRIMARY KEY": - mc.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true} - case "UNIQUE": - mc.UniqueValue = sql.NullBool{Bool: true, Valid: true} + { + columnTypeRows, err := m.DB.Raw("SELECT c.column_name, t.constraint_type FROM information_schema.table_constraints t JOIN information_schema.constraint_column_usage c ON c.constraint_name=t.constraint_name WHERE t.constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_name = ?", m.CurrentDatabase(), stmt.Table).Rows() + if err != nil { + return err + } + + for columnTypeRows.Next() { + var name, columnType string + columnTypeRows.Scan(&name, &columnType) + for idx, c := range columnTypes { + mc := c.(migrator.ColumnType) + if mc.NameValue.String == name { + switch columnType { + case "PRIMARY KEY": + mc.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true} + case "UNIQUE": + mc.UniqueValue = sql.NullBool{Bool: true, Valid: true} + } + columnTypes[idx] = mc + break } - columnTypes[idx] = mc - break } } + + columnTypeRows.Close() } return