Skip to content

Commit 915fb79

Browse files
committed
address
1 parent ffe9c4e commit 915fb79

File tree

3 files changed

+260
-1
lines changed

3 files changed

+260
-1
lines changed

src/arrow_parquet/field_ids.rs

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::{collections::HashMap, fmt::Display, str::FromStr};
22

3+
use arrow_schema::{DataType, Schema};
34
use serde::{Deserialize, Serialize};
45

56
#[derive(Debug, Clone, Default)]
@@ -10,6 +11,7 @@ pub(crate) enum FieldIds {
1011
Explicit(FieldIdMapping),
1112
}
1213

14+
/// Implements parsing for the field_ids option in COPY .. TO statements
1315
impl FromStr for FieldIds {
1416
type Err = String;
1517

@@ -50,6 +52,7 @@ pub(crate) struct FieldIdMapping {
5052
}
5153

5254
impl FieldIdMapping {
55+
/// Returns the field ID, if any, from `FieldIdMapping` for the given field path.
5356
pub(crate) fn field_id(&self, field_path: &[String]) -> Option<i32> {
5457
if field_path.is_empty() {
5558
panic!("Field path is empty");
@@ -65,6 +68,94 @@ impl FieldIdMapping {
6568
None => None,
6669
}
6770
}
71+
72+
/// Validates that every field name in the `FieldIdMapping` exists in the provided Arrow schema
73+
fn validate_against_schema(&self, arrow_schema: &Schema) -> Result<(), String> {
74+
// Build a map from field name to &Field for quick lookups
75+
let mut arrow_field_map = HashMap::new();
76+
for field in arrow_schema.fields() {
77+
arrow_field_map.insert(field.name().clone(), field);
78+
}
79+
80+
// Check every field name in the JSON mapping
81+
for (field_name, mapping_item) in &self.fields {
82+
if field_name == "__root_field_id" {
83+
// Skip the root field, as it doesn't exist in the Arrow schema
84+
continue;
85+
}
86+
87+
// Ensure the field exists in the Arrow schema
88+
let arrow_field = match arrow_field_map.get(field_name) {
89+
Some(f) => f,
90+
None => {
91+
return Err(format!(
92+
"Field '{}' in the mapping does not exist in the Arrow schema.\nAvailable fields: {:?}",
93+
field_name,
94+
arrow_schema
95+
.fields()
96+
.iter()
97+
.map(|f| f.name())
98+
.collect::<Vec<_>>()
99+
));
100+
}
101+
};
102+
103+
match mapping_item {
104+
// If the JSON item is an integer field ID, we're done
105+
FieldIdMappingItem::FieldId(_id) => {}
106+
107+
// If the JSON item is a nested mapping, we need to validate it
108+
FieldIdMappingItem::FieldIdMapping(mapping) => match arrow_field.data_type() {
109+
DataType::Struct(subfields) => {
110+
// We expect the JSON keys to include something like:
111+
// "__root_field_id": <int>,
112+
// "field_name": <int or nested mapping>
113+
114+
let subschema = Schema::new(subfields.clone());
115+
mapping.validate_against_schema(&subschema)?;
116+
}
117+
DataType::List(element_field) => {
118+
// We expect the JSON keys to include something like:
119+
// "__root_field_id": <int>,
120+
// "element": <int or nested mapping>
121+
//
122+
123+
let element_schema = Schema::new(vec![element_field.clone()]);
124+
mapping.validate_against_schema(&element_schema)?;
125+
}
126+
DataType::Map(entry_field, _) => {
127+
// We expect the JSON keys to include something like:
128+
// "__root_field_id": <int>,
129+
// "key": <int or nested mapping>
130+
// "val": <int or nested mapping>
131+
132+
match entry_field.data_type() {
133+
DataType::Struct(entry_fields) => {
134+
let entry_schema = Schema::new(entry_fields.clone());
135+
mapping.validate_against_schema(&entry_schema)?;
136+
}
137+
other_type => {
138+
panic!(
139+
"Map entry field should be a struct, but got '{:?}' for field '{}'",
140+
other_type, field_name
141+
);
142+
}
143+
};
144+
145+
return Ok(());
146+
}
147+
other_type => {
148+
panic!(
149+
"Unexpected data type '{:?}' for field '{}'",
150+
other_type, field_name
151+
);
152+
}
153+
},
154+
}
155+
}
156+
157+
Ok(())
158+
}
68159
}
69160

70161
pub(crate) fn field_id_mapping_from_json_string(
@@ -76,3 +167,15 @@ pub(crate) fn field_id_mapping_from_json_string(
76167
fn field_id_mapping_to_json_string(field_id_mapping: &FieldIdMapping) -> String {
77168
serde_json::to_string(field_id_mapping).unwrap()
78169
}
170+
171+
/// Validate that every field name in the `FieldIdMapping` exists in the provided Arrow schema
172+
/// when the `FieldIds` are explicitly specified.
173+
pub(crate) fn validate_field_ids(field_ids: FieldIds, arrow_schema: &Schema) -> Result<(), String> {
174+
match field_ids {
175+
FieldIds::None => Ok(()),
176+
FieldIds::Auto => Ok(()),
177+
FieldIds::Explicit(field_id_mapping) => {
178+
field_id_mapping.validate_against_schema(arrow_schema)
179+
}
180+
}
181+
}

src/arrow_parquet/parquet_writer.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use pgrx::{heap_tuple::PgHeapTuple, AllocatedByRust, PgTupleDesc};
1212
use crate::{
1313
arrow_parquet::{
1414
compression::PgParquetCompressionWithLevel,
15+
field_ids::validate_field_ids,
1516
pg_to_arrow::context::collect_pg_to_arrow_attribute_contexts,
1617
schema_parser::{
1718
parquet_schema_string_from_attributes, parse_arrow_schema_from_attributes,
@@ -62,7 +63,10 @@ impl ParquetWriterContext {
6263
parquet_schema_string_from_attributes(&attributes, field_ids.clone())
6364
);
6465

65-
let schema = parse_arrow_schema_from_attributes(&attributes, field_ids);
66+
let schema = parse_arrow_schema_from_attributes(&attributes, field_ids.clone());
67+
68+
validate_field_ids(field_ids, &schema).unwrap_or_else(|e| panic!("{e}"));
69+
6670
let schema = Arc::new(schema);
6771

6872
let writer_props = Self::writer_props(tupledesc, options);

src/pgrx_tests/copy_options.rs

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,4 +1006,156 @@ mod tests {
10061006
]
10071007
);
10081008
}
1009+
1010+
#[pg_test]
1011+
fn test_explicit_field_ids_with_missing_field_ids() {
1012+
let setup_commands = "create type dog as (id int, name text);
1013+
create type person as (id int, dog dog, dogs dog[]);
1014+
create table test_table(a int, b text, c person, d person[]);";
1015+
Spi::run(setup_commands).unwrap();
1016+
1017+
let explicit_field_ids = "{\"a\": 10,
1018+
\"c\": {
1019+
\"__root_field_id\": 100,
1020+
\"id\": 200,
1021+
\"dog\": {
1022+
\"id\": 400,
1023+
\"name\": 500
1024+
},
1025+
\"dogs\": {
1026+
\"__root_field_id\": 600,
1027+
\"element\": {
1028+
\"__root_field_id\": 700,
1029+
\"id\": 800,
1030+
\"name\": 900
1031+
}
1032+
}
1033+
},
1034+
\"d\": {
1035+
\"__root_field_id\": 1000,
1036+
\"element\": {
1037+
\"dog\": {
1038+
\"__root_field_id\": 1300,
1039+
\"id\": 1400,
1040+
\"name\": 1500
1041+
},
1042+
\"dogs\": {
1043+
\"__root_field_id\": 1600,
1044+
\"element\": {
1045+
\"__root_field_id\": 1700,
1046+
\"id\": 1800
1047+
}
1048+
}
1049+
}
1050+
}
1051+
}";
1052+
1053+
let copy_to_parquet = format!(
1054+
"copy test_table to '{LOCAL_TEST_FILE_PATH}' with (field_ids '{explicit_field_ids}');"
1055+
);
1056+
Spi::run(&copy_to_parquet).unwrap();
1057+
1058+
let fields = Spi::connect(|client| {
1059+
let parquet_schema_command = format!(
1060+
"select field_id, name from parquet.schema('{}') order by 1,2;",
1061+
LOCAL_TEST_FILE_PATH
1062+
);
1063+
1064+
let tup_table = client.select(&parquet_schema_command, None, &[]).unwrap();
1065+
let mut results = Vec::new();
1066+
1067+
for row in tup_table {
1068+
let field_id = row["field_id"].value::<i64>().unwrap();
1069+
let name = row["name"].value::<String>().unwrap().unwrap();
1070+
1071+
results.push((field_id, name));
1072+
}
1073+
1074+
results
1075+
});
1076+
1077+
assert_eq!(
1078+
fields,
1079+
vec![
1080+
(Some(10), "a".into()),
1081+
(Some(100), "c".into()),
1082+
(Some(200), "id".into()),
1083+
(Some(400), "id".into()),
1084+
(Some(500), "name".into()),
1085+
(Some(600), "dogs".into()),
1086+
(Some(700), "element".into()),
1087+
(Some(800), "id".into()),
1088+
(Some(900), "name".into()),
1089+
(Some(1000), "d".into()),
1090+
(Some(1300), "dog".into()),
1091+
(Some(1400), "id".into()),
1092+
(Some(1500), "name".into()),
1093+
(Some(1600), "dogs".into()),
1094+
(Some(1700), "element".into()),
1095+
(Some(1800), "id".into()),
1096+
(None, "arrow_schema".into()),
1097+
(None, "b".into()),
1098+
(None, "dog".into()),
1099+
(None, "element".into()),
1100+
(None, "id".into()),
1101+
(None, "list".into()),
1102+
(None, "list".into()),
1103+
(None, "list".into()),
1104+
(None, "name".into()),
1105+
]
1106+
);
1107+
}
1108+
1109+
#[pg_test]
1110+
#[should_panic(expected = "Available fields: [\"a\", \"b\", \"c\", \"d\"]")]
1111+
fn test_explicit_field_ids_invalid_json() {
1112+
let setup_commands = "create type dog as (id int, name text);
1113+
create type person as (id int, dog dog, dogs dog[]);
1114+
create table test_table(a int, b text, c person, d person[]);";
1115+
Spi::run(setup_commands).unwrap();
1116+
1117+
let explicit_field_ids = "{\"aa\": 10, \"b\": 12}";
1118+
1119+
let copy_to_parquet = format!(
1120+
"copy test_table to '{LOCAL_TEST_FILE_PATH}' with (field_ids '{explicit_field_ids}');"
1121+
);
1122+
Spi::run(&copy_to_parquet).unwrap();
1123+
}
1124+
1125+
#[pg_test]
1126+
#[should_panic(expected = "Available fields: [\"id\", \"name\"]")]
1127+
fn test_explicit_field_ids_another_invalid_json() {
1128+
let setup_commands = "create type dog as (id int, name text);
1129+
create type person as (id int, dog dog, dogs dog[]);
1130+
create table test_table(a int, b text, c person, d person[]);";
1131+
Spi::run(setup_commands).unwrap();
1132+
1133+
let explicit_field_ids = "{\"a\": 10,
1134+
\"d\": {
1135+
\"__root_field_id\": 1000,
1136+
\"element\": {
1137+
\"__root_field_id\": 1100,
1138+
\"id\": 1200,
1139+
\"dog\": {
1140+
\"__root_field_id\": 1300,
1141+
\"id\": 1400,
1142+
\"name\": 1500
1143+
},
1144+
\"dogs\": {
1145+
\"__root_field_id\": 1600,
1146+
\"element\": {
1147+
\"__root_field_id\": 1700,
1148+
\"iddd\": 1800,
1149+
\"name\": 1900
1150+
}
1151+
}
1152+
}
1153+
}
1154+
}";
1155+
1156+
let copy_to_parquet = format!(
1157+
"copy test_table to '{LOCAL_TEST_FILE_PATH}' with (field_ids '{explicit_field_ids}');"
1158+
);
1159+
Spi::run(&copy_to_parquet).unwrap();
1160+
}
10091161
}

0 commit comments

Comments
 (0)