Skip to content

Commit 32a7dd3

Browse files
committed
Add support for complex types
1 parent 6b69c1b commit 32a7dd3

File tree

2 files changed

+326
-86
lines changed

2 files changed

+326
-86
lines changed

presto-flight-shim/src/main/java/com/facebook/presto/flightshim/ArrowBatchSource.java

Lines changed: 124 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import com.facebook.presto.common.Page;
1717
import com.facebook.presto.common.block.Block;
18-
import com.facebook.presto.common.block.BlockBuilder;
1918
import com.facebook.presto.common.block.IntArrayBlock;
2019
import com.facebook.presto.common.type.ArrayType;
2120
import com.facebook.presto.common.type.BigintType;
@@ -25,7 +24,9 @@
2524
import com.facebook.presto.common.type.DecimalType;
2625
import com.facebook.presto.common.type.DoubleType;
2726
import com.facebook.presto.common.type.IntegerType;
27+
import com.facebook.presto.common.type.MapType;
2828
import com.facebook.presto.common.type.RealType;
29+
import com.facebook.presto.common.type.RowType;
2930
import com.facebook.presto.common.type.SmallintType;
3031
import com.facebook.presto.common.type.TimeType;
3132
import com.facebook.presto.common.type.TimestampType;
@@ -57,7 +58,8 @@
5758
import org.apache.arrow.vector.VectorSchemaRoot;
5859
import org.apache.arrow.vector.complex.BaseRepeatedValueVector;
5960
import org.apache.arrow.vector.complex.ListVector;
60-
import org.apache.arrow.vector.complex.impl.UnionListWriter;
61+
import org.apache.arrow.vector.complex.MapVector;
62+
import org.apache.arrow.vector.complex.StructVector;
6163
import org.apache.arrow.vector.types.Types;
6264
import org.apache.arrow.vector.types.pojo.ArrowType;
6365
import org.apache.arrow.vector.types.pojo.Field;
@@ -72,12 +74,17 @@
7274
import java.util.ArrayList;
7375
import java.util.List;
7476
import java.util.Map;
77+
import java.util.concurrent.atomic.AtomicInteger;
7578
import java.util.stream.Collectors;
7679

7780
import static com.facebook.presto.common.type.Decimals.decodeUnscaledValue;
81+
import static com.facebook.presto.common.type.StandardTypes.ARRAY;
82+
import static com.facebook.presto.common.type.StandardTypes.MAP;
83+
import static com.facebook.presto.common.type.StandardTypes.ROW;
7884
import static com.facebook.presto.common.type.Varchars.isVarcharType;
7985
import static java.lang.Float.intBitsToFloat;
8086
import static java.lang.Math.toIntExact;
87+
import static java.lang.String.format;
8188
import static java.util.Collections.unmodifiableList;
8289
import static java.util.Objects.requireNonNull;
8390

