Skip to content

Commit 44f0939

Browse files
Copilotshueybubbles
andcommitted
Fix TVP support for nullable civil types
Co-authored-by: shueybubbles <[email protected]>
1 parent 60e5558 commit 44f0939

File tree

2 files changed

+203
-0
lines changed

2 files changed

+203
-0
lines changed

tvp_go19.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"strings"
1414
"time"
1515

16+
"github.com/golang-sql/civil"
1617
"github.com/microsoft/go-mssqldb/msdsn"
1718
)
1819

@@ -108,6 +109,23 @@ func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldInd
108109
if tvp.verifyStandardTypeOnNull(buf, tvpVal) {
109110
continue
110111
}
112+
113+
// Extract inner value from nullable civil types when they are valid
114+
switch v := tvpVal.(type) {
115+
case NullDate:
116+
if v.Valid {
117+
tvpVal = v.Date
118+
}
119+
case NullDateTime:
120+
if v.Valid {
121+
tvpVal = v.DateTime
122+
}
123+
case NullTime:
124+
if v.Valid {
125+
tvpVal = v.Time
126+
}
127+
}
128+
111129
valOf := reflect.ValueOf(tvpVal)
112130
elemKind := field.Kind()
113131
if elemKind == reflect.Ptr && valOf.IsNil() {
@@ -279,6 +297,12 @@ func (tvp TVP) createZeroType(fieldVal interface{}) interface{} {
279297
return defaultInt64
280298
case sql.NullString:
281299
return defaultString
300+
case NullDate:
301+
return civil.Date{}
302+
case NullDateTime:
303+
return civil.DateTime{}
304+
case NullTime:
305+
return civil.Time{}
282306
}
283307
return fieldVal
284308
}
@@ -310,6 +334,21 @@ func (tvp TVP) verifyStandardTypeOnNull(buf *bytes.Buffer, tvpVal interface{}) b
310334
binary.Write(buf, binary.LittleEndian, uint64(_PLP_NULL))
311335
return true
312336
}
337+
case NullDate:
338+
if !val.Valid {
339+
binary.Write(buf, binary.LittleEndian, defaultNull)
340+
return true
341+
}
342+
case NullDateTime:
343+
if !val.Valid {
344+
binary.Write(buf, binary.LittleEndian, defaultNull)
345+
return true
346+
}
347+
case NullTime:
348+
if !val.Valid {
349+
binary.Write(buf, binary.LittleEndian, defaultNull)
350+
return true
351+
}
313352
}
314353
return false
315354
}

tvp_go19_test.go

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"testing"
99
"time"
1010

11+
"github.com/golang-sql/civil"
1112
"github.com/microsoft/go-mssqldb/msdsn"
1213
)
1314

