Skip to content

Commit 567b819

Browse files
indielokiAndrey Kuznetsov
andauthored
Reduce per-row allocations (#149)
Co-authored-by: Andrey Kuznetsov <[email protected]>
1 parent 7e6e67b commit 567b819

File tree

2 files changed

+43
-17
lines changed

2 files changed

+43
-17
lines changed

dbscan/dbscan.go

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -274,21 +274,38 @@ func (api *API) parseSliceDestination(dst interface{}) (*sliceDestinationMeta, e
274274
}
275275

276276
func scanSliceElement(rs *RowScanner, sliceMeta *sliceDestinationMeta) error {
277-
dstValPtr := reflect.New(sliceMeta.elementBaseType)
278-
if err := rs.Scan(dstValPtr.Interface()); err != nil {
279-
return fmt.Errorf("scanning: %w", err)
280-
}
281-
var elemVal reflect.Value
277+
s := sliceMeta.val
278+
l := s.Len()
279+
growSliceByOne(s)
280+
var dstValPtr reflect.Value
282281
if sliceMeta.elementByPtr {
283-
elemVal = dstValPtr
282+
dstValPtr = reflect.New(sliceMeta.elementBaseType)
283+
s.Index(l).Set(dstValPtr)
284284
} else {
285-
elemVal = dstValPtr.Elem()
285+
dstValPtr = s.Index(l).Addr()
286+
}
287+
if err := rs.Scan(dstValPtr.Interface()); err != nil {
288+
// Undo growing the slice. Zero the value to ensure it doesn't retain garbage.
289+
s.Index(l).Set(reflect.Zero(s.Type().Elem()))
290+
s.SetLen(l)
291+
return fmt.Errorf("scanning: %w", err)
286292
}
287-
288-
sliceMeta.val.Set(reflect.Append(sliceMeta.val, elemVal))
289293
return nil
290294
}
291295

296+
func growSliceByOne(s reflect.Value) {
297+
// In go 1.20 and above, this could be made simpler (and possibly more efficient)
298+
// by using Value.Grow.
299+
l := s.Len()
300+
c := s.Cap()
301+
if l < c {
302+
s.SetLen(l + 1)
303+
return
304+
}
305+
t := s.Type().Elem()
306+
s.Set(reflect.Append(s, reflect.Zero(t)))
307+
}
308+
292309
// ScanRow is a package-level helper function that uses the DefaultAPI object.
293310
// See API.ScanRow for details.
294311
func ScanRow(dst interface{}, rows Rows) error {

dbscan/rowscanner.go

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ type RowScanner struct {
3434
started bool
3535
scanFn func(dstVal reflect.Value) error
3636
start startScannerFunc
37+
scans []interface{}
3738
}
3839

3940
// NewRowScanner is a package-level helper function that uses the DefaultAPI object.
@@ -130,13 +131,15 @@ func (*noOpScanType) Scan(value interface{}) error {
130131
}
131132

132133
func (rs *RowScanner) scanStruct(structValue reflect.Value) error {
133-
scans := make([]interface{}, len(rs.columns))
134+
if rs.scans == nil {
135+
rs.scans = make([]interface{}, len(rs.columns))
136+
}
134137
for i, column := range rs.columns {
135138
fieldIndex, ok := rs.columnToFieldIndex[column]
136139
if !ok {
137140
if rs.api.allowUnknownColumns {
138141
var tmp noOpScanType
139-
scans[i] = &tmp
142+
rs.scans[i] = &tmp
140143
continue
141144
}
142145
return fmt.Errorf(
@@ -150,9 +153,9 @@ func (rs *RowScanner) scanStruct(structValue reflect.Value) error {
150153
initializeNested(structValue, fieldIndex)
151154

152155
fieldVal := structValue.FieldByIndex(fieldIndex)
153-
scans[i] = fieldVal.Addr().Interface()
156+
rs.scans[i] = fieldVal.Addr().Interface()
154157
}
155-
if err := rs.rows.Scan(scans...); err != nil {
158+
if err := rs.rows.Scan(rs.scans...); err != nil {
156159
return fmt.Errorf("scany: scan row into struct fields: %w", err)
157160
}
158161
return nil
@@ -163,14 +166,16 @@ func (rs *RowScanner) scanMap(mapValue reflect.Value) error {
163166
mapValue.Set(reflect.MakeMap(mapValue.Type()))
164167
}
165168

166-
scans := make([]interface{}, len(rs.columns))
169+
if rs.scans == nil {
170+
rs.scans = make([]interface{}, len(rs.columns))
171+
}
167172
values := make([]reflect.Value, len(rs.columns))
168173
for i := range rs.columns {
169174
valuePtr := reflect.New(rs.mapElementType)
170-
scans[i] = valuePtr.Interface()
175+
rs.scans[i] = valuePtr.Interface()
171176
values[i] = valuePtr.Elem()
172177
}
173-
if err := rs.rows.Scan(scans...); err != nil {
178+
if err := rs.rows.Scan(rs.scans...); err != nil {
174179
return fmt.Errorf("scany: scan rows into map: %w", err)
175180
}
176181
// We can't set reflect values into destination map before scanning them,
@@ -185,7 +190,11 @@ func (rs *RowScanner) scanMap(mapValue reflect.Value) error {
185190
}
186191

187192
func (rs *RowScanner) scanPrimitive(value reflect.Value) error {
188-
if err := rs.rows.Scan(value.Addr().Interface()); err != nil {
193+
if rs.scans == nil {
194+
rs.scans = make([]interface{}, 1)
195+
}
196+
rs.scans[0] = value.Addr().Interface()
197+
if err := rs.rows.Scan(rs.scans...); err != nil {
189198
return fmt.Errorf("scany: scan row value into a primitive type: %w", err)
190199
}
191200
return nil

0 commit comments

Comments
 (0)