Skip to content

Commit c69fd25

Browse files
committed
refactor: modularize tests
1 parent c8af064 commit c69fd25

File tree

10 files changed

+4134
-4067
lines changed

10 files changed

+4134
-4067
lines changed

src/lib.rs

Lines changed: 2 additions & 4067 deletions
Large diffs are not rendered by default.

src/pgrx_tests/common.rs

Lines changed: 307 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
use std::fs::File;
2+
use std::marker::PhantomData;
3+
use std::{collections::HashMap, fmt::Debug};
4+
5+
use crate::type_compat::map::Map;
6+
7+
use arrow::array::RecordBatch;
8+
use arrow_schema::SchemaRef;
9+
use parquet::arrow::ArrowWriter;
10+
use pgrx::{
11+
datum::{Time, TimeWithTimeZone},
12+
FromDatum, IntoDatum, Spi,
13+
};
14+
15+
pub(crate) enum CopyOptionValue {
16+
StringOption(String),
17+
IntOption(i64),
18+
}
19+
20+
pub(crate) fn comma_separated_copy_options(options: &HashMap<String, CopyOptionValue>) -> String {
21+
let mut comma_sepated_options = String::new();
22+
23+
for (option_idx, (key, value)) in options.iter().enumerate() {
24+
match value {
25+
CopyOptionValue::StringOption(value) => {
26+
comma_sepated_options.push_str(&format!("{} '{}'", key, value));
27+
}
28+
CopyOptionValue::IntOption(value) => {
29+
comma_sepated_options.push_str(&format!("{} {}", key, value));
30+
}
31+
}
32+
33+
if option_idx < options.len() - 1 {
34+
comma_sepated_options.push_str(", ");
35+
}
36+
}
37+
38+
comma_sepated_options
39+
}
40+
41+
pub(crate) struct TestTable<T: IntoDatum + FromDatum> {
42+
uri: String,
43+
order_by_col: String,
44+
copy_to_options: HashMap<String, CopyOptionValue>,
45+
copy_from_options: HashMap<String, CopyOptionValue>,
46+
_data: PhantomData<T>,
47+
}
48+
49+
impl<T: IntoDatum + FromDatum> TestTable<T> {
50+
pub(crate) fn new(typename: String) -> Self {
51+
Spi::run("DROP TABLE IF EXISTS test_expected, test_result;").unwrap();
52+
53+
let create_table_command = format!("CREATE TABLE test_expected (a {});", &typename);
54+
Spi::run(create_table_command.as_str()).unwrap();
55+
56+
let create_table_command = format!("CREATE TABLE test_result (a {});", &typename);
57+
Spi::run(create_table_command.as_str()).unwrap();
58+
59+
let mut copy_to_options = HashMap::new();
60+
copy_to_options.insert(
61+
"format".to_string(),
62+
CopyOptionValue::StringOption("parquet".to_string()),
63+
);
64+
65+
let mut copy_from_options = HashMap::new();
66+
copy_from_options.insert(
67+
"format".to_string(),
68+
CopyOptionValue::StringOption("parquet".to_string()),
69+
);
70+
71+
let uri = "/tmp/test.parquet".to_string();
72+
73+
let order_by_col = "a".to_string();
74+
75+
Self {
76+
uri,
77+
order_by_col,
78+
copy_to_options,
79+
copy_from_options,
80+
_data: PhantomData,
81+
}
82+
}
83+
84+
pub(crate) fn with_order_by_col(mut self, order_by_col: String) -> Self {
85+
self.order_by_col = order_by_col;
86+
self
87+
}
88+
89+
pub(crate) fn with_copy_to_options(
90+
mut self,
91+
copy_to_options: HashMap<String, CopyOptionValue>,
92+
) -> Self {
93+
self.copy_to_options = copy_to_options;
94+
self
95+
}
96+
97+
pub(crate) fn with_copy_from_options(
98+
mut self,
99+
copy_from_options: HashMap<String, CopyOptionValue>,
100+
) -> Self {
101+
self.copy_from_options = copy_from_options;
102+
self
103+
}
104+
105+
pub(crate) fn with_uri(mut self, uri: String) -> Self {
106+
self.uri = uri;
107+
self
108+
}
109+
110+
pub(crate) fn insert(&self, insert_command: &str) {
111+
Spi::run(insert_command).unwrap();
112+
}
113+
114+
pub(crate) fn select_all(&self, table_name: &str) -> Vec<(Option<T>,)> {
115+
let select_command = format!(
116+
"SELECT a FROM {} ORDER BY {};",
117+
table_name, self.order_by_col
118+
);
119+
120+
Spi::connect(|client| {
121+
let mut results = Vec::new();
122+
let tup_table = client.select(&select_command, None, None).unwrap();
123+
124+
for row in tup_table {
125+
let val = row["a"].value::<T>();
126+
results.push((val.expect("could not select"),));
127+
}
128+
129+
results
130+
})
131+
}
132+
133+
pub(crate) fn copy_to_parquet(&self) {
134+
let mut copy_to_query = format!("COPY (SELECT a FROM test_expected) TO '{}'", self.uri);
135+
136+
if !self.copy_to_options.is_empty() {
137+
copy_to_query.push_str(" WITH (");
138+
139+
let options_str = comma_separated_copy_options(&self.copy_to_options);
140+
copy_to_query.push_str(&options_str);
141+
142+
copy_to_query.push(')');
143+
}
144+
145+
copy_to_query.push(';');
146+
147+
Spi::run(copy_to_query.as_str()).unwrap();
148+
}
149+
150+
pub(crate) fn copy_from_parquet(&self) {
151+
let mut copy_from_query = format!("COPY test_result FROM '{}'", self.uri);
152+
153+
if !self.copy_from_options.is_empty() {
154+
copy_from_query.push_str(" WITH (");
155+
156+
let options_str = comma_separated_copy_options(&self.copy_from_options);
157+
copy_from_query.push_str(&options_str);
158+
159+
copy_from_query.push(')');
160+
}
161+
162+
copy_from_query.push(';');
163+
164+
Spi::run(copy_from_query.as_str()).unwrap();
165+
}
166+
}
167+
168+
pub(crate) fn timetz_to_utc_time(timetz: TimeWithTimeZone) -> Option<Time> {
169+
Some(timetz.to_utc())
170+
}
171+
172+
pub(crate) fn timetz_array_to_utc_time_array(
173+
timetz_array: Vec<Option<TimeWithTimeZone>>,
174+
) -> Option<Vec<Option<Time>>> {
175+
Some(
176+
timetz_array
177+
.into_iter()
178+
.map(|timetz| timetz.map(|timetz| timetz.to_utc()))
179+
.collect(),
180+
)
181+
}
182+
183+
pub(crate) fn assert_int_text_map(expected: Option<Map>, actual: Option<Map>) {
184+
if expected.is_none() {
185+
assert!(actual.is_none());
186+
} else {
187+
assert!(actual.is_some());
188+
189+
let expected = expected.unwrap().entries;
190+
let actual = actual.unwrap().entries;
191+
192+
for (expected, actual) in expected.iter().zip(actual.iter()) {
193+
if expected.is_none() {
194+
assert!(actual.is_none());
195+
} else {
196+
assert!(actual.is_some());
197+
198+
let expected = expected.unwrap();
199+
let actual = actual.unwrap();
200+
201+
let expected_key: Option<i32> = expected.get_by_name("key").unwrap();
202+
let actual_key: Option<i32> = actual.get_by_name("key").unwrap();
203+
204+
assert_eq!(expected_key, actual_key);
205+
206+
let expected_val: Option<String> = expected.get_by_name("val").unwrap();
207+
let actual_val: Option<String> = actual.get_by_name("val").unwrap();
208+
209+
assert_eq!(expected_val, actual_val);
210+
}
211+
}
212+
}
213+
}
214+
215+
pub(crate) struct TestResult<T> {
216+
pub(crate) expected: Vec<(Option<T>,)>,
217+
pub(crate) result: Vec<(Option<T>,)>,
218+
}
219+
220+
pub(crate) fn test_common<T: IntoDatum + FromDatum>(test_table: TestTable<T>) -> TestResult<T> {
221+
test_table.copy_to_parquet();
222+
test_table.copy_from_parquet();
223+
224+
let expected = test_table.select_all("test_expected");
225+
let result = test_table.select_all("test_result");
226+
227+
TestResult { expected, result }
228+
}
229+
230+
pub(crate) fn test_assert<T>(expected_result: Vec<(Option<T>,)>, result: Vec<(Option<T>,)>)
231+
where
232+
T: Debug + PartialEq,
233+
{
234+
for (expected, actual) in expected_result.into_iter().zip(result.into_iter()) {
235+
assert_eq!(expected, actual);
236+
}
237+
}
238+
239+
pub(crate) fn test_assert_float(expected_result: Vec<Option<f32>>, result: Vec<Option<f32>>) {
240+
for (expected, actual) in expected_result.into_iter().zip(result.into_iter()) {
241+
if expected.is_none() {
242+
assert!(actual.is_none());
243+
}
244+
245+
if expected.is_some() {
246+
assert!(actual.is_some());
247+
248+
let expected = expected.unwrap();
249+
let actual = actual.unwrap();
250+
251+
if expected.is_nan() {
252+
assert!(actual.is_nan());
253+
} else if expected.is_infinite() {
254+
assert!(actual.is_infinite());
255+
assert!(expected.is_sign_positive() == actual.is_sign_positive());
256+
} else {
257+
assert_eq!(expected, actual);
258+
}
259+
}
260+
}
261+
}
262+
263+
pub(crate) fn test_assert_double(expected_result: Vec<Option<f64>>, result: Vec<Option<f64>>) {
264+
for (expected, actual) in expected_result.into_iter().zip(result.into_iter()) {
265+
if expected.is_none() {
266+
assert!(actual.is_none());
267+
}
268+
269+
if expected.is_some() {
270+
assert!(actual.is_some());
271+
272+
let expected = expected.unwrap();
273+
let actual = actual.unwrap();
274+
275+
if expected.is_nan() {
276+
assert!(actual.is_nan());
277+
} else if expected.is_infinite() {
278+
assert!(actual.is_infinite());
279+
assert!(expected.is_sign_positive() == actual.is_sign_positive());
280+
} else {
281+
assert_eq!(expected, actual);
282+
}
283+
}
284+
}
285+
}
286+
287+
pub(crate) fn test_helper<T: IntoDatum + FromDatum + Debug + PartialEq>(test_table: TestTable<T>) {
288+
let test_result = test_common(test_table);
289+
test_assert(test_result.expected, test_result.result);
290+
}
291+
292+
pub(crate) fn extension_exists(extension_name: &str) -> bool {
293+
let query = format!(
294+
"select count(*) = 1 from pg_available_extensions where name = '{}'",
295+
extension_name
296+
);
297+
298+
Spi::get_one(&query).unwrap().unwrap()
299+
}
300+
301+
pub(crate) fn write_record_batch_to_parquet(schema: SchemaRef, record_batch: RecordBatch) {
302+
let file = File::create("/tmp/test.parquet").unwrap();
303+
let mut writer = ArrowWriter::try_new(file, schema, None).unwrap();
304+
305+
writer.write(&record_batch).unwrap();
306+
writer.close().unwrap();
307+
}

0 commit comments

Comments
 (0)