Skip to content

Commit e87a1f0

Browse files
committed
fix(socket): truncation when calling recv_slice
When calling `recv_slice`/`peek_slice` in the udp, icmp and raw sockets, data is lost when the provided buffer is smaller than the payload to be copied (not the case for `peek_slice` because it is not dequeued). With this commit, an `RecvError::Truncated` error is returned. In case of UDP, the endpoint is also returned in the error. In case of ICMP, the IP address is also returned in the error. I implemented it the way Whitequark proposes it. Data is still lost, but at least the caller knows that the data was truncated to the size of the provided buffer. As Whitequark says, it is preferred to call `recv` instead of `recv_slice` as there would be no truncation.
1 parent a5da783 commit e87a1f0

File tree

4 files changed

+97
-16
lines changed

4 files changed

+97
-16
lines changed

src/socket/icmp.rs

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,14 @@ impl std::error::Error for SendError {}
6161
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
6262
pub enum RecvError {
6363
Exhausted,
64+
Truncated(IpAddress),
6465
}
6566

6667
impl core::fmt::Display for RecvError {
6768
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
6869
match self {
6970
RecvError::Exhausted => write!(f, "exhausted"),
71+
RecvError::Truncated(_) => write!(f, "truncated"),
7072
}
7173
}
7274
}
@@ -130,8 +132,8 @@ impl<'a> Socket<'a> {
130132
/// Create an ICMP socket with the given buffers.
131133
pub fn new(rx_buffer: PacketBuffer<'a>, tx_buffer: PacketBuffer<'a>) -> Socket<'a> {
132134
Socket {
133-
rx_buffer: rx_buffer,
134-
tx_buffer: tx_buffer,
135+
rx_buffer,
136+
tx_buffer,
135137
endpoint: Default::default(),
136138
hop_limit: None,
137139
#[cfg(feature = "async")]
@@ -394,12 +396,20 @@ impl<'a> Socket<'a> {
394396
/// Dequeue a packet received from a remote endpoint, copy the payload into the given slice,
395397
/// and return the amount of octets copied as well as the `IpAddress`
396398
///
399+
/// The payload is copied partially when the size of the given slice is smaller than the size
400+
/// of the payload. In this case, a `RecvError::Truncated` error is returned.
401+
///
397402
/// See also [recv](#method.recv).
398403
pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<(usize, IpAddress), RecvError> {
399404
let (buffer, endpoint) = self.recv()?;
400405
let length = cmp::min(data.len(), buffer.len());
401406
data[..length].copy_from_slice(&buffer[..length]);
402-
Ok((length, endpoint))
407+
408+
if data.len() < buffer.len() {
409+
Err(RecvError::Truncated(endpoint))
410+
} else {
411+
Ok((length, endpoint))
412+
}
403413
}
404414

405415
/// Filter determining which packets received by the interface are appended to
@@ -555,7 +565,7 @@ impl<'a> Socket<'a> {
555565
dst_addr,
556566
next_header: IpProtocol::Icmp,
557567
payload_len: repr.buffer_len(),
558-
hop_limit: hop_limit,
568+
hop_limit,
559569
});
560570
emit(cx, (ip_repr, IcmpRepr::Ipv4(repr)))
561571
}
@@ -592,7 +602,7 @@ impl<'a> Socket<'a> {
592602
dst_addr,
593603
next_header: IpProtocol::Icmpv6,
594604
payload_len: repr.buffer_len(),
595-
hop_limit: hop_limit,
605+
hop_limit,
596606
});
597607
emit(cx, (ip_repr, IcmpRepr::Ipv6(repr)))
598608
}
@@ -1096,6 +1106,44 @@ mod test_ipv6 {
10961106
assert!(!socket.can_recv());
10971107
}
10981108

