Skip to content

Commit ed4907a

Browse files
committed
Match fields by position via option
We add an option for `COPY FROM` called `match_by_position` which matches Parquet file fields to PostgreSQL table columns `by their position` in the schema rather than `by their names`. By default, the option is `false`. The option is useful when field names differ between the Parquet file and the table, but their order aligns. Closes #39.
1 parent fbaeadb commit ed4907a

File tree

6 files changed

+96
-23
lines changed

6 files changed

+96
-23
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ Alternatively, you can use the following environment variables when starting pos
193193

194194
`pg_parquet` supports the following options in the `COPY FROM` command:
195195
- `format parquet`: you need to specify this option to read or write Parquet files which does not end with `.parquet[.<compression>]` extension,
196+
- `match_by_position <bool>`: matches Parquet file fields to PostgreSQL table columns by their position in the schema rather than by their names. By default, the option is `false`. The option is useful when field names differ between the Parquet file and the table, but their order aligns.
196197

197198
## Configuration
198199
There is currently only one GUC parameter to enable/disable the `pg_parquet`:

src/arrow_parquet/parquet_reader.rs

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,11 @@ pub(crate) struct ParquetReaderContext {
3838
parquet_reader: ParquetRecordBatchStream<ParquetObjectReader>,
3939
attribute_contexts: Vec<ArrowToPgAttributeContext>,
4040
binary_out_funcs: Vec<PgBox<FmgrInfo>>,
41+
match_by_position: bool,
4142
}
4243

