Skip to content

Commit 1e514a0

Browse files
Copilotshueybubbles
andcommitted
Fix OUT parameter support for nullable civil types
Co-authored-by: shueybubbles <[email protected]>
1 parent f4b57b4 commit 1e514a0

File tree

2 files changed

+140
-0
lines changed

2 files changed

+140
-0
lines changed

civil_null_test.go

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,117 @@ func TestNullCivilTypesImplementInterfaces(t *testing.T) {
268268
)
269269
// Note: Scanner interface is verified by successful compilation of Scan methods
270270
}
271+
272+
// TestNullCivilTypesParameterEncoding tests that nullable civil types are properly encoded
273+
// as typed NULL parameters rather than untyped NULLs, which is important for OUT parameters
274+
func TestNullCivilTypesParameterEncoding(t *testing.T) {
275+
// Create a mock connection and statement for testing
276+
c := &Conn{}
277+
c.sess = &tdsSession{}
278+
c.sess.loginAck.TDSVersion = verTDS74 // Use modern TDS version
279+
s := &Stmt{c: c}
280+
281+
t.Run("NullDate parameter encoding", func(t *testing.T) {
282+
// Test valid NullDate
283+
validDate := NullDate{Date: civil.Date{Year: 2023, Month: time.December, Day: 25}, Valid: true}
284+
param, err := s.makeParam(validDate)
285+
if err != nil {
286+
t.Errorf("Unexpected error for valid NullDate: %v", err)
287+
}
288+
if param.ti.TypeId != typeDateN {
289+
t.Errorf("Expected TypeId %v for valid NullDate, got %v", typeDateN, param.ti.TypeId)
290+
}
291+
if len(param.buffer) == 0 {
292+
t.Error("Expected non-empty buffer for valid NullDate")
293+
}
294+
295+
// Test invalid NullDate (NULL)
296+
nullDate := NullDate{Valid: false}
297+
param, err = s.makeParam(nullDate)
298+
if err != nil {
299+
t.Errorf("Unexpected error for NULL NullDate: %v", err)
300+
}
301+
if param.ti.TypeId != typeDateN {
302+
t.Errorf("Expected TypeId %v for NULL NullDate, got %v", typeDateN, param.ti.TypeId)
303+
}
304+
if param.ti.TypeId == typeNull {
305+
t.Error("NULL NullDate should not use untyped NULL (typeNull)")
306+
}
307+
if len(param.buffer) != 0 {
308+
t.Error("Expected empty buffer for NULL NullDate")
309+
}
310+
if param.ti.Size != 3 {
311+
t.Errorf("Expected Size 3 for NULL NullDate, got %v", param.ti.Size)
312+
}
313+
})
314+
315+
t.Run("NullDateTime parameter encoding", func(t *testing.T) {
316+
// Test valid NullDateTime
317+
testTime := time.Date(2023, time.December, 25, 14, 30, 45, 0, time.UTC)
318+
validDateTime := NullDateTime{DateTime: civil.DateTimeOf(testTime), Valid: true}
319+
param, err := s.makeParam(validDateTime)
320+
if err != nil {
321+
t.Errorf("Unexpected error for valid NullDateTime: %v", err)
322+
}
323+
if param.ti.TypeId != typeDateTime2N {
324+
t.Errorf("Expected TypeId %v for valid NullDateTime, got %v", typeDateTime2N, param.ti.TypeId)
325+
}
326+
if len(param.buffer) == 0 {
327+
t.Error("Expected non-empty buffer for valid NullDateTime")
328+
}
329+
330+
// Test invalid NullDateTime (NULL)
331+
nullDateTime := NullDateTime{Valid: false}
332+
param, err = s.makeParam(nullDateTime)
333+
if err != nil {
334+
t.Errorf("Unexpected error for NULL NullDateTime: %v", err)
335+
}
336+
if param.ti.TypeId != typeDateTime2N {
337+
t.Errorf("Expected TypeId %v for NULL NullDateTime, got %v", typeDateTime2N, param.ti.TypeId)
338+
}
339+
if param.ti.TypeId == typeNull {
340+
t.Error("NULL NullDateTime should not use untyped NULL (typeNull)")
341+
}
342+
if len(param.buffer) != 0 {
343+
t.Error("Expected empty buffer for NULL NullDateTime")
344+
}
345+
if param.ti.Scale != 7 {
346+
t.Errorf("Expected Scale 7 for NULL NullDateTime, got %v", param.ti.Scale)
347+
}
348+
})
349+
350+
t.Run("NullTime parameter encoding", func(t *testing.T) {
351+
// Test valid NullTime
352+
testTime := time.Date(2023, time.December, 25, 14, 30, 45, 0, time.UTC)
353+
validTime := NullTime{Time: civil.TimeOf(testTime), Valid: true}
354+
param, err := s.makeParam(validTime)
355+
if err != nil {
356+
t.Errorf("Unexpected error for valid NullTime: %v", err)
357+
}
358+
if param.ti.TypeId != typeTimeN {
359+
t.Errorf("Expected TypeId %v for valid NullTime, got %v", typeTimeN, param.ti.TypeId)
360+
}
361+
if len(param.buffer) == 0 {
362+
t.Error("Expected non-empty buffer for valid NullTime")
363+
}
364+
365+
// Test invalid NullTime (NULL)
366+
nullTime := NullTime{Valid: false}
367+
param, err = s.makeParam(nullTime)
368+
if err != nil {
369+
t.Errorf("Unexpected error for NULL NullTime: %v", err)
370+
}
371+
if param.ti.TypeId != typeTimeN {
372+
t.Errorf("Expected TypeId %v for NULL NullTime, got %v", typeTimeN, param.ti.TypeId)
373+
}
374+
if param.ti.TypeId == typeNull {
375+
t.Error("NULL NullTime should not use untyped NULL (typeNull)")
376+
}
377+
if len(param.buffer) != 0 {
378+
t.Error("Expected empty buffer for NULL NullTime")
379+
}
380+
if param.ti.Scale != 7 {
381+
t.Errorf("Expected Scale 7 for NULL NullTime, got %v", param.ti.Scale)
382+
}
383+
})
384+
}

