diff --git a/callbacks/query.go b/callbacks/query.go index 67936766f..97fe8a49c 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -185,7 +185,7 @@ func BuildQuerySQL(db *gorm.DB) { } fromClause.Joins = append(fromClause.Joins, clause.Join{ - Type: clause.LeftJoin, + Type: join.JoinType, Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, ON: clause.Where{Exprs: exprs}, }) diff --git a/chainable_api.go b/chainable_api.go index 68ec7a672..8a92a9e3f 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -235,6 +235,16 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { // db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) // db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{})) func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { + return joins(db, clause.LeftJoin, query, args...) +} + +// InnerJoins specify inner joins conditions +// db.InnerJoins("Account").Find(&user) +func (db *DB) InnerJoins(query string, args ...interface{}) (tx *DB) { + return joins(db, clause.InnerJoin, query, args...) +} + +func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) (tx *DB) { tx = db.getInstance() if len(args) == 1 { @@ -248,7 +258,7 @@ func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { } } - tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args}) + tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, JoinType: joinType}) return } diff --git a/statement.go b/statement.go index d4d20cbff..9f49d5840 100644 --- a/statement.go +++ b/statement.go @@ -49,11 +49,12 @@ type Statement struct { } type join struct { - Name string - Conds []interface{} - On *clause.Where - Selects []string - Omits []string + Name string + Conds []interface{} + On *clause.Where + Selects []string + Omits []string + JoinType clause.JoinType } // StatementModifier statement modifier interface diff --git a/tests/joins_test.go b/tests/joins_test.go index 091fb9864..057ad333e 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -230,6 +230,28 @@ func TestJoinWithSoftDeleted(t *testing.T) { } } +func TestInnerJoins(t *testing.T) { + user := *GetUser("inner-joins-1", Config{Company: true, Manager: true, Account: true, NamedPet: false}) + + DB.Create(&user) + + var user2 User + var err error + err = DB.InnerJoins("Company").InnerJoins("Manager").InnerJoins("Account").First(&user2, "users.name = ?", user.Name).Error + AssertEqual(t, err, nil) + CheckUser(t, user2, user) + + // inner join and NamedPet is nil + err = DB.InnerJoins("NamedPet").InnerJoins("Company").InnerJoins("Manager").InnerJoins("Account").First(&user2, "users.name = ?", user.Name).Error + AssertEqual(t, err, gorm.ErrRecordNotFound) + + // mixed inner join and left join + var user3 User + err = DB.Joins("NamedPet").InnerJoins("Company").InnerJoins("Manager").InnerJoins("Account").First(&user3, "users.name = ?", user.Name).Error + AssertEqual(t, err, nil) + CheckUser(t, user3, user) +} + func TestJoinWithSameColumnName(t *testing.T) { user := GetUser("TestJoinWithSameColumnName", Config{ Languages: 1,