Skip to content

Commit cd30884

Browse files
committed
address
1 parent ffe9c4e commit cd30884

File tree

3 files changed

+259
-1
lines changed

3 files changed

+259
-1
lines changed

src/arrow_parquet/field_ids.rs

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::{collections::HashMap, fmt::Display, str::FromStr};
22

3+
use arrow_schema::{DataType, Schema};
34
use serde::{Deserialize, Serialize};
45

56
#[derive(Debug, Clone, Default)]
@@ -50,6 +51,7 @@ pub(crate) struct FieldIdMapping {
5051
}
5152

5253
impl FieldIdMapping {
54+
/// Returns the field ID, if any, from `FieldIdMapping` for the given field path.
5355
pub(crate) fn field_id(&self, field_path: &[String]) -> Option<i32> {
5456
if field_path.is_empty() {
5557
panic!("Field path is empty");
@@ -65,6 +67,94 @@ impl FieldIdMapping {
6567
None => None,
6668
}
6769
}
70+
71+
/// Validates that every field name in the `FieldIdMapping` exists in the provided Arrow schema
72+
fn validate_against_schema(&self, arrow_schema: &Schema) -> Result<(), String> {
73+
// Build a map from field name to &Field for quick lookups
74+
let mut arrow_field_map = HashMap::new();
75+
for field in arrow_schema.fields() {
76+
arrow_field_map.insert(field.name().clone(), field);
77+
}
78+
79+
// Check every field name in the JSON mapping
80+
for (field_name, mapping_item) in &self.fields {
81+
if field_name == "__root_field_id" {
82+
// Skip the root field, as it doesn't exist in the Arrow schema
83+
continue;
84+
}
85+
86+
// Ensure the field exists in the Arrow schema
87+
let arrow_field = match arrow_field_map.get(field_name) {
88+
Some(f) => f,
89+
None => {
90+
return Err(format!(
91+
"Field '{}' in the mapping does not exist in the Arrow schema.\nAvailable fields: {:?}",
92+
field_name,
93+
arrow_schema
94+
.fields()
95+
.iter()
96+
.map(|f| f.name())
97+
.collect::<Vec<_>>()
98+
));
99+
}
100+
};
101+
102+
match mapping_item {
103+
// If the JSON item is an integer field ID, we're done
104+
FieldIdMappingItem::FieldId(_id) => {}
105+
106+
// If the JSON item is a nested mapping, we need to validate it
107+
FieldIdMappingItem::FieldIdMapping(mapping) => match arrow_field.data_type() {
108+
DataType::Struct(subfields) => {
109+
// We expect the JSON keys to include something like:
110+
// "__root_field_id": <int>,
111+
// "field_name": <int or nested mapping>
112+
113+
let subschema = Schema::new(subfields.clone());
114+
mapping.validate_against_schema(&subschema)?;
115+
}
116+
DataType::List(element_field) => {
117+
// We expect the JSON keys to include something like:
118+
// "__root_field_id": <int>,
119+
// "element": <int or nested mapping>
120+
//
121+
122+
let element_schema = Schema::new(vec![element_field.clone()]);
123+
mapping.validate_against_schema(&element_schema)?;
124+
}
125+
DataType::Map(entry_field, _) => {
126+
// We expect the JSON keys to include something like:
127+
// "__root_field_id": <int>,
128+
// "key": <int or nested mapping>
129+
// "val": <int or nested mapping>
130+
131+
match entry_field.data_type() {
132+
DataType::Struct(entry_fields) => {
133+
let entry_schema = Schema::new(entry_fields.clone());
134+
mapping.validate_against_schema(&entry_schema)?;
135+
}
136+
other_type => {
137+
panic!(
138+
"Map entry field should be a struct, but got '{:?}' for field '{}'",
139+
other_type, field_name
140+
);
141+
}
142+
};
143+
144+
return Ok(());
145+
}
146+
other_type => {
147+
panic!(
148+
"Unexpected data type '{:?}' for field '{}'",
149+
other_type, field_name
150+
);
151+
}
152+
},
153+
}
154+
}
155+
156+
Ok(())
157+
}
68158
}
69159

70160
pub(crate) fn field_id_mapping_from_json_string(
@@ -76,3 +166,15 @@ pub(crate) fn field_id_mapping_from_json_string(
76166
fn field_id_mapping_to_json_string(field_id_mapping: &FieldIdMapping) -> String {
77167
serde_json::to_string(field_id_mapping).unwrap()
78168
}
169+
170+
/// Validate that every field name in the `FieldIdMapping` exists in the provided Arrow schema
171+
/// when the `FieldIds` are explicitly specified.
172+
pub(crate) fn validate_field_ids(field_ids: FieldIds, arrow_schema: &Schema) -> Result<(), String> {
173+
match field_ids {
174+
FieldIds::None => Ok(()),
175+
FieldIds::Auto => Ok(()),
176+
FieldIds::Explicit(field_id_mapping) => {
177+
field_id_mapping.validate_against_schema(arrow_schema)
178+
}
179+
}
180+
}

src/arrow_parquet/parquet_writer.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use pgrx::{heap_tuple::PgHeapTuple, AllocatedByRust, PgTupleDesc};
1212
use crate::{
1313
arrow_parquet::{
1414
compression::PgParquetCompressionWithLevel,
15+
field_ids::validate_field_ids,
1516
pg_to_arrow::context::collect_pg_to_arrow_attribute_contexts,
1617
schema_parser::{
1718
parquet_schema_string_from_attributes, parse_arrow_schema_from_attributes,
@@ -62,7 +63,10 @@ impl ParquetWriterContext {
6263
parquet_schema_string_from_attributes(&attributes, field_ids.clone())
6364
);
6465

65-
let schema = parse_arrow_schema_from_attributes(&attributes, field_ids);
66+
let schema = parse_arrow_schema_from_attributes(&attributes, field_ids.clone());
67+
68+
validate_field_ids(field_ids, &schema).unwrap_or_else(|e| panic!("{e}"));
69+
6670
let schema = Arc::new(schema);
6771

6872
let writer_props = Self::writer_props(tupledesc, options);

src/pgrx_tests/copy_options.rs

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,4 +1006,156 @@ mod tests {
10061006
]
10071007
);
10081008
}
1009+
1010+
#[pg_test]
1011+
fn test_explicit_field_ids_with_missing_field_ids() {
1012+
let setup_commands = "create type dog as (id int, name text);
1013+
create type person as (id int, dog dog, dogs dog[]);
1014+
create table test_table(a int, b text, c person, d person[]);";
1015+
Spi::run(setup_commands).unwrap();
1016+
1017+
let explicit_field_ids = "{\"a\": 10,
1018+
\"c\": {
1019+
\"__root_field_id\": 100,
1020+
\"id\": 200,
1021+
\"dog\": {
1022+
\"id\": 400,
1023+
\"name\": 500
1024+
},
1025+
\"dogs\": {
1026+
\"__root_field_id\": 600,
1027+
\"element\": {
1028+
\"__root_field_id\": 700,
1029+
\"id\": 800,
1030+
\"name\": 900
1031+
}
1032+
}
1033+
},
1034+
\"d\": {
1035+
\"__root_field_id\": 1000,
1036+
\"element\": {
1037+
\"dog\": {
1038+
\"__root_field_id\": 1300,
1039+
\"id\": 1400,
1040+
\"name\": 1500
1041+
},
1042+
\"dogs\": {
1043+
\"__root_field_id\": 1600,
1044+
\"element\": {
1045+
\"__root_field_id\": 1700,
1046+
\"id\": 1800
1047+
}
1048+
}
1049+
}
1050+
}
1051+
}";
1052+
1053+
let copy_to_parquet = format!(
1054+
"copy test_table to '{LOCAL_TEST_FILE_PATH}' with (field_ids '{explicit_field_ids}');"
1055+
);
1056+
Spi::run(&copy_to_parquet).unwrap();
1057+
1058+
let fields = Spi::connect(|client| {
1059+
let parquet_schema_command = format!(
1060+
"select field_id, name from parquet.schema('{}') order by 1,2;",
1061+
LOCAL_TEST_FILE_PATH
1062+
);
1063+
1064+
let tup_table = client.select(&parquet_schema_command, None, &[]).unwrap();
1065+
let mut results = Vec::new();
1066+
1067+
for row in tup_table {
1068+
let field_id = row["field_id"].value::<i64>().unwrap();
1069+
let name = row["name"].value::<String>().unwrap().unwrap();
1070+
1071+
results.push((field_id, name));
1072+
}
1073+
1074+
results
1075+
});
1076+
1077+
assert_eq!(
1078+
fields,
1079+
vec![
1080+
(Some(10), "a".into()),
1081+
(Some(100), "c".into()),
1082+
(Some(200), "id".into()),
1083+
(Some(400), "id".into()),
1084+
(Some(500), "name".into()),
1085+
(Some(600), "dogs".into()),
1086+
(Some(700), "element".into()),
1087+
(Some(800), "id".into()),
1088+
(Some(900), "name".into()),
1089+
(Some(1000), "d".into()),
1090+
(Some(1300), "dog".into()),
1091+
(Some(1400), "id".into()),
1092+
(Some(1500), "name".into()),
1093+
(Some(1600), "dogs".into()),
1094+
(Some(1700), "element".into()),
1095+
(Some(1800), "id".into()),
1096+
(None, "arrow_schema".into()),
1097+
(None, "b".into()),
1098+
(None, "dog".into()),
1099+
(None, "element".into()),
1100+
(None, "id".into()),
1101+
(None, "list".into()),
1102+
(None, "list".into()),
1103+
(None, "list".into()),
1104+
(None, "name".into()),
1105+
]
1106+
);
1107+
}
1108+
1109+
#[pg_test]
1110+
#[should_panic(expected = "Available fields: [\"a\", \"b\", \"c\", \"d\"]")]
1111+
fn test_explicit_field_ids_invalid_json() {
1112+
let setup_commands = "create type dog as (id int, name text);
1113+
create type person as (id int, dog dog, dogs dog[]);
1114+
create table test_table(a int, b text, c person, d person[]);";
1115+
Spi::run(setup_commands).unwrap();
1116+
1117+
let explicit_field_ids = "{\"aa\": 10, \"b\": 12}";
1118+
1119+
let copy_to_parquet = format!(
1120+
"copy test_table to '{LOCAL_TEST_FILE_PATH}' with (field_ids '{explicit_field_ids}');"
1121+
);
1122+
Spi::run(&copy_to_parquet).unwrap();
1123+
}
1124+
1125+
#[pg_test]
1126+
#[should_panic(expected = "Available fields: [\"id\", \"name\"]")]
1127+
fn test_explicit_field_ids_another_invalid_json() {
1128+
let setup_commands = "create type dog as (id int, name text);
1129+
create type person as (id int, dog dog, dogs dog[]);
1130+
create table test_table(a int, b text, c person, d person[]);";
1131+
Spi::run(setup_commands).unwrap();
1132+
1133+
let explicit_field_ids = "{\"a\": 10,
1134+
\"d\": {
1135+
\"__root_field_id\": 1000,
1136+
\"element\": {
1137+
\"__root_field_id\": 1100,
1138+
\"id\": 1200,
1139+
\"dog\": {
1140+
\"__root_field_id\": 1300,
1141+
\"id\": 1400,
1142+
\"name\": 1500
1143+
},
1144+
\"dogs\": {
1145+
\"__root_field_id\": 1600,
1146+
\"element\": {
1147+
\"__root_field_id\": 1700,
1148+
\"iddd\": 1800,
1149+
\"name\": 1900
1150+
}
1151+
}
1152+
}
1153+
}
1154+
}";
1155+
1156+
let copy_to_parquet = format!(
1157+
"copy test_table to '{LOCAL_TEST_FILE_PATH}' with (field_ids '{explicit_field_ids}');"
1158+
);
1159+
Spi::run(&copy_to_parquet).unwrap();
1160+
}
10091161
}

0 commit comments

Comments
 (0)