Skip to content

Commit 1901982

Browse files
committed
Fix GetBSON() method usage
Original issue --- You can't use type with custom GetBSON() method mixed with structure field type and structure field reference type. For example, you can't create custom GetBSON() for Bar type: ``` struct Foo { a Bar b *Bar } ``` Type implementation (`func (t Bar) GetBSON()` ) would crash on `Foo.b = nil` value encoding. Reference implementation (`func (t *Bar) GetBSON()` ) would not call on `Foo.a` value encoding. After this change --- For type implementation `func (t Bar) GetBSON()` would not call on `Foo.b = nil` value encoding. In this case `nil` value would be seariazied as `nil` BSON value. For reference implementation `func (t *Bar) GetBSON()` would call even on `Foo.a` value encoding.
1 parent aead58f commit 1901982

File tree

2 files changed

+126
-1
lines changed

2 files changed

+126
-1
lines changed

bson/bson_test.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import (
3636
"reflect"
3737
"testing"
3838
"time"
39+
"strings"
3940

4041
"github.com/globalsign/mgo/bson"
4142
. "gopkg.in/check.v1"
@@ -381,8 +382,54 @@ func (s *S) Test64bitInt(c *C) {
381382
// --------------------------------------------------------------------------
382383
// Generic two-way struct marshaling tests.
383384

385+
type prefixPtr string
386+
type prefixVal string
387+
388+
func (t *prefixPtr) GetBSON() (interface{}, error) {
389+
if t == nil {
390+
return nil, nil
391+
}
392+
return "foo-" + string(*t), nil
393+
}
394+
395+
func (t *prefixPtr) SetBSON(raw bson.Raw) error {
396+
var s string
397+
if raw.Kind == 0x0A {
398+
return bson.SetZero
399+
}
400+
if err := raw.Unmarshal(&s); err != nil {
401+
return err
402+
}
403+
if !strings.HasPrefix(s, "foo-") {
404+
return errors.New("Prefix not found: " + s)
405+
}
406+
*t = prefixPtr(s[4:])
407+
return nil
408+
}
409+
410+
func (t prefixVal) GetBSON() (interface{}, error) {
411+
return "foo-" + string(t), nil
412+
}
413+
414+
func (t *prefixVal) SetBSON(raw bson.Raw) error {
415+
var s string
416+
if raw.Kind == 0x0A {
417+
return bson.SetZero
418+
}
419+
if err := raw.Unmarshal(&s); err != nil {
420+
return err
421+
}
422+
if !strings.HasPrefix(s, "foo-") {
423+
return errors.New("Prefix not found: " + s)
424+
}
425+
*t = prefixVal(s[4:])
426+
return nil
427+
}
428+
384429
var bytevar = byte(8)
385430
var byteptr = &bytevar
431+
var prefixptr = prefixPtr("bar")
432+
var prefixval = prefixVal("bar")
386433

387434
var structItems = []testItemType{
388435
{&struct{ Ptr *byte }{nil},
@@ -419,6 +466,24 @@ var structItems = []testItemType{
419466
// Byte arrays.
420467
{&struct{ V [2]byte }{[2]byte{'y', 'o'}},
421468
"\x05v\x00\x02\x00\x00\x00\x00yo"},
469+
470+
{&struct{ V prefixPtr }{prefixPtr("buzz")},
471+
"\x02v\x00\x09\x00\x00\x00foo-buzz\x00"},
472+
473+
{&struct{ V *prefixPtr }{&prefixptr},
474+
"\x02v\x00\x08\x00\x00\x00foo-bar\x00"},
475+
476+
{&struct{ V *prefixPtr }{nil},
477+
"\x0Av\x00"},
478+
479+
{&struct{ V prefixVal }{prefixVal("buzz")},
480+
"\x02v\x00\x09\x00\x00\x00foo-buzz\x00"},
481+
482+
{&struct{ V *prefixVal }{&prefixval},
483+
"\x02v\x00\x08\x00\x00\x00foo-bar\x00"},
484+
485+
{&struct{ V *prefixVal }{nil},
486+
"\x0Av\x00"},
422487
}
423488

424489
func (s *S) TestMarshalStructItems(c *C) {

bson/encode.go

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import (
3535
"reflect"
3636
"sort"
3737
"strconv"
38+
"sync"
3839
"time"
3940
)
4041

@@ -60,13 +61,28 @@ var (
6061

6162
const itoaCacheSize = 32
6263

64+
const (
65+
getterUnknown = iota
66+
getterNone
67+
getterTypeVal
68+
getterTypePtr
69+
getterAddr
70+
)
71+
6372
var itoaCache []string
6473

74+
var getterStyles map[reflect.Type]int
75+
var getterIface reflect.Type
76+
var getterMutex sync.RWMutex
77+
6578
func init() {
6679
itoaCache = make([]string, itoaCacheSize)
6780
for i := 0; i != itoaCacheSize; i++ {
6881
itoaCache[i] = strconv.Itoa(i)
6982
}
83+
var iface Getter
84+
getterIface = reflect.TypeOf(&iface).Elem()
85+
getterStyles = make(map[reflect.Type]int)
7086
}
7187

7288
func itoa(i int) string {
@@ -76,6 +92,50 @@ func itoa(i int) string {
7692
return strconv.Itoa(i)
7793
}
7894

95+
func getterStyle(outt reflect.Type) int {
96+
getterMutex.RLock()
97+
style := getterStyles[outt]
98+
getterMutex.RUnlock()
99+
if style == getterUnknown {
100+
getterMutex.Lock()
101+
defer getterMutex.Unlock()
102+
if outt.Implements(getterIface) {
103+
vt := outt
104+
for vt.Kind() == reflect.Ptr {
105+
vt = vt.Elem()
106+
}
107+
if vt.Implements(getterIface) {
108+
getterStyles[outt] = getterTypeVal
109+
} else {
110+
getterStyles[outt] = getterTypePtr
111+
}
112+
} else if reflect.PtrTo(outt).Implements(getterIface) {
113+
getterStyles[outt] = getterAddr
114+
} else {
115+
getterStyles[outt] = getterNone
116+
}
117+
style = getterStyles[outt]
118+
}
119+
return style
120+
}
121+
122+
func getGetter(outt reflect.Type, out reflect.Value) Getter {
123+
style := getterStyle(outt)
124+
if style == getterNone {
125+
return nil
126+
}
127+
if style == getterAddr {
128+
if !out.CanAddr() {
129+
return nil
130+
}
131+
return out.Addr().Interface().(Getter)
132+
}
133+
if style == getterTypeVal && out.Kind() == reflect.Ptr && out.IsNil() {
134+
return nil
135+
}
136+
return out.Interface().(Getter)
137+
}
138+
79139
// --------------------------------------------------------------------------
80140
// Marshaling of the document value itself.
81141

@@ -253,7 +313,7 @@ func (e *encoder) addElem(name string, v reflect.Value, minSize bool) {
253313
return
254314
}
255315

256-
if getter, ok := v.Interface().(Getter); ok {
316+
if getter := getGetter(v.Type(), v); getter != nil {
257317
getv, err := getter.GetBSON()
258318
if err != nil {
259319
panic(err)

0 commit comments

Comments
 (0)