|
| 1 | +use std::{ffi::CStr, io::Write}; |
| 2 | + |
| 3 | +use pgrx::{ |
| 4 | + ffi::c_char, |
| 5 | + pg_sys::{ |
| 6 | + makeStringInfo, pq_beginmessage, pq_copymsgbytes, pq_endmessage, pq_getmsgstring, |
| 7 | + pq_sendbyte, pq_sendint16, QueryCancelHoldoffCount, StringInfo, |
| 8 | + }, |
| 9 | +}; |
| 10 | + |
| 11 | +use crate::arrow_parquet::uri_utils::{uri_as_string, ParsedUriInfo}; |
| 12 | + |
| 13 | +/* |
| 14 | + * CopyFromStdinState is a simplified version of CopyFromState |
| 15 | + * in PostgreSQL to use in ReceiveDataFromClient with names |
| 16 | + * preserved. |
| 17 | + */ |
| 18 | +struct CopyFromStdinState { |
| 19 | + /* buffer in which we store incoming bytes */ |
| 20 | + fe_msgbuf: StringInfo, |
| 21 | + |
| 22 | + /* whether we reached the end-of-file */ |
| 23 | + raw_reached_eof: bool, |
| 24 | +} |
| 25 | + |
| 26 | +const MAX_READ_SIZE: usize = 65536; |
| 27 | + |
| 28 | +/* |
| 29 | + * CopyInputToFile copies data from the socket to the given file. |
| 30 | + * We request the client send a specific column count. |
| 31 | + */ |
| 32 | +pub(crate) unsafe fn copy_stdin_to_file(uri_info: ParsedUriInfo, natts: i16, is_binary: bool) { |
| 33 | + let mut cstate = CopyFromStdinState { |
| 34 | + fe_msgbuf: makeStringInfo(), |
| 35 | + raw_reached_eof: false, |
| 36 | + }; |
| 37 | + |
| 38 | + /* open the destination file for writing */ |
| 39 | + let path = uri_as_string(&uri_info.uri); |
| 40 | + |
| 41 | + // create or overwrite the local file |
| 42 | + let mut file = std::fs::OpenOptions::new() |
| 43 | + .write(true) |
| 44 | + .truncate(true) |
| 45 | + .create(true) |
| 46 | + .open(&path) |
| 47 | + .unwrap_or_else(|e| panic!("{}", e)); |
| 48 | + |
| 49 | + /* tell the client we are ready for data */ |
| 50 | + send_copy_in_begin(natts, is_binary); |
| 51 | + |
| 52 | + /* allocate on the heap since it's quite big */ |
| 53 | + let mut receive_buffer = vec![0u8; MAX_READ_SIZE]; |
| 54 | + |
| 55 | + while !cstate.raw_reached_eof { |
| 56 | + /* copy some bytes from the client into fe_msgbuf */ |
| 57 | + let bytes_read = receive_data_from_client(&mut cstate, &mut receive_buffer); |
| 58 | + |
| 59 | + if bytes_read == 0 { |
| 60 | + break; |
| 61 | + } |
| 62 | + |
| 63 | + if bytes_read > 0 { |
| 64 | + /* copy bytes from fe_msgbuf to the destination file */ |
| 65 | + file.write_all(&receive_buffer[..bytes_read]) |
| 66 | + .unwrap_or_else(|e| { |
| 67 | + panic!("could not write to file: {}", e); |
| 68 | + }); |
| 69 | + } |
| 70 | + } |
| 71 | +} |
| 72 | + |
| 73 | +/* |
| 74 | + * send_copy_in_begin sends the CopyInResponse message that the client |
| 75 | + * expects after a COPY .. FROM STDIN. |
| 76 | + * |
| 77 | + * This code is adapted from ReceiveCopyBegin in PostgreSQL. |
| 78 | + */ |
| 79 | +unsafe fn send_copy_in_begin(natts: i16, is_binary: bool) { |
| 80 | + let buf = makeStringInfo(); |
| 81 | + |
| 82 | + pq_beginmessage(buf, 'G' as _); |
| 83 | + |
| 84 | + let copy_format = if is_binary { 1 } else { 0 }; |
| 85 | + pq_sendbyte(buf, copy_format); |
| 86 | + |
| 87 | + pq_sendint16(buf, natts as _); |
| 88 | + for _ in 0..natts { |
| 89 | + /* use the same format for all columns */ |
| 90 | + pq_sendint16(buf, copy_format as _); |
| 91 | + } |
| 92 | + |
| 93 | + pq_endmessage(buf); |
| 94 | + ((*PqCommMethods).flush)(); |
| 95 | +} |
| 96 | + |
| 97 | +const PQ_LARGE_MESSAGE_LIMIT: i32 = 1024 * 1024 * 1024 - 3; |
| 98 | +const PQ_SMALL_MESSAGE_LIMIT: i32 = 10000; |
| 99 | + |
| 100 | +unsafe fn receive_data_from_client( |
| 101 | + cstate: &mut CopyFromStdinState, |
| 102 | + receive_buffer: &mut [u8], |
| 103 | +) -> usize { |
| 104 | + let mut databuf = receive_buffer; |
| 105 | + |
| 106 | + let minread = 1; |
| 107 | + let mut maxread = MAX_READ_SIZE; |
| 108 | + |
| 109 | + let mut bytesread = 0; |
| 110 | + |
| 111 | + while maxread > 0 && bytesread < minread && !cstate.raw_reached_eof { |
| 112 | + let mut avail; |
| 113 | + let mut flushed = false; |
| 114 | + |
| 115 | + while flushed || (*cstate.fe_msgbuf).cursor >= (*cstate.fe_msgbuf).len { |
| 116 | + /* Try to receive another message */ |
| 117 | + |
| 118 | + QueryCancelHoldoffCount += 1; |
| 119 | + |
| 120 | + pq_startmsgread(); |
| 121 | + |
| 122 | + let mtype = pq_getbyte(); |
| 123 | + if mtype == -1 { |
| 124 | + panic!("unexpected EOF on client connection with an open transaction"); |
| 125 | + } |
| 126 | + |
| 127 | + /* Validate message type and set packet size limit */ |
| 128 | + let maxmsglen = match mtype as u8 as char { |
| 129 | + 'd' => |
| 130 | + /* CopyData */ |
| 131 | + { |
| 132 | + PQ_LARGE_MESSAGE_LIMIT |
| 133 | + } |
| 134 | + 'c' | 'f' | 'H' | 'S' => |
| 135 | + /* CopyDone, CopyFail, Flush, Sync */ |
| 136 | + { |
| 137 | + PQ_SMALL_MESSAGE_LIMIT |
| 138 | + } |
| 139 | + _ => { |
| 140 | + panic!( |
| 141 | + "unexpected message type 0x{:02X} during COPY from stdin", |
| 142 | + mtype |
| 143 | + ); |
| 144 | + } |
| 145 | + }; |
| 146 | + |
| 147 | + /* Now collect the message body */ |
| 148 | + if pq_getmessage(cstate.fe_msgbuf, maxmsglen) != 0 { |
| 149 | + panic!("unexpected EOF on client connection with an open transaction"); |
| 150 | + } |
| 151 | + |
| 152 | + QueryCancelHoldoffCount -= 1; |
| 153 | + |
| 154 | + /* ... and process it */ |
| 155 | + match mtype as u8 as char { |
| 156 | + 'd' => { |
| 157 | + /* CopyData */ |
| 158 | + break; |
| 159 | + } |
| 160 | + 'c' => { |
| 161 | + /* CopyDone */ |
| 162 | + cstate.raw_reached_eof = true; |
| 163 | + return bytesread; |
| 164 | + } |
| 165 | + 'f' => { |
| 166 | + /* CopyFail */ |
| 167 | + let msg = pq_getmsgstring(cstate.fe_msgbuf); |
| 168 | + let msg = CStr::from_ptr(msg).to_str().expect("invalid CStr"); |
| 169 | + panic!("COPY from stdin failed: {msg}"); |
| 170 | + } |
| 171 | + 'H' | 'S' => { |
| 172 | + /* Flush, Sync */ |
| 173 | + flushed = true; |
| 174 | + continue; |
| 175 | + } |
| 176 | + _ => { |
| 177 | + panic!( |
| 178 | + "unexpected message type 0x{:02X} during COPY from stdin", |
| 179 | + mtype |
| 180 | + ); |
| 181 | + } |
| 182 | + } |
| 183 | + } |
| 184 | + |
| 185 | + avail = ((*cstate.fe_msgbuf).len - (*cstate.fe_msgbuf).cursor) as _; |
| 186 | + if avail > maxread { |
| 187 | + avail = maxread; |
| 188 | + } |
| 189 | + |
| 190 | + pq_copymsgbytes(cstate.fe_msgbuf, databuf.as_mut_ptr() as _, avail as _); |
| 191 | + databuf = &mut databuf[avail..]; |
| 192 | + maxread -= avail; |
| 193 | + bytesread += avail; |
| 194 | + } |
| 195 | + |
| 196 | + bytesread |
| 197 | +} |
| 198 | + |
| 199 | +// todo: move to pgrx (include libpq.h) |
| 200 | +#[repr(C)] |
| 201 | +struct PQcommMethods { |
| 202 | + comm_reset: unsafe extern "C" fn(), |
| 203 | + flush: unsafe extern "C" fn() -> i32, |
| 204 | + flush_if_writable: unsafe extern "C" fn() -> i32, |
| 205 | + is_send_pending: unsafe extern "C" fn() -> bool, |
| 206 | + putmessage: unsafe extern "C" fn(msgtype: u32, s: *const c_char, len: usize) -> i32, |
| 207 | + putmessage_noblock: unsafe extern "C" fn(msgtype: u32, s: *const c_char, len: usize), |
| 208 | +} |
| 209 | + |
| 210 | +unsafe extern "C" { |
| 211 | + fn pq_startmsgread(); |
| 212 | + fn pq_getmessage(s: StringInfo, maxlen: i32) -> i32; |
| 213 | + fn pq_getbyte() -> i32; |
| 214 | + |
| 215 | + static PqCommMethods: *mut PQcommMethods; |
| 216 | +} |
0 commit comments