mssql.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -996,6 +996,18 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) {
996996
if valuer.Valid {
997997
return s.makeParam(valuer.Int32)
998998
}
999+
case NullDate:
1000+
if valuer.Valid {
1001+
return s.makeParamExtra(valuer.Date)
1002+
}
1003+
case NullDateTime:
1004+
if valuer.Valid {
1005+
return s.makeParamExtra(valuer.DateTime)
1006+
}
1007+
case NullTime:
1008+
if valuer.Valid {
1009+
return s.makeParamExtra(valuer.Time)
1010+
}
9991011
case UniqueIdentifier:
10001012
case NullUniqueIdentifier:
10011013
default:
@@ -1143,6 +1155,20 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) {
11431155
} else {
11441156
res.ti.TypeId = typeDateTimeN
11451157
}
1158+
case NullDate: // only null values reach here
1159+
res.ti.TypeId = typeDateN
1160+
res.ti.Size = 3
1161+
res.buffer = []byte{}
1162+
case NullDateTime: // only null values reach here
1163+
res.ti.TypeId = typeDateTime2N
1164+
res.ti.Scale = 7
1165+
res.ti.Size = calcTimeSize(int(res.ti.Scale)) + 3
1166+
res.buffer = []byte{}
1167+
case NullTime: // only null values reach here
1168+
res.ti.TypeId = typeTimeN
1169+
res.ti.Scale = 7
1170+
res.ti.Size = calcTimeSize(int(res.ti.Scale))
1171+
res.buffer = []byte{}
11461172
case driver.Valuer:
11471173
// We have a custom Valuer implementation with a nil value
11481174
return s.makeParam(nil)

0 commit comments

Comments
 (0)