Skip to content

Commit e4a832b

Browse files
committed
add more tests
1 parent 2688cb0 commit e4a832b

File tree

5 files changed

+235
-38
lines changed

5 files changed

+235
-38
lines changed

src/arrow_parquet/parquet_reader.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use arrow::array::RecordBatch;
77
use arrow_cast::{cast_with_options, CastOptions};
88
use arrow_schema::SchemaRef;
99
use futures::StreamExt;
10+
use glob::Pattern;
1011
use parquet::arrow::async_reader::{ParquetObjectReader, ParquetRecordBatchStream};
1112
use pgrx::{
1213
check_for_interrupts,
@@ -207,9 +208,9 @@ impl ParquetReaderContext {
207208
SingleParquetReader::try_new(uri_info, match_by, tupledesc_schema.clone(), &attributes)
208209
.map(|reader| vec![reader])
209210
.unwrap_or_else(|e| {
210-
// if uri contains any pattern, try to create readers from the pattern uri
211+
// if uri contains a valid pattern, try to create readers from the pattern uri
211212
// otherwise, panic with the original error
212-
if !uri_info.path.as_ref().contains('*') {
213+
if Pattern::try_from(uri_info).is_err() {
213214
panic!("{e}");
214215
}
215216

src/parquet_udfs/list.rs

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,27 +22,19 @@ mod parquet {
2222
}
2323
}
2424

25-
fn error_if_list_http_store(uri_info: &ParsedUriInfo) {
26-
if uri_info.uri.scheme() == "http" || uri_info.uri.scheme() == "https" {
27-
panic!("list operation on http(s) object stores is not supported");
28-
}
29-
}
30-
3125
pub(crate) fn list_uri(uri_info: &ParsedUriInfo) -> Vec<(String, i64)> {
3226
ensure_access_privilege_to_uri(uri_info, true);
3327

34-
error_if_list_http_store(uri_info);
35-
3628
let base_uri = object_store_base_uri(&uri_info.uri);
3729

38-
let copy_from = true;
39-
let (parquet_object_store, location) = get_or_create_object_store(uri_info, copy_from);
40-
4130
// build the pattern before we start the stream to bail out early
42-
let pattern = Pattern::new(location.as_ref()).unwrap_or_else(|e| {
31+
let pattern = Pattern::try_from(uri_info).unwrap_or_else(|e| {
4332
panic!("{}", e);
4433
});
4534

35+
let copy_from = true;
36+
let (parquet_object_store, location) = get_or_create_object_store(uri_info, copy_from);
37+
4638
// prefix is the part of the location that doesn't contain any patterns
4739
let prefix = location
4840
.parts()
@@ -90,3 +82,15 @@ pub(crate) fn list_uri(uri_info: &ParsedUriInfo) -> Vec<(String, i64)> {
9082
})
9183
.collect::<Vec<_>>()
9284
}
85+
86+
impl TryFrom<&ParsedUriInfo> for Pattern {
87+
type Error = String;
88+
89+
fn try_from(uri_info: &ParsedUriInfo) -> Result<Self, Self::Error> {
90+
if uri_info.uri.scheme() == "http" || uri_info.uri.scheme() == "https" {
91+
return Err("list operation on http(s) object stores is not supported".into());
92+
}
93+
94+
Self::new(uri_info.path.as_ref()).map_err(|e| e.to_string())
95+
}
96+
}

src/pgrx_tests/copy_pattern.rs

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,22 @@
22
mod tests {
33
use pgrx::{pg_test, Spi};
44

5+
use crate::pgrx_tests::common::FileCleanup;
6+
57
#[pg_test]
68
#[should_panic(expected = "EOF: file size of 2 is less than footer")]
7-
fn test_pattern_invalid_parquet_file() {
9+
fn test_non_parquet_path_in_pattern() {
10+
let filename1 = "/tmp/pg_parquet_test/dummy1.parquet";
11+
let filename2 = "/tmp/pg_parquet_test/dummy2.csv";
12+
let _file1 = FileCleanup::new(filename1);
13+
let _file2 = FileCleanup::new(filename2);
14+
815
let copy_to_parquet =
9-
"copy (select 1 as a) to '/tmp/pg_parquet_test/dummy1.parquet' with (format parquet);";
10-
Spi::run(copy_to_parquet).unwrap();
16+
format!("copy (select 1 as a) to '{filename1}' with (format parquet);");
17+
Spi::run(&copy_to_parquet).unwrap();
1118

12-
let copy_to_csv =
13-
"copy (select 1 as a) to '/tmp/pg_parquet_test/dummy2.csv' with (format csv);";
14-
Spi::run(copy_to_csv).unwrap();
19+
let copy_to_csv = format!("copy (select 1 as a) to '{filename2}' with (format csv);");
20+
Spi::run(&copy_to_csv).unwrap();
1521

1622
let create_table = "create table test_table(a int);";
1723
Spi::run(create_table).unwrap();
@@ -25,8 +31,10 @@ mod tests {
2531
}
2632

2733
#[pg_test]
28-
fn test_pattern_with_special_parquet_file_name() {
34+
fn test_path_with_special_parquet_file_name() {
2935
let filename = "/tmp/pg_parquet_test/du**mm*y1.parquet";
36+
let _file = FileCleanup::new(filename);
37+
3038
let copy_to_parquet =
3139
format!("copy (select 1 as a) to '{filename}' with (format parquet);");
3240
Spi::run(&copy_to_parquet).unwrap();
@@ -41,4 +49,30 @@ mod tests {
4149
let result = Spi::get_one::<i64>(count_query).unwrap().unwrap();
4250
assert_eq!(result, 1);
4351
}
52+
53+
#[pg_test]
54+
#[should_panic(expected = "failed to get object store metadata")]
55+
fn test_path_with_special_chars() {
56+
let filename = "/tmp/pg_parquet_test/du**\\mm*y1.parquet";
57+
let _file = FileCleanup::new(filename);
58+
59+
let create_table = "create table test_table(a int);";
60+
Spi::run(create_table).unwrap();
61+
62+
let copy_from_parquet = format!("COPY test_table FROM '{filename}' WITH (format parquet);");
63+
Spi::run(copy_from_parquet.as_str()).unwrap();
64+
}
65+
66+
#[pg_test]
67+
#[should_panic(expected = "no files found that match the pattern")]
68+
fn test_with_nonexistent_pattern_path() {
69+
let filename = "/tmp/pg_parquet_test/**";
70+
let _file = FileCleanup::new(filename);
71+
72+
let create_table = "create table test_table(a int);";
73+
Spi::run(create_table).unwrap();
74+
75+
let copy_from_parquet = format!("COPY test_table FROM '{filename}' WITH (format parquet);");
76+
Spi::run(copy_from_parquet.as_str()).unwrap();
77+
}
4478
}

src/pgrx_tests/object_store.rs

Lines changed: 163 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,49 @@ mod tests {
9191
test_table.assert_expected_and_result_rows();
9292
}
9393

94+
#[pg_test]
95+
fn test_s3_uri_with_special_chars() {
96+
object_store_cache_clear();
97+
98+
let test_bucket_name: String =
99+
std::env::var("AWS_S3_TEST_BUCKET").expect("AWS_S3_TEST_BUCKET not found");
100+
101+
let s3_uri = format!("s3://{test_bucket_name}/so\\mek*ey/**testme.parquet");
102+
103+
let copy_to_command = format!(
104+
"COPY (SELECT a FROM generate_series(1,10) a) TO '{s3_uri}' WITH (format parquet);"
105+
);
106+
Spi::run(&copy_to_command).unwrap();
107+
108+
let create_table = "create table test_table(a int);";
109+
Spi::run(create_table).unwrap();
110+
111+
let copy_from_command = format!("COPY test_table FROM '{s3_uri}' WITH (format parquet);");
112+
Spi::run(copy_from_command.as_str()).unwrap();
113+
114+
let count_query = "select count(*) from test_table;";
115+
let result = Spi::get_one::<i64>(count_query).unwrap().unwrap();
116+
assert_eq!(result, 10);
117+
}
118+
119+
#[pg_test]
120+
#[should_panic(expected = "no files found that match the pattern")]
121+
fn test_s3_with_nonexistent_pattern_uri() {
122+
object_store_cache_clear();
123+
124+
let test_bucket_name: String =
125+
std::env::var("AWS_S3_TEST_BUCKET").expect("AWS_S3_TEST_BUCKET not found");
126+
127+
let s3_uri_pattern = format!("s3://{test_bucket_name}/notexists/**");
128+
129+
let create_table = "create table test_table(a int);";
130+
Spi::run(create_table).unwrap();
131+
132+
let copy_from_command =
133+
format!("COPY test_table FROM '{s3_uri_pattern}' WITH (format parquet);");
134+
Spi::run(copy_from_command.as_str()).unwrap();
135+
}
136+
94137
#[pg_test]
95138
fn test_s3_from_config_file() {
96139
object_store_cache_clear();
@@ -491,6 +534,52 @@ mod tests {
491534
test_table.assert_expected_and_result_rows();
492535
}
493536

537+
#[pg_test]
538+
#[cfg(not(rhel8))]
539+
fn test_azure_uri_with_special_chars() {
540+
object_store_cache_clear();
541+
542+
let test_container_name: String = std::env::var("AZURE_TEST_CONTAINER_NAME")
543+
.expect("AZURE_TEST_CONTAINER_NAME not found");
544+
545+
let azure_uri = format!("az://{test_container_name}/so\\mek*ey/**testme.parquet");
546+
547+
let copy_to_command = format!(
548+
"COPY (SELECT a FROM generate_series(1,10) a) TO '{azure_uri}' WITH (format parquet);"
549+
);
550+
Spi::run(&copy_to_command).unwrap();
551+
552+
let create_table = "create table test_table(a int);";
553+
Spi::run(create_table).unwrap();
554+
555+
let copy_from_command =
556+
format!("COPY test_table FROM '{azure_uri}' WITH (format parquet);");
557+
Spi::run(copy_from_command.as_str()).unwrap();
558+
559+
let count_query = "select count(*) from test_table;";
560+
let result = Spi::get_one::<i64>(count_query).unwrap().unwrap();
561+
assert_eq!(result, 10);
562+
}
563+
564+
#[pg_test]
565+
#[cfg(not(rhel8))]
566+
#[should_panic(expected = "no files found that match the pattern")]
567+
fn test_azure_with_nonexistent_pattern_uri() {
568+
object_store_cache_clear();
569+
570+
let test_container_name: String = std::env::var("AZURE_TEST_CONTAINER_NAME")
571+
.expect("AZURE_TEST_CONTAINER_NAME not found");
572+
573+
let azure_uri_pattern = format!("az://{test_container_name}/notexists/**");
574+
575+
let create_table = "create table test_table(a int);";
576+
Spi::run(create_table).unwrap();
577+
578+
let copy_from_command =
579+
format!("COPY test_table FROM '{azure_uri_pattern}' WITH (format parquet);");
580+
Spi::run(copy_from_command.as_str()).unwrap();
581+
}
582+
494583
#[pg_test]
495584
#[cfg(not(rhel8))]
496585
fn test_azure_from_config_file() {
@@ -767,32 +856,46 @@ mod tests {
767856
}
768857

769858
#[pg_test]
770-
#[should_panic(expected = "list operation on http(s) object stores is not supported")]
771-
fn test_http_uri_glob_pattern() {
859+
fn test_http_uri_with_special_chars() {
772860
object_store_cache_clear();
773861

774862
let http_endpoint: String =
775863
std::env::var("HTTP_ENDPOINT").expect("HTTP_ENDPOINT not found");
776864

777-
let http_uri =
778-
format!("{http_endpoint}/test_http_uri_glob_pattern/some/pg_parquet_test.parquet");
779-
780-
let http_uri_pattern = format!("{http_endpoint}/test_http_uri_glob_pattern/**/*.parquet");
865+
let http_uri = format!("{http_endpoint}/so\\mek*ey/**testme.parquet");
781866

782-
let mut copy_options = HashMap::new();
783-
copy_options.insert(
784-
"file_size_bytes".to_string(),
785-
CopyOptionValue::StringOption("1MB".to_string()),
867+
let copy_to_command = format!(
868+
"COPY (SELECT a FROM generate_series(1,10) a) TO '{http_uri}' WITH (format parquet);"
786869
);
870+
Spi::run(&copy_to_command).unwrap();
787871

788-
let test_table = TestTable::<i32>::new("int4".into())
789-
.with_uri(http_uri)
790-
.with_uri_pattern(http_uri_pattern)
791-
.with_order_by_col("a".into())
792-
.with_copy_to_options(copy_options);
872+
let create_table = "create table test_table(a int);";
873+
Spi::run(create_table).unwrap();
793874

794-
test_table.insert("INSERT INTO test_expected select i from generate_series(1, 1000000) i;");
795-
test_table.assert_expected_and_result_rows();
875+
let copy_from_command = format!("COPY test_table FROM '{http_uri}' WITH (format parquet);");
876+
Spi::run(copy_from_command.as_str()).unwrap();
877+
878+
let count_query = "select count(*) from test_table;";
879+
let result = Spi::get_one::<i64>(count_query).unwrap().unwrap();
880+
assert_eq!(result, 10);
881+
}
882+
883+
#[pg_test]
884+
#[should_panic(expected = "failed to get object store metadata for uri")]
885+
fn test_http_with_nonexistent_pattern_uri() {
886+
object_store_cache_clear();
887+
888+
let http_endpoint: String =
889+
std::env::var("HTTP_ENDPOINT").expect("HTTP_ENDPOINT not found");
890+
891+
let http_uri_pattern = format!("{http_endpoint}/notexists/**");
892+
893+
let create_table = "create table test_table(a int);";
894+
Spi::run(create_table).unwrap();
895+
896+
let copy_from_command =
897+
format!("COPY test_table FROM '{http_uri_pattern}' WITH (format parquet);");
898+
Spi::run(copy_from_command.as_str()).unwrap();
796899
}
797900

798901
#[pg_test]
@@ -839,6 +942,49 @@ mod tests {
839942
test_table.assert_expected_and_result_rows();
840943
}
841944

945+
#[pg_test]
946+
fn test_gcs_uri_with_special_chars() {
947+
object_store_cache_clear();
948+
949+
let test_bucket_name: String =
950+
std::env::var("GOOGLE_TEST_BUCKET").expect("GOOGLE_TEST_BUCKET not found");
951+
952+
let gcs_uri = format!("gs://{test_bucket_name}/so\\mek*ey/**testme.parquet");
953+
954+
let copy_to_command = format!(
955+
"COPY (SELECT a FROM generate_series(1,10) a) TO '{gcs_uri}' WITH (format parquet);"
956+
);
957+
Spi::run(&copy_to_command).unwrap();
958+
959+
let create_table = "create table test_table(a int);";
960+
Spi::run(create_table).unwrap();
961+
962+
let copy_from_command = format!("COPY test_table FROM '{gcs_uri}' WITH (format parquet);");
963+
Spi::run(copy_from_command.as_str()).unwrap();
964+
965+
let count_query = "select count(*) from test_table;";
966+
let result = Spi::get_one::<i64>(count_query).unwrap().unwrap();
967+
assert_eq!(result, 10);
968+
}
969+
970+
#[pg_test]
971+
#[should_panic(expected = "no files found that match the pattern")]
972+
fn test_gcs_with_nonexistent_pattern_uri() {
973+
object_store_cache_clear();
974+
975+
let test_bucket_name: String =
976+
std::env::var("GOOGLE_TEST_BUCKET").expect("GOOGLE_TEST_BUCKET not found");
977+
978+
let gcs_uri_pattern = format!("gs://{test_bucket_name}/notexists/**");
979+
980+
let create_table = "create table test_table(a int);";
981+
Spi::run(create_table).unwrap();
982+
983+
let copy_from_command =
984+
format!("COPY test_table FROM '{gcs_uri_pattern}' WITH (format parquet);");
985+
Spi::run(copy_from_command.as_str()).unwrap();
986+
}
987+
842988
#[pg_test]
843989
#[should_panic(expected = "404 Not Found")]
844990
fn test_gcs_write_wrong_bucket() {

src/pgrx_tests/udfs.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,4 +1082,16 @@ mod tests {
10821082
vec![]
10831083
);
10841084
}
1085+
1086+
#[pg_test]
1087+
#[should_panic(expected = "list operation on http(s) object stores is not supported")]
1088+
fn test_http_list_not_allowed() {
1089+
let http_endpoint: String =
1090+
std::env::var("HTTP_ENDPOINT").expect("HTTP_ENDPOINT not found");
1091+
1092+
let http_uri_pattern = format!("{http_endpoint}/**");
1093+
1094+
let parquet_list_command = format!("select * from parquet.list('{http_uri_pattern}');");
1095+
Spi::run(&parquet_list_command).unwrap();
1096+
}
10851097
}

0 commit comments

Comments
 (0)