Skip to content

Commit eb90a02

Browse files
authored
fix: returning all columns with "on conflict do update" must considered as ScanUpdate (#7534)
1 parent 22d5239 commit eb90a02

2 files changed

Lines changed: 51 additions & 1 deletion

File tree

callbacks/create.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,11 @@ func Create(config *Config) func(db *gorm.DB) {
8080
ok, mode := hasReturning(db, supportReturning)
8181
if ok {
8282
if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok {
83-
if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing {
83+
onConflict, _ := c.Expression.(clause.OnConflict)
84+
if onConflict.DoNothing {
8485
mode |= gorm.ScanOnConflictDoNothing
86+
} else if len(onConflict.DoUpdates) > 0 || onConflict.UpdateAll {
87+
mode |= gorm.ScanUpdate
8588
}
8689
}
8790

tests/upsert_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,53 @@ func TestUpsertSlice(t *testing.T) {
135135
}
136136
}
137137

138+
func TestUpsertSliceWithReturning(t *testing.T) {
139+
langs := []Language{
140+
{Code: "upsert-slice1", Name: "Upsert-slice1"},
141+
{Code: "upsert-slice2", Name: "Upsert-slice2"},
142+
{Code: "upsert-slice3", Name: "Upsert-slice3"},
143+
}
144+
DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&langs)
145+
146+
var langs2 []Language
147+
if err := DB.Find(&langs2, "code LIKE ?", "upsert-slice%").Error; err != nil {
148+
t.Errorf("no error should happen when find languages with code, but got %v", err)
149+
} else if len(langs2) != 3 {
150+
t.Errorf("should only find only 3 languages, but got %+v", langs2)
151+
}
152+
153+
DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&langs)
154+
var langs3 []Language
155+
if err := DB.Find(&langs3, "code LIKE ?", "upsert-slice%").Error; err != nil {
156+
t.Errorf("no error should happen when find languages with code, but got %v", err)
157+
} else if len(langs3) != 3 {
158+
t.Errorf("should only find only 3 languages, but got %+v", langs3)
159+
}
160+
161+
for idx, lang := range langs {
162+
lang.Name = lang.Name + "_new"
163+
langs[idx] = lang
164+
}
165+
166+
if err := DB.Clauses(clause.OnConflict{
167+
Columns: []clause.Column{{Name: "code"}},
168+
DoUpdates: clause.AssignmentColumns([]string{"name"}),
169+
}, clause.Returning{}).CreateInBatches(&langs, len(langs)).Error; err != nil {
170+
t.Fatalf("failed to upsert, got %v", err)
171+
}
172+
173+
for _, lang := range langs {
174+
var results []Language
175+
if err := DB.Find(&results, "code = ?", lang.Code).Error; err != nil {
176+
t.Errorf("no error should happen when find languages with code, but got %v", err)
177+
} else if len(results) != 1 {
178+
t.Errorf("should only find only 1 languages, but got %+v", langs)
179+
} else if results[0].Name != lang.Name {
180+
t.Errorf("should update name on conflict, but got name %+v", results[0].Name)
181+
}
182+
}
183+
}
184+
138185
func TestUpsertWithSave(t *testing.T) {
139186
langs := []Language{
140187
{Code: "upsert-save-1", Name: "Upsert-save-1"},

0 commit comments

Comments
 (0)