Skip to content

Commit 7e0e0db

Browse files
committed
address
1 parent 5576a85 commit 7e0e0db

File tree

3 files changed

+260
-1
lines changed

3 files changed

+260
-1
lines changed

src/arrow_parquet/field_ids.rs

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::{collections::HashMap, fmt::Display, str::FromStr};
22

3+
use arrow_schema::{DataType, Schema};
34
use serde::{Deserialize, Serialize};
45

56
#[derive(Debug, Clone, Default)]
@@ -10,6 +11,7 @@ pub(crate) enum FieldIds {
1011
Explicit(FieldIdMapping),
1112
}
1213

14+
/// Implements parsing for the field_ids option in COPY .. TO statements
1315
impl FromStr for FieldIds {
1416
type Err = String;
1517

@@ -50,6 +52,7 @@ pub(crate) struct FieldIdMapping {
5052
}
5153

5254
impl FieldIdMapping {
55+
/// Returns the field ID, if any, from `FieldIdMapping` for the given field path.
5356
pub(crate) fn field_id(&self, field_path: &[String]) -> Option<i32> {
5457
if field_path.is_empty() {
5558
panic!("Field path is empty");
@@ -65,6 +68,94 @@ impl FieldIdMapping {
6568
None => None,
6669
}
6770
}
71+
72+
/// Validates that every field name in the `FieldIdMapping` exists in the provided Arrow schema
73+
fn validate_against_schema(&self, arrow_schema: &Schema) -> Result<(), String> {
74+
// Build a map from field name to &Field for quick lookups
75+
let mut arrow_field_map = HashMap::new();
76+
for field in arrow_schema.fields() {
77+
arrow_field_map.insert(field.name().clone(), field);
78+
}
79+
80+
// Check every field name in the JSON mapping
81+
for (field_name, mapping_item) in &self.fields {
82+
if field_name == "__root_field_id" {
83+
// Skip the root field, as it doesn't exist in the Arrow schema
84+
continue;
85+
}
86+
87+
// Ensure the field exists in the Arrow schema
88+
let arrow_field = match arrow_field_map.get(field_name) {
89+
Some(f) => f,
90+
None => {
91+
return Err(format!(
92+
"Field '{}' in the mapping does not exist in the Arrow schema.\nAvailable fields: {:?}",
93+
field_name,
94+
arrow_schema
95+
.fields()
96+
.iter()
97+
.map(|f| f.name())
98+
.collect::<Vec<_>>()
99+
));
100+
}
101+
};
102+
103+
match mapping_item {
104+
// If the JSON item is an integer field ID, we're done
105+
FieldIdMappingItem::FieldId(_id) => {}
106+
107+
// If the JSON item is a nested mapping, we need to validate it
108+
FieldIdMappingItem::FieldIdMapping(mapping) => match arrow_field.data_type() {
109+
DataType::Struct(subfields) => {
110+
// We expect the JSON keys to include something like:
111+
// "__root_field_id": <int>,
112+
// "field_name": <int or nested mapping>
113+
114+
let subschema = Schema::new(subfields.clone());
115+
mapping.validate_against_schema(&subschema)?;
116+
}
117+
DataType::List(element_field) => {
118+
// We expect the JSON keys to include something like:
119+
// "__root_field_id": <int>,
120+
// "element": <int or nested mapping>
121+
//
122+
123+
let element_schema = Schema::new(vec![element_field.clone()]);
124+
mapping.validate_against_schema(&element_schema)?;
125+
}
126+
DataType::Map(entry_field, _) => {
127+
// We expect the JSON keys to include something like:
128+
// "__root_field_id": <int>,
129+
// "key": <int or nested mapping>
130+
// "val": <int or nested mapping>
131+
132+
match entry_field.data_type() {
133+
DataType::Struct(entry_fields) => {
134+
let entry_schema = Schema::new(entry_fields.clone());
135+
mapping.validate_against_schema(&entry_schema)?;
136+
}
137+
other_type => {
138+
panic!(
139+
"Map entry field should be a struct, but got '{:?}' for field '{}'",
140+
other_type, field_name
141+
);
142+
}
143+
};
144+
145+
return Ok(());
146+
}
147+
other_type => {
148+
panic!(
149+
"Unexpected data type '{:?}' for field '{}'",
150+
other_type, field_name
151+
);
152+
}
153+
},
154+
}
155+
}
156+
157+
Ok(())
158+
}
68159
}
69160

70161
pub(crate) fn field_id_mapping_from_json_string(
@@ -76,3 +167,15 @@ pub(crate) fn field_id_mapping_from_json_string(
76167
fn field_id_mapping_to_json_string(field_id_mapping: &FieldIdMapping) -> String {
77168
serde_json::to_string(field_id_mapping).unwrap()
78169
}
170+
171+
/// Validate that every field name in the `FieldIdMapping` exists in the provided Arrow schema
172+
/// when the `FieldIds` are explicitly specified.
173+
pub(crate) fn validate_field_ids(field_ids: FieldIds, arrow_schema: &Schema) -> Result<(), String> {
174+
match field_ids {
175+
FieldIds::None => Ok(()),
176+
FieldIds::Auto => Ok(()),
177+
FieldIds::Explicit(field_id_mapping) => {
178+
field_id_mapping.validate_against_schema(arrow_schema)
179+
}
180+
}
181+
}

src/arrow_parquet/parquet_writer.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use pgrx::{heap_tuple::PgHeapTuple, AllocatedByRust, PgTupleDesc};
1212
use crate::{
1313
arrow_parquet::{
1414
compression::PgParquetCompressionWithLevel,
15+
field_ids::validate_field_ids,
1516
pg_to_arrow::context::collect_pg_to_arrow_attribute_contexts,
1617
schema_parser::{
1718
parquet_schema_string_from_attributes, parse_arrow_schema_from_attributes,
@@ -62,7 +63,10 @@ impl ParquetWriterContext {
6263
parquet_schema_string_from_attributes(&attributes, field_ids.clone())
6364
);
6465

65-
let schema = parse_arrow_schema_from_attributes(&attributes, field_ids);
66+
let schema = parse_arrow_schema_from_attributes(&attributes, field_ids.clone());
67+
68+
validate_field_ids(field_ids, &schema).unwrap_or_else(|e| panic!("{e}"));
69+
6670
let schema = Arc::new(schema);
6771

6872
let writer_props = Self::writer_props(tupledesc, options);

src/pgrx_tests/copy_options.rs

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,4 +1002,156 @@ mod tests {
10021002
]
10031003
);
10041004
}
1005+
1006+
#[pg_test]
1007+
fn test_explicit_field_ids_with_missing_field_ids() {
1008+
let setup_commands = "create type dog as (id int, name text);
1009+
create type person as (id int, dog dog, dogs dog[]);
1010+
create table test_table(a int, b text, c person, d person[]);";
1011+
Spi::run(setup_commands).unwrap();
1012+
1013+
let explicit_field_ids = "{\"a\": 10,
1014+
\"c\": {
1015+
\"__root_field_id\": 100,
1016+
\"id\": 200,
1017+
\"dog\": {
1018+
\"id\": 400,
1019+
\"name\": 500
1020+
},
1021+
\"dogs\": {
1022+
\"__root_field_id\": 600,
1023+
\"element\": {
1024+
\"__root_field_id\": 700,
1025+
\"id\": 800,
1026+
\"name\": 900
1027+
}
1028+
}
1029+
},
1030+
\"d\": {
1031+
\"__root_field_id\": 1000,
1032+
\"element\": {
1033+
\"dog\": {
1034+
\"__root_field_id\": 1300,
1035+
\"id\": 1400,
1036+
\"name\": 1500
1037+
},
1038+
\"dogs\": {
1039+
\"__root_field_id\": 1600,
1040+
\"element\": {
1041+
\"__root_field_id\": 1700,
1042+
\"id\": 1800
1043+
}
1044+
}
1045+
}
1046+
}
1047+
}";
1048+
1049+
let copy_to_parquet = format!(
1050+
"copy test_table to '{LOCAL_TEST_FILE_PATH}' with (field_ids '{explicit_field_ids}');"
1051+
);
1052+
Spi::run(&copy_to_parquet).unwrap();
1053+
1054+
let fields = Spi::connect(|client| {
1055+
let parquet_schema_command = format!(
1056+
"select field_id, name from parquet.schema('{}') order by 1,2;",
1057+
LOCAL_TEST_FILE_PATH
1058+
);
1059+
1060+
let tup_table = client.select(&parquet_schema_command, None, &[]).unwrap();
1061+
let mut results = Vec::new();
1062+
1063+
for row in tup_table {
1064+
let field_id = row["field_id"].value::<i64>().unwrap();
1065+
let name = row["name"].value::<String>().unwrap().unwrap();
1066+
1067+
results.push((field_id, name));
1068+
}
1069+
1070+
results
1071+
});
1072+
1073+
assert_eq!(
1074+
fields,
1075+
vec![
1076+
(Some(10), "a".into()),
1077+
(Some(100), "c".into()),
1078+
(Some(200), "id".into()),
1079+
(Some(400), "id".into()),
1080+
(Some(500), "name".into()),
1081+
(Some(600), "dogs".into()),
1082+
(Some(700), "element".into()),
1083+
(Some(800), "id".into()),
1084+
(Some(900), "name".into()),
1085+
(Some(1000), "d".into()),
1086+
(Some(1300), "dog".into()),
1087+
(Some(1400), "id".into()),
1088+
(Some(1500), "name".into()),
1089+
(Some(1600), "dogs".into()),
1090+
(Some(1700), "element".into()),
1091+
(Some(1800), "id".into()),
1092+
(None, "arrow_schema".into()),
1093+
(None, "b".into()),
1094+
(None, "dog".into()),
1095+
(None, "element".into()),
1096+
(None, "id".into()),
1097+
(None, "list".into()),
1098+
(None, "list".into()),
1099+
(None, "list".into()),
1100+
(None, "name".into()),
1101+
]
1102+
);
1103+
}
1104+
1105+
#[pg_test]
1106+
#[should_panic(expected = "Available fields: [\"a\", \"b\", \"c\", \"d\"]")]
1107+
fn test_explicit_field_ids_invalid_json() {
1108+
let setup_commands = "create type dog as (id int, name text);
1109+
create type person as (id int, dog dog, dogs dog[]);
1110+
create table test_table(a int, b text, c person, d person[]);";
1111+
Spi::run(setup_commands).unwrap();
1112+
1113+
let explicit_field_ids = "{\"aa\": 10, \"b\": 12}";
1114+
1115+
let copy_to_parquet = format!(
1116+
"copy test_table to '{LOCAL_TEST_FILE_PATH}' with (field_ids '{explicit_field_ids}');"
1117+
);
1118+
Spi::run(&copy_to_parquet).unwrap();
1119+
}
1120+
1121+
#[pg_test]
1122+
#[should_panic(expected = "Available fields: [\"id\", \"name\"]")]
1123+
fn test_explicit_field_ids_another_invalid_json() {
1124+
let setup_commands = "create type dog as (id int, name text);
1125+
create type person as (id int, dog dog, dogs dog[]);
1126+
create table test_table(a int, b text, c person, d person[]);";
1127+
Spi::run(setup_commands).unwrap();
1128+
1129+
let explicit_field_ids = "{\"a\": 10,
1130+
\"d\": {
1131+
\"__root_field_id\": 1000,
1132+
\"element\": {
1133+
\"__root_field_id\": 1100,
1134+
\"id\": 1200,
1135+
\"dog\": {
1136+
\"__root_field_id\": 1300,
1137+
\"id\": 1400,
1138+
\"name\": 1500
1139+
},
1140+
\"dogs\": {
1141+
\"__root_field_id\": 1600,
1142+
\"element\": {
1143+
\"__root_field_id\": 1700,
1144+
\"iddd\": 1800,
1145+
\"name\": 1900
1146+
}
1147+
}
1148+
}
1149+
}
1150+
}";
1151+
1152+
let copy_to_parquet = format!(
1153+
"copy test_table to '{LOCAL_TEST_FILE_PATH}' with (field_ids '{explicit_field_ids}');"
1154+
);
1155+
Spi::run(&copy_to_parquet).unwrap();
1156+
}
10051157
}

0 commit comments

Comments
 (0)