1109+
#[rstest]
1110+
#[case::ethernet(Medium::Ethernet)]
1111+
#[cfg(feature = "medium-ethernet")]
1112+
fn test_truncated_recv_slice(#[case] medium: Medium) {
1113+
let (mut iface, _, _) = setup(medium);
1114+
let cx = iface.context();
1115+
1116+
let mut socket = socket(buffer(1), buffer(1));
1117+
assert_eq!(socket.bind(Endpoint::Ident(0x1234)), Ok(()));
1118+
1119+
let checksum = ChecksumCapabilities::default();
1120+
1121+
let mut bytes = [0xff; 24];
1122+
let mut packet = Icmpv6Packet::new_unchecked(&mut bytes[..]);
1123+
ECHOV6_REPR.emit(
1124+
&LOCAL_IPV6.into(),
1125+
&REMOTE_IPV6.into(),
1126+
&mut packet,
1127+
&checksum,
1128+
);
1129+
let data = &*packet.into_inner();
1130+
1131+
assert!(socket.accepts(cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR.into()));
1132+
socket.process(cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR.into());
1133+
assert!(socket.can_recv());
1134+
1135+
assert!(socket.accepts(cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR.into()));
1136+
socket.process(cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR.into());
1137+
1138+
let mut buffer = [0u8; 1];
1139+
assert_eq!(
1140+
socket.recv_slice(&mut buffer[..]),
1141+
Err(RecvError::Truncated(REMOTE_IPV6.into()))
1142+
);
1143+
assert_eq!(buffer[0], data[0]);
1144+
assert!(!socket.can_recv());
1145+
}
1146+
10991147
#[rstest]
11001148
#[case::ethernet(Medium::Ethernet)]
11011149
#[cfg(feature = "medium-ethernet")]

src/socket/raw.rs

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,14 @@ impl std::error::Error for SendError {}
5757
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
5858
pub enum RecvError {
5959
Exhausted,
60+
Truncated,
6061
}
6162