4344
impl ParquetReaderContext {
44-
pub(crate) fn new(uri: Url, tupledesc: &PgTupleDesc) -> Self {
45+
pub(crate) fn new(uri: Url, match_by_position: bool, tupledesc: &PgTupleDesc) -> Self {
4546
// Postgis and Map contexts are used throughout reading the parquet file.
4647
// We need to reset them to avoid reading the stale data. (e.g. extension could be dropped)
4748
reset_postgis_context();
@@ -69,6 +70,7 @@ impl ParquetReaderContext {
6970
parquet_file_schema.clone(),
7071
tupledesc_schema.clone(),
7172
&attributes,
73+
match_by_position,
7274
);
7375

7476
let attribute_contexts = collect_arrow_to_pg_attribute_contexts(
@@ -85,6 +87,7 @@ impl ParquetReaderContext {
8587
attribute_contexts,
8688
parquet_reader,
8789
binary_out_funcs,
90+
match_by_position,
8891
started: false,
8992
finished: false,
9093
}
@@ -116,15 +119,23 @@ impl ParquetReaderContext {
116119
fn record_batch_to_tuple_datums(
117120
record_batch: RecordBatch,
118121
attribute_contexts: &[ArrowToPgAttributeContext],
122+
match_by_position: bool,
119123
) -> Vec<Option<Datum>> {
120124
let mut datums = vec![];
121125

122-
for attribute_context in attribute_contexts {
126+
for (attribute_idx, attribute_context) in attribute_contexts.iter().enumerate() {
123127
let name = attribute_context.name();
124128

125-
let column_array = record_batch
126-
.column_by_name(name)
127-
.unwrap_or_else(|| panic!("column {} not found", name));
129+
let column_array = if match_by_position {
130+
record_batch
131+
.columns()
132+
.get(attribute_idx)
133+
.unwrap_or_else(|| panic!("column {} not found", name))
134+
} else {
135+
record_batch
136+
.column_by_name(name)
137+
.unwrap_or_else(|| panic!("column {} not found", name))
138+
};
128139

129140
let datum = if attribute_context.needs_cast() {
130141
// should fail instead of returning None if the cast fails at runtime
@@ -181,8 +192,11 @@ impl ParquetReaderContext {
181192
self.buffer.extend_from_slice(&attnum_len_bytes);
182193

183194
// convert the columnar arrays in record batch to tuple datums
184-
let tuple_datums =
185-
Self::record_batch_to_tuple_datums(record_batch, &self.attribute_contexts);
195+
let tuple_datums = Self::record_batch_to_tuple_datums(
196+
record_batch,
197+
&self.attribute_contexts,
198+
self.match_by_position,
199+
);
186200

187201
// write the tuple datums to the ParquetReader's internal buffer in PG copy format
188202
for (datum, out_func) in tuple_datums.into_iter().zip(self.binary_out_funcs.iter())

src/arrow_parquet/schema_parser.rs

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -349,21 +349,38 @@ pub(crate) fn ensure_file_schema_match_tupledesc_schema(
349349
file_schema: Arc<Schema>,
350350
tupledesc_schema: Arc<Schema>,
351351
attributes: &[FormData_pg_attribute],
352+
match_by_position: bool,
352353
) -> Vec<Option<DataType>> {
353354
let mut cast_to_types = Vec::new();
354355

356+
if match_by_position && tupledesc_schema.fields().len() != file_schema.fields().len() {
357+
panic!(
358+
"column count mismatch between table and parquet file. \
359+
parquet file has {} columns, but table has {} columns",
360+
file_schema.fields().len(),
361+
tupledesc_schema.fields().len()
362+
);
363+
}
364+
355365
for (tupledesc_schema_field, attribute) in
356366
tupledesc_schema.fields().iter().zip(attributes.iter())
357367
{
358368
let field_name = tupledesc_schema_field.name();
359369

360-
let file_schema_field = file_schema.column_with_name(field_name);
370+
let file_schema_field = if match_by_position {
371+
file_schema.field(attribute.attnum as usize - 1)
372+
} else {
373+
let file_schema_field = file_schema.column_with_name(field_name);
361374

362-
if file_schema_field.is_none() {
363-
panic!("column \"{}\" is not found in parquet file", field_name);
364-
}
375+
if file_schema_field.is_none() {
376+
panic!("column \"{}\" is not found in parquet file", field_name);
377+
}
378+
379+
let (_, file_schema_field) = file_schema_field.unwrap();
380+
381+
file_schema_field
382+
};
365383

366-
let (_, file_schema_field) = file_schema_field.unwrap();
367384
let file_schema_field = Arc::new(file_schema_field.clone());
368385

369386
let from_type = file_schema_field.data_type();
@@ -378,7 +395,7 @@ pub(crate) fn ensure_file_schema_match_tupledesc_schema(
378395
if !is_coercible(from_type, to_type, attribute.atttypid, attribute.atttypmod) {
379396
panic!(
380397
"type mismatch for column \"{}\" between table and parquet file.\n\n\
381-
table has \"{}\"\n\nparquet file has \"{}\"",
398+
table has \"{}\"\n\nparquet file has \"{}\"",
382399
field_name, to_type, from_type
383400
);
384401
}

src/parquet_copy_hook/copy_from.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ use crate::{
2020
};
2121

2222
use super::copy_utils::{
23-
copy_stmt_attribute_list, copy_stmt_create_namespace_item, copy_stmt_create_parse_state,
24-
create_filtered_tupledesc_for_relation,
23+
copy_from_stmt_match_by_position, copy_stmt_attribute_list, copy_stmt_create_namespace_item,
24+
copy_stmt_create_parse_state, create_filtered_tupledesc_for_relation,
2525
};
2626

2727
// stack to store parquet reader contexts for COPY FROM.
@@ -131,9 +131,11 @@ pub(crate) fn execute_copy_from(
131131

132132
let tupledesc = create_filtered_tupledesc_for_relation(p_stmt, &relation);
133133

134+
let match_by_position = copy_from_stmt_match_by_position(p_stmt);
135+
134136
unsafe {
135137
// parquet reader context is used throughout the COPY FROM operation.
136-
let parquet_reader_context = ParquetReaderContext::new(uri, &tupledesc);
138+
let parquet_reader_context = ParquetReaderContext::new(uri, match_by_position, &tupledesc);
137139
push_parquet_reader_context(parquet_reader_context);
138140

139141
// makes sure to set binary format

src/parquet_copy_hook/copy_utils.rs

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@ use std::{ffi::CStr, str::FromStr};
33
use pgrx::{
44
is_a,
55
pg_sys::{
6-
addRangeTableEntryForRelation, defGetInt32, defGetInt64, defGetString, get_namespace_name,
7-
get_rel_namespace, makeDefElem, makeString, make_parsestate, quote_qualified_identifier,
8-
AccessShareLock, AsPgCStr, CopyStmt, CreateTemplateTupleDesc, DefElem, List, NoLock, Node,
9-
NodeTag::T_CopyStmt, Oid, ParseNamespaceItem, ParseState, PlannedStmt, QueryEnvironment,
10-
RangeVar, RangeVarGetRelidExtended, RowExclusiveLock, TupleDescInitEntry,
6+
addRangeTableEntryForRelation, defGetBoolean, defGetInt32, defGetInt64, defGetString,
7+
get_namespace_name, get_rel_namespace, makeDefElem, makeString, make_parsestate,
8+
quote_qualified_identifier, AccessShareLock, AsPgCStr, CopyStmt, CreateTemplateTupleDesc,
9+
DefElem, List, NoLock, Node, NodeTag::T_CopyStmt, Oid, ParseNamespaceItem, ParseState,
10+
PlannedStmt, QueryEnvironment, RangeVar, RangeVarGetRelidExtended, RowExclusiveLock,
11+
TupleDescInitEntry,
1112
},
1213
PgBox, PgList, PgRelation, PgTupleDesc,
1314
};
@@ -109,7 +110,7 @@ pub(crate) fn validate_copy_to_options(p_stmt: &PgBox<PlannedStmt>, uri: &Url) {
109110
}
110111

111112
pub(crate) fn validate_copy_from_options(p_stmt: &PgBox<PlannedStmt>) {
112-
validate_copy_option_names(p_stmt, &["format", "freeze"]);
113+
validate_copy_option_names(p_stmt, &["format", "match_by_position", "freeze"]);
113114

114115
let format_option = copy_stmt_get_option(p_stmt, "format");
115116

@@ -253,6 +254,16 @@ pub(crate) fn copy_from_stmt_create_option_list(p_stmt: &PgBox<PlannedStmt>) ->
253254
new_copy_options
254255
}
255256

257+
pub(crate) fn copy_from_stmt_match_by_position(p_stmt: &PgBox<PlannedStmt>) -> bool {
258+
let match_by_position_option = copy_stmt_get_option(p_stmt, "match_by_position");
259+
260+
if match_by_position_option.is_null() {
261+
false
262+
} else {
263+
unsafe { defGetBoolean(match_by_position_option.as_ptr()) }
264+
}
265+
}
266+
256267
pub(crate) fn copy_stmt_get_option(
257268
p_stmt: &PgBox<PlannedStmt>,
258269
option_name: &str,

src/pgrx_tests/copy_from_coerce.rs

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -966,7 +966,7 @@ mod tests {
966966
}
967967

968968
#[pg_test]
969-
fn test_table_with_different_field_position() {
969+
fn test_table_with_different_position_match_by_name() {
970970
let copy_to = format!(
971971
"COPY (SELECT 1 as x, 'hello' as y) TO '{}'",
972972
LOCAL_TEST_FILE_PATH
@@ -983,6 +983,34 @@ mod tests {
983983
assert_eq!(result, (Some("hello"), Some(1)));
984984
}
985985

986+
#[pg_test]
987+
fn test_table_with_different_name_match_by_position() {
988+
let copy_to = "COPY (SELECT 1 as a, 'hello' as b) TO '/tmp/test.parquet'";
989+
Spi::run(copy_to).unwrap();
990+
991+
let create_table = "CREATE TABLE test_table (x bigint, y varchar)";
992+
Spi::run(create_table).unwrap();
993+
994+
let copy_from = "COPY test_table FROM '/tmp/test.parquet' WITH (match_by_position true)";
995+
Spi::run(copy_from).unwrap();
996+
997+
let result = Spi::get_two::<i64, &str>("SELECT x, y FROM test_table LIMIT 1").unwrap();
998+
assert_eq!(result, (Some(1), Some("hello")));
999+
}
1000+
1001+
#[pg_test]
1002+
#[should_panic(expected = "column count mismatch between table and parquet file")]
1003+
fn test_table_with_different_name_match_by_position_fail() {
1004+
let copy_to = "COPY (SELECT 1 as a, 'hello' as b) TO '/tmp/test.parquet'";
1005+
Spi::run(copy_to).unwrap();
1006+
1007+
let create_table = "CREATE TABLE test_table (x bigint, y varchar, z int)";
1008+
Spi::run(create_table).unwrap();
1009+
1010+
let copy_from = "COPY test_table FROM '/tmp/test.parquet' WITH (match_by_position true)";
1011+
Spi::run(copy_from).unwrap();
1012+
}
1013+
9861014
#[pg_test]
9871015
#[should_panic(expected = "column \"name\" is not found in parquet file")]
9881016
fn test_missing_column_in_parquet() {

0 commit comments

Comments
 (0)