@@ -163,7 +170,12 @@ private static void writeValueFromBlock(ArrowShimWriter writer, int row, Type ty
163170
writer.writeBoolean(row, type.getBoolean(block, position));
164171
}
165172
else if (javaType == long.class) {
166-
writer.writeLong(row, type.getLong(block, position));
173+
if (block instanceof IntArrayBlock) {
174+
writer.writeLong(row, block.toLong(position));
175+
}
176+
else {
177+
writer.writeLong(row, type.getLong(block, position));
178+
}
167179
}
168180
else if (javaType == double.class) {
169181
writer.writeDouble(row, type.getDouble(block, position));
@@ -176,7 +188,6 @@ else if (javaType == Block.class) {
176188
writer.writeBlock(row, block, position, type);
177189
}
178190
else {
179-
// TODO handle Object cursor.getObject(column)
180191
throw new UnsupportedOperationException();
181192
}
182193
}
@@ -201,11 +212,30 @@ private static Field prestoToArrowField(ColumnMetadata column)
201212
{
202213
Field field;
203214
Map<String, String> metadata = column.getProperties().entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, Object::toString));
204-
if (column.getType() instanceof ArrayType) {
215+
if (column.getType().getTypeSignature().getBase().equals(ARRAY)) {
205216
ArrayType arrayType = (ArrayType) column.getType();
206217
Field childField = prestoToArrowField(ColumnMetadata.builder().setName(BaseRepeatedValueVector.DATA_VECTOR_NAME).setType(arrayType.getElementType()).build());
207218
field = new Field(column.getName(), new FieldType(column.isNullable(), ArrowType.List.INSTANCE, null, metadata), ImmutableList.of(childField));
208219
}
220+
else if (column.getType().getTypeSignature().getBase().equals(MAP)) {
221+
MapType mapType = (MapType) column.getType();
222+
// NOTE: Arrow key type must be non-nullable
223+
Field keyField = prestoToArrowField(ColumnMetadata.builder().setName(MapVector.KEY_NAME).setType(mapType.getKeyType()).setNullable(false).build());
224+
Field valueField = prestoToArrowField(ColumnMetadata.builder().setName(MapVector.VALUE_NAME).setType(mapType.getValueType()).build());
225+
Field entriesField = new Field(MapVector.DATA_VECTOR_NAME, FieldType.notNullable(ArrowType.Struct.INSTANCE), ImmutableList.of(keyField, valueField));
226+
field = new Field(column.getName(), new FieldType(column.isNullable(), new ArrowType.Map(false), null, metadata), ImmutableList.of(entriesField));
227+
}
228+
else if (column.getType().getTypeSignature().getBase().equals(ROW)) {
229+
RowType rowType = (RowType) column.getType();
230+
List<RowType.Field> rowFields = rowType.getFields();
231+
232+
AtomicInteger childCount = new AtomicInteger();
233+
List<Field> childFields = rowFields.stream().map(f -> prestoToArrowField(
234+
ColumnMetadata.builder().setName(f.getName().orElse(format("$child%s$", childCount.incrementAndGet()))).setType(f.getType()).build()))
235+
.collect(Collectors.toList());
236+
237+
field = new Field(column.getName(), new FieldType(column.isNullable(), ArrowType.Struct.INSTANCE, null, metadata), childFields);
238+
}
209239
else {
210240
ArrowType arrowType = prestoToArrowType(column.getType());
211241
field = new Field(column.getName(), new FieldType(column.isNullable(), arrowType, null, metadata), ImmutableList.of());
@@ -302,6 +332,10 @@ private static ArrowShimWriter createArrowWriter(FieldVector vector)
302332
return new ArrowShimTimeStampWriter((TimeStampVector) vector);
303333
case LIST:
304334
return new ArrowShimListWriter((ListVector) vector);
335+
case MAP:
336+
return new ArrowShimMapWriter((MapVector) vector);
337+
case STRUCT:
338+
return new ArrowShimStructWriter((StructVector) vector);
305339
default:
306340
throw new UnsupportedOperationException("Unsupported Arrow type: " + vector.getMinorType().name());
307341
}
@@ -669,37 +703,94 @@ public void writeBlock(int index, Block block, int position, Type type)
669703
{
670704
if (type instanceof ArrayType) {
671705
ArrayType arrayType = ((ArrayType) type);
672-
Object value = type.getObject(block, position);
673-
if (value instanceof List<?>) {
674-
List<?> valuesList = (List<?>) value;
675-
//UnionListWriter listWriter = vector.getWriter();
676-
//listWriter.setPosition(index);
677-
//listWriter.startList();
678-
vector.startNewValue(index);
679-
680-
for (Object element : valuesList) {
681-
int stop = 20;
682-
}
683-
684-
((List<?>) value).forEach(element ->
685-
type.getTypeParameters().get(0)
686-
//appendTo(, element, builder)
687-
);
688-
689-
vector.endValue(index, valuesList.size());
690-
//listWriter.endList();
691-
return;
706+
Block elementBlock = arrayType.getObject(block, position);
707+
int dataIndex = vector.startNewValue(index);
708+
for (int i = 0; i < elementBlock.getPositionCount(); ++i) {
709+
writeValueFromBlock(childWriter, dataIndex + i, arrayType.getElementType(), elementBlock, i);
710+
}
711+
vector.endValue(index, elementBlock.getPositionCount());
712+
}
713+
else {
714+
throw new UnsupportedOperationException("Unknown type for writeBlock: " + type);
715+
}
716+
}
717+
}
718+
719+
private static class ArrowShimMapWriter
720+
extends ArrowShimWriter
721+
{
722+
private final MapVector vector;
723+
private final StructVector structVector;
724+
private final ArrowShimWriter keyWriter;
725+
private final ArrowShimWriter valueWriter;
726+
727+
public ArrowShimMapWriter(MapVector vector)
728+
{
729+
this.vector = vector;
730+
this.structVector = (StructVector) vector.getDataVector();
731+
this.keyWriter = createArrowWriter((FieldVector) structVector.getChildByOrdinal(0));
732+
this.valueWriter = createArrowWriter((FieldVector) structVector.getChildByOrdinal(1));
733+
}
734+
735+
@Override
736+
public void writeNull(int index)
737+
{
738+
vector.setNull(index);
739+
}
740+
741+
@Override
742+
public void writeBlock(int index, Block block, int position, Type type)
743+
{
744+
if (type instanceof MapType) {
745+
MapType mapType = ((MapType) type);
746+
Block singleMapBlock = mapType.getObject(block, position);
747+
int dataIndex = vector.startNewValue(index);
748+
int numPairs = singleMapBlock.getPositionCount() / 2;
749+
for (int i = 0; i < numPairs; ++i) {
750+
writeValueFromBlock(keyWriter, dataIndex + i, mapType.getKeyType(), singleMapBlock, i * 2);
751+
writeValueFromBlock(valueWriter, dataIndex + i, mapType.getValueType(), singleMapBlock, (i * 2) + 1);
752+
structVector.setIndexDefined(dataIndex + i);
692753
}
693-
else if (value instanceof Block) {
694-
Block elementBlock = (Block) value;
695-
//IntArrayBlock intArrayBlock = (IntArrayBlock) value;
696-
//intArrayBlock.getLong();
697-
writeValueFromBlock(childWriter, 0, ((ArrayType) type).getElementType(), elementBlock, 0);
754+
vector.endValue(index, numPairs);
755+
}
756+
else {
757+
throw new UnsupportedOperationException("Unknown type for writeBlock: " + type);
758+
}
759+
}
760+
}
761+
762+
private static class ArrowShimStructWriter
763+
extends ArrowShimWriter
764+
{
765+
private final StructVector vector;
766+
private final List<ArrowShimWriter> childWriters;
767+
768+
public ArrowShimStructWriter(StructVector vector)
769+
{
770+
this.vector = vector;
771+
this.childWriters = vector.getChildrenFromFields().stream().map(ArrowBatchSource::createArrowWriter).collect(Collectors.toList());
772+
}
773+
774+
@Override
775+
public void writeNull(int index)
776+
{
777+
vector.setNull(index);
778+
}
779+
780+
@Override
781+
public void writeBlock(int index, Block block, int position, Type type)
782+
{
783+
if (type instanceof RowType) {
784+
RowType rowType = ((RowType) type);
785+
Block singleRowBlock = rowType.getObject(block, position);
786+
for (int i = 0; i < childWriters.size(); ++i) {
787+
writeValueFromBlock(childWriters.get(i), index, rowType.getTypeParameters().get(i), singleRowBlock, i);
698788
}
699-
int stop = 10;
789+
vector.setIndexDefined(index);
790+
}
791+
else {
792+
throw new UnsupportedOperationException("Unknown type for writeBlock: " + type);
700793
}
701-
//vector.set(index, intBitsToFloat(toIntExact(value)));
702-
int stop = 10;
703794
}
704795
}
705796
}

0 commit comments

Comments
 (0)