6263
impl core::fmt::Display for RecvError {
6364
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
6465
match self {
6566
RecvError::Exhausted => write!(f, "exhausted"),
67+
RecvError::Truncated => write!(f, "truncated"),
6668
}
6769
}
6870
}
@@ -273,12 +275,20 @@ impl<'a> Socket<'a> {
273275

274276
/// Dequeue a packet, and copy the payload into the given slice.
275277
///
278+
/// The payload is copied partially when the size of the given slice is smaller than the size
279+
/// of the payload. In this case, a `RecvError::Truncated` error is returned.
280+
///
276281
/// See also [recv](#method.recv).
277282
pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<usize, RecvError> {
278283
let buffer = self.recv()?;
279284
let length = min(data.len(), buffer.len());
280285
data[..length].copy_from_slice(&buffer[..length]);
281-
Ok(length)
286+
287+
if data.len() < buffer.len() {
288+
Err(RecvError::Truncated)
289+
} else {
290+
Ok(length)
291+
}
282292
}
283293

284294
/// Peek at a packet in the receive buffer and return a pointer to the
@@ -308,7 +318,12 @@ impl<'a> Socket<'a> {
308318
let buffer = self.peek()?;
309319
let length = min(data.len(), buffer.len());
310320
data[..length].copy_from_slice(&buffer[..length]);
311-
Ok(length)
321+
322+
if data.len() < buffer.len() {
323+
Err(RecvError::Truncated)
324+
} else {
325+
Ok(length)
326+
}
312327
}
313328

314329
pub(crate) fn accepts(&self, ip_repr: &IpRepr) -> bool {
@@ -602,7 +617,7 @@ mod test {
602617
socket.process(&mut cx, &$hdr, &$payload);
603618

604619
let mut slice = [0; 4];
605-
assert_eq!(socket.recv_slice(&mut slice[..]), Ok(4));
620+
assert_eq!(socket.recv_slice(&mut slice[..]), Err(RecvError::Truncated));
606621
assert_eq!(&slice, &$packet[..slice.len()]);
607622
}
608623

@@ -641,9 +656,9 @@ mod test {
641656
socket.process(&mut cx, &$hdr, &$payload);
642657

643658
let mut slice = [0; 4];
644-
assert_eq!(socket.peek_slice(&mut slice[..]), Ok(4));
659+
assert_eq!(socket.peek_slice(&mut slice[..]), Err(RecvError::Truncated));
645660
assert_eq!(&slice, &$packet[..slice.len()]);
646-
assert_eq!(socket.recv_slice(&mut slice[..]), Ok(4));
661+
assert_eq!(socket.recv_slice(&mut slice[..]), Err(RecvError::Truncated));
647662
assert_eq!(&slice, &$packet[..slice.len()]);
648663
assert_eq!(socket.peek_slice(&mut slice[..]), Err(RecvError::Exhausted));
649664
}

src/socket/tcp.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1331,7 +1331,7 @@ impl<'a> Socket<'a> {
13311331
// Rate-limit to 1 per second max.
13321332
self.challenge_ack_timer = cx.now() + Duration::from_secs(1);
13331333

1334-
return Some(self.ack_reply(ip_repr, repr));
1334+
Some(self.ack_reply(ip_repr, repr))
13351335
}
13361336

13371337
pub(crate) fn accepts(&self, _cx: &mut Context, ip_repr: &IpRepr, repr: &TcpRepr) -> bool {

src/socket/udp.rs

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,14 @@ impl std::error::Error for SendError {}
8888
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
8989
pub enum RecvError {
9090
Exhausted,
91+
Truncated(UdpMetadata),
9192
}
9293

9394
impl core::fmt::Display for RecvError {
9495
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
9596
match self {
9697
RecvError::Exhausted => write!(f, "exhausted"),
98+
RecvError::Truncated(_) => write!(f, "truncated"),
9799
}
98100
}
99101
}
@@ -393,12 +395,20 @@ impl<'a> Socket<'a> {
393395
/// Dequeue a packet received from a remote endpoint, copy the payload into the given slice,
394396
/// and return the amount of octets copied as well as the endpoint.
395397
///
398+
/// The payload is copied partially when the size of the given slice is smaller than the size
399+
/// of the payload. In this case, a `RecvError::Truncated` error is returned.
400+
///
396401
/// See also [recv](#method.recv).
397402
pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<(usize, UdpMetadata), RecvError> {
398403
let (buffer, endpoint) = self.recv().map_err(|_| RecvError::Exhausted)?;
399404
let length = min(data.len(), buffer.len());
400405
data[..length].copy_from_slice(&buffer[..length]);
401-
Ok((length, endpoint))
406+
407+
if data.len() < buffer.len() {
408+
Err(RecvError::Truncated(endpoint))
409+
} else {
410+
Ok((length, endpoint))
411+
}
402412
}
403413

404414
/// Peek at a packet received from a remote endpoint, and return the endpoint as well
@@ -426,12 +436,20 @@ impl<'a> Socket<'a> {
426436
/// packet from the receive buffer.
427437
/// This function otherwise behaves identically to [recv_slice](#method.recv_slice).
428438
///
439+
/// The payload is copied partially when the size of the given slice is smaller than the size
440+
/// of the payload. In this case, a `RecvError::Truncated` error is returned.
441+
///
429442
/// See also [peek](#method.peek).
430443
pub fn peek_slice(&mut self, data: &mut [u8]) -> Result<(usize, &UdpMetadata), RecvError> {
431444
let (buffer, endpoint) = self.peek()?;
432445
let length = min(data.len(), buffer.len());
433446
data[..length].copy_from_slice(&buffer[..length]);
434-
Ok((length, endpoint))
447+
448+
if data.len() < buffer.len() {
449+
Err(RecvError::Truncated(*endpoint))
450+
} else {
451+
Ok((length, endpoint))
452+
}
435453
}
436454

437455
pub(crate) fn accepts(&self, cx: &mut Context, ip_repr: &IpRepr, repr: &UdpRepr) -> bool {
@@ -853,7 +871,7 @@ mod test {
853871
let mut slice = [0; 4];
854872
assert_eq!(
855873
socket.recv_slice(&mut slice[..]),
856-
Ok((4, REMOTE_END.into()))
874+
Err(RecvError::Truncated(REMOTE_END.into()))
857875
);
858876
assert_eq!(&slice, b"abcd");
859877
}
@@ -884,12 +902,12 @@ mod test {
884902
let mut slice = [0; 4];
885903
assert_eq!(
886904
socket.peek_slice(&mut slice[..]),
887-
Ok((4, &REMOTE_END.into()))
905+
Err(RecvError::Truncated(REMOTE_END.into()))
888906
);
889907
assert_eq!(&slice, b"abcd");
890908
assert_eq!(
891909
socket.recv_slice(&mut slice[..]),
892-
Ok((4, REMOTE_END.into()))
910+
Err(RecvError::Truncated(REMOTE_END.into()))
893911
);
894912
assert_eq!(&slice, b"abcd");
895913
assert_eq!(socket.peek_slice(&mut slice[..]), Err(RecvError::Exhausted));

0 commit comments

Comments
 (0)