@@ -587,3 +588,166 @@ func TestTVP_encode_WithGuidConversion(t *testing.T) {
587588
func TestTVP_encode(t *testing.T) {
588589
testTVP_encode(t, false /*guidConversion*/)
589590
}
591+
592+
// TestTVPWithNullCivilTypes tests that nullable civil types work correctly in TVP operations
593+
func TestTVPWithNullCivilTypes(t *testing.T) {
594+
type tvpDataRowNullDateTime struct {
595+
T NullDateTime
596+
}
597+
598+
type tvpDataRowNullDate struct {
599+
D NullDate
600+
}
601+
602+
type tvpDataRowNullTime struct {
603+
T NullTime
604+
}
605+
606+
type tvpDataRowMixed struct {
607+
Date NullDate `tvp:"date_col"`
608+
DateTime NullDateTime `tvp:"datetime_col"`
609+
Time NullTime `tvp:"time_col"`
610+
}
611+
612+
tests := []struct {
613+
name string
614+
tvpData interface{}
615+
wantErr bool
616+
}{
617+
{
618+
name: "NullDateTime with Valid=false",
619+
tvpData: []tvpDataRowNullDateTime{
620+
{T: NullDateTime{Valid: false}},
621+
},
622+
wantErr: false,
623+
},
624+
{
625+
name: "NullDateTime with Valid=true",
626+
tvpData: []tvpDataRowNullDateTime{
627+
{T: NullDateTime{DateTime: civil.DateTime{Date: civil.Date{Year: 2025, Month: 10, Day: 2}, Time: civil.Time{Hour: 16, Minute: 10, Second: 55}}, Valid: true}},
628+
},
629+
wantErr: false,
630+
},
631+
{
632+
name: "NullDate with Valid=false",
633+
tvpData: []tvpDataRowNullDate{
634+
{D: NullDate{Valid: false}},
635+
},
636+
wantErr: false,
637+
},
638+
{
639+
name: "NullDate with Valid=true",
640+
tvpData: []tvpDataRowNullDate{
641+
{D: NullDate{Date: civil.Date{Year: 2025, Month: 10, Day: 2}, Valid: true}},
642+
},
643+
wantErr: false,
644+
},
645+
{
646+
name: "NullTime with Valid=false",
647+
tvpData: []tvpDataRowNullTime{
648+
{T: NullTime{Valid: false}},
649+
},
650+
wantErr: false,
651+
},
652+
{
653+
name: "NullTime with Valid=true",
654+
tvpData: []tvpDataRowNullTime{
655+
{T: NullTime{Time: civil.Time{Hour: 16, Minute: 10, Second: 55}, Valid: true}},
656+
},
657+
wantErr: false,
658+
},
659+
{
660+
name: "Mixed nullable civil types with some null, some valid",
661+
tvpData: []tvpDataRowMixed{
662+
{
663+
Date: NullDate{Valid: false},
664+
DateTime: NullDateTime{DateTime: civil.DateTime{Date: civil.Date{Year: 2025, Month: 10, Day: 2}, Time: civil.Time{Hour: 16, Minute: 10, Second: 55}}, Valid: true},
665+
Time: NullTime{Valid: false},
666+
},
667+
{
668+
Date: NullDate{Date: civil.Date{Year: 2025, Month: 12, Day: 25}, Valid: true},
669+
DateTime: NullDateTime{Valid: false},
670+
Time: NullTime{Time: civil.Time{Hour: 9, Minute: 30, Second: 0}, Valid: true},
671+
},
672+
},
673+
wantErr: false,
674+
},
675+
{
676+
name: "User example 1: Empty NullDateTime",
677+
tvpData: []tvpDataRowNullDateTime{
678+
{T: NullDateTime{}}, // Valid defaults to false
679+
},
680+
wantErr: false,
681+
},
682+
{
683+
name: "User example 2: Valid NullDateTime",
684+
tvpData: func() []tvpDataRowNullDateTime {
685+
t1, _ := civil.ParseDateTime("2025-10-02T16:10:55")
686+
return []tvpDataRowNullDateTime{
687+
{T: NullDateTime{DateTime: t1, Valid: true}},
688+
}
689+
}(),
690+
wantErr: false,
691+
},
692+
}
693+
694+
for _, tt := range tests {
695+
t.Run(tt.name, func(t *testing.T) {
696+
tvp := TVP{
697+
TypeName: "dbo.TestType",
698+
Value: tt.tvpData,
699+
}
700+
701+
// Test columnTypes
702+
columnStr, tvpFieldIndexes, err := tvp.columnTypes()
703+
if (err != nil) != tt.wantErr {
704+
t.Errorf("TVP.columnTypes() error = %v, wantErr %v", err, tt.wantErr)
705+
return
706+
}
707+
708+
if err != nil {
709+
return // Skip encoding test if columnTypes failed
710+
}
711+
712+
// Test encode
713+
_, err = tvp.encode("dbo", "TestType", columnStr, tvpFieldIndexes, msdsn.EncodeParameters{})
714+
if (err != nil) != tt.wantErr {
715+
t.Errorf("TVP.encode() error = %v, wantErr %v", err, tt.wantErr)
716+
return
717+
}
718+
})
719+
}
720+
}
721+
722+
// TestTVPNullCivilTypesCreateZeroType tests that nullable civil types are handled correctly
723+
// in the createZeroType method when building column type information
724+
func TestTVPNullCivilTypesCreateZeroType(t *testing.T) {
725+
type tvpDataRowMixed struct {
726+
Date NullDate `tvp:"date_col"`
727+
DateTime NullDateTime `tvp:"datetime_col"`
728+
Time NullTime `tvp:"time_col"`
729+
}
730+
731+
tvp := TVP{
732+
TypeName: "dbo.TestType",
733+
Value: []tvpDataRowMixed{
734+
{}, // Empty struct to trigger createZeroType for all fields
735+
},
736+
}
737+
738+
// Test that we can get column types without error
739+
columnStr, tvpFieldIndexes, err := tvp.columnTypes()
740+
if err != nil {
741+
t.Errorf("TVP.columnTypes() with empty struct failed: %v", err)
742+
return
743+
}
744+
745+
// Should have 3 columns for the 3 fields
746+
if len(columnStr) != 3 {
747+
t.Errorf("Expected 3 columns, got %d", len(columnStr))
748+
}
749+
750+
if len(tvpFieldIndexes) != 3 {
751+
t.Errorf("Expected 3 field indexes, got %d", len(tvpFieldIndexes))
752+
}
753+
}

0 commit comments

Comments
 (0)