Read IP packets outside of tokio::select to avoid cancel

This commit is contained in:
Jackson Coxson
2025-08-20 12:41:09 -06:00
parent f388aaaf2d
commit 6ce6777479
3 changed files with 175 additions and 111 deletions

View File

@@ -112,7 +112,7 @@ impl ConnectionState {
#[derive(Debug)] #[derive(Debug)]
pub struct Adapter { pub struct Adapter {
/// The underlying transport connection /// The underlying transport connection
peer: Box<dyn ReadWrite>, pub(crate) peer: Box<dyn ReadWrite>,
/// The local IP address /// The local IP address
host_ip: IpAddr, host_ip: IpAddr,
/// The remote peer's IP address /// The remote peer's IP address
@@ -538,63 +538,67 @@ impl Adapter {
} }
pub(crate) async fn process_tcp_packet(&mut self) -> Result<(), std::io::Error> { pub(crate) async fn process_tcp_packet(&mut self) -> Result<(), std::io::Error> {
loop { let ip_packet = self.read_ip_packet().await?;
let ip_packet = self.read_ip_packet().await?; self.process_tcp_packet_from_payload(&ip_packet).await
let res = TcpPacket::parse(&ip_packet)?; }
let mut ack_me = None;
if let Some(state) = self.states.get(&res.destination_port) { pub(crate) async fn process_tcp_packet_from_payload(
// A keep-alive probe: ACK set, no payload, and seq == RCV.NXT - 1 &mut self,
let is_keepalive = res.flags.ack payload: &[u8],
&& res.payload.is_empty() ) -> Result<(), std::io::Error> {
&& res.sequence_number.wrapping_add(1) == state.ack; let res = TcpPacket::parse(payload)?;
let mut ack_me = None;
if is_keepalive { if let Some(state) = self.states.get(&res.destination_port) {
// Don't update any seq/ack state; just ACK what we already expect. // A keep-alive probe: ACK set, no payload, and seq == RCV.NXT - 1
debug!("responding to keep-alive probe"); let is_keepalive = res.flags.ack
let port = res.destination_port; && res.payload.is_empty()
self.ack(port).await?; && res.sequence_number.wrapping_add(1) == state.ack;
break;
} if is_keepalive {
// Don't update any seq/ack state; just ACK what we already expect.
debug!("responding to keep-alive probe");
let port = res.destination_port;
self.ack(port).await?;
return Ok(());
}
}
if let Some(state) = self.states.get_mut(&res.destination_port) {
if state.peer_seq > res.sequence_number {
// ignore retransmission
return Ok(());
} }
if let Some(state) = self.states.get_mut(&res.destination_port) { state.peer_seq = res.sequence_number + res.payload.len() as u32;
if state.peer_seq > res.sequence_number { state.ack = res.sequence_number
// ignore retransmission + if res.payload.is_empty() && state.status != ConnectionStatus::Connected {
continue; 1
} } else {
res.payload.len() as u32
state.peer_seq = res.sequence_number + res.payload.len() as u32; };
state.ack = res.sequence_number if res.flags.psh || !res.payload.is_empty() {
+ if res.payload.is_empty() && state.status != ConnectionStatus::Connected { ack_me = Some(res.destination_port);
1 state.read_buffer.extend(res.payload);
} else {
res.payload.len() as u32
};
if res.flags.psh || !res.payload.is_empty() {
ack_me = Some(res.destination_port);
state.read_buffer.extend(res.payload);
}
if res.flags.rst {
warn!("stream rst");
state.status = ConnectionStatus::Error(ErrorKind::ConnectionReset);
}
if res.flags.fin {
ack_me = Some(res.destination_port);
state.status = ConnectionStatus::Error(ErrorKind::UnexpectedEof);
}
if res.flags.syn && res.flags.ack {
ack_me = Some(res.destination_port);
state.seq = state.seq.wrapping_add(1);
state.status = ConnectionStatus::Connected;
}
} }
if res.flags.rst {
// we have to ack outside of the mutable state borrow warn!("stream rst");
if let Some(a) = ack_me { state.status = ConnectionStatus::Error(ErrorKind::ConnectionReset);
self.ack(a).await?;
} }
break; if res.flags.fin {
ack_me = Some(res.destination_port);
state.status = ConnectionStatus::Error(ErrorKind::UnexpectedEof);
}
if res.flags.syn && res.flags.ack {
ack_me = Some(res.destination_port);
state.seq = state.seq.wrapping_add(1);
state.status = ConnectionStatus::Connected;
}
}
// we have to ack outside of the mutable state borrow
if let Some(a) = ack_me {
self.ack(a).await?;
} }
Ok(()) Ok(())
} }

View File

@@ -8,13 +8,16 @@ use std::{collections::HashMap, path::PathBuf, sync::Mutex, task::Poll};
use crossfire::{AsyncRx, MTx, Tx, mpsc, spsc, stream::AsyncStream}; use crossfire::{AsyncRx, MTx, Tx, mpsc, spsc, stream::AsyncStream};
use futures::{StreamExt, stream::FuturesUnordered}; use futures::{StreamExt, stream::FuturesUnordered};
use log::trace; use log::{debug, trace};
use tokio::{ use tokio::{
io::{AsyncRead, AsyncWrite}, io::{AsyncRead, AsyncReadExt, AsyncWrite},
sync::oneshot, sync::oneshot,
}; };
use crate::tcp::adapter::ConnectionStatus; use crate::tcp::{
adapter::ConnectionStatus,
packets::{IpParseError, Ipv6Packet},
};
pub type ConnectToPortRes = pub type ConnectToPortRes =
oneshot::Sender<Result<(u16, AsyncRx<Result<Vec<u8>, std::io::Error>>), std::io::Error>>; oneshot::Sender<Result<(u16, AsyncRx<Result<Vec<u8>, std::io::Error>>), std::io::Error>>;
@@ -50,6 +53,9 @@ impl AdapterHandle {
tokio::spawn(async move { tokio::spawn(async move {
let mut handles: HashMap<u16, Tx<Result<Vec<u8>, std::io::Error>>> = HashMap::new(); let mut handles: HashMap<u16, Tx<Result<Vec<u8>, std::io::Error>>> = HashMap::new();
let mut tick = tokio::time::interval(std::time::Duration::from_millis(1)); let mut tick = tokio::time::interval(std::time::Duration::from_millis(1));
let mut read_buf = [0u8; 4096];
let mut bytes_in_buf = 0;
loop { loop {
tokio::select! { tokio::select! {
// check for messages for us // check for messages for us
@@ -96,53 +102,85 @@ impl AdapterHandle {
} }
} }
r = adapter.process_tcp_packet() => { result = adapter.peer.read(&mut read_buf[bytes_in_buf..]) => {
if let Err(e) = r { match result {
// propagate error to all streams; close them Ok(0) => {
debug!("Underlying stream closed (EOF)");
break; // Exit the main actor loop
}
Ok(s) => {
bytes_in_buf += s;
loop {
match Ipv6Packet::parse(&read_buf[..bytes_in_buf]) {
IpParseError::Ok { packet, bytes_consumed } => {
// We got a full packet! Process it.
if let Err(e) = adapter.process_tcp_packet_from_payload(&packet.payload).await {
debug!("CRITICAL: Failed to process IP packet: {e:?}");
}
// And remove it from the buffer by shifting the remaining bytes
read_buf.copy_within(bytes_consumed..bytes_in_buf, 0);
bytes_in_buf -= bytes_consumed;
// Push any newly available bytes to per-conn channels
let mut dead = Vec::new();
for (&hp, tx) in &handles {
match adapter.uncache_all(hp) {
Ok(buf) if !buf.is_empty() => {
if tx.send(Ok(buf)).is_err() {
dead.push(hp);
}
}
Err(e) => {
let _ = tx.send(Err(e));
dead.push(hp);
}
_ => {}
}
}
for hp in dead {
handles.remove(&hp);
let _ = adapter.close(hp).await;
}
let mut to_close = Vec::new();
for (&hp, tx) in &handles {
if let Ok(ConnectionStatus::Error(kind)) = adapter.get_status(hp) {
if kind == std::io::ErrorKind::UnexpectedEof {
to_close.push(hp);
} else {
let _ = tx.send(Err(std::io::Error::from(kind)));
to_close.push(hp);
}
}
}
for hp in to_close {
handles.remove(&hp);
// Best-effort close. For RST this just tidies state on our side
let _ = adapter.close(hp).await;
}
}
IpParseError::NotEnough => {
// Buffer doesn't have a full packet, wait for the next read
break;
}
IpParseError::Invalid => {
// Corrupted data, close the connection
// ... (error handling) ...
return;
}
}
}
}
Err(e) => {
debug!("Failed to read: {e:?}, closing stack");
for (hp, tx) in handles.drain() { for (hp, tx) in handles.drain() {
let _ = tx.send(Err(e.kind().into())); // or clone/convert let _ = tx.send(Err(e.kind().into())); // or clone/convert
let _ = adapter.close(hp).await; let _ = adapter.close(hp).await;
} }
break; break;
}
// Push any newly available bytes to per-conn channels
let mut dead = Vec::new();
for (&hp, tx) in &handles {
match adapter.uncache_all(hp) {
Ok(buf) if !buf.is_empty() => {
if tx.send(Ok(buf)).is_err() {
dead.push(hp);
}
}
Err(e) => {
let _ = tx.send(Err(e));
dead.push(hp);
}
_ => {}
} }
} }
for hp in dead {
handles.remove(&hp);
let _ = adapter.close(hp).await;
}
let mut to_close = Vec::new();
for (&hp, tx) in &handles {
if let Ok(ConnectionStatus::Error(kind)) = adapter.get_status(hp) {
if kind == std::io::ErrorKind::UnexpectedEof {
to_close.push(hp);
} else {
let _ = tx.send(Err(std::io::Error::from(kind)));
to_close.push(hp);
}
}
}
for hp in to_close {
handles.remove(&hp);
// Best-effort close. For RST this just tidies state on our side
let _ = adapter.close(hp).await;
}
} }
_ = tick.tick() => { _ = tick.tick() => {

View File

@@ -6,6 +6,7 @@ use std::{
sync::Arc, sync::Arc,
}; };
use log::debug;
use tokio::{ use tokio::{
io::{AsyncRead, AsyncReadExt}, io::{AsyncRead, AsyncReadExt},
sync::Mutex, sync::Mutex,
@@ -109,6 +110,7 @@ impl Ipv4Packet {
let ihl = (version_ihl & 0x0F) * 4; let ihl = (version_ihl & 0x0F) * 4;
if version != 4 || ihl < 20 { if version != 4 || ihl < 20 {
debug!("Got an invalid IPv4 header");
return Err(std::io::Error::new( return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData, std::io::ErrorKind::InvalidData,
"Invalid IPv4 header", "Invalid IPv4 header",
@@ -220,21 +222,34 @@ pub struct Ipv6Packet {
pub payload: Vec<u8>, pub payload: Vec<u8>,
} }
#[derive(Debug, Clone)]
pub(crate) enum IpParseError<T> {
Ok { packet: T, bytes_consumed: usize },
NotEnough,
Invalid,
}
impl Ipv6Packet { impl Ipv6Packet {
pub fn parse(packet: &[u8]) -> Option<Self> { pub(crate) fn parse(packet: &[u8]) -> IpParseError<Ipv6Packet> {
if packet.len() < 40 { if packet.len() < 40 {
return None; return IpParseError::NotEnough;
} }
let version = packet[0] >> 4; let version = packet[0] >> 4;
if version != 6 { if version != 6 {
return None; return IpParseError::Invalid;
} }
let traffic_class = ((packet[0] & 0x0F) << 4) | (packet[1] >> 4); let traffic_class = ((packet[0] & 0x0F) << 4) | (packet[1] >> 4);
let flow_label = let flow_label =
((packet[1] as u32 & 0x0F) << 16) | ((packet[2] as u32) << 8) | packet[3] as u32; ((packet[1] as u32 & 0x0F) << 16) | ((packet[2] as u32) << 8) | packet[3] as u32;
let payload_length = u16::from_be_bytes([packet[4], packet[5]]); let payload_length = u16::from_be_bytes([packet[4], packet[5]]);
let total_packet_len = 40 + payload_length as usize;
if packet.len() < total_packet_len {
return IpParseError::NotEnough;
}
let next_header = packet[6]; let next_header = packet[6];
let hop_limit = packet[7]; let hop_limit = packet[7];
let source = Ipv6Addr::new( let source = Ipv6Addr::new(
@@ -258,19 +273,22 @@ impl Ipv6Packet {
u16::from_be_bytes([packet[36], packet[37]]), u16::from_be_bytes([packet[36], packet[37]]),
u16::from_be_bytes([packet[38], packet[39]]), u16::from_be_bytes([packet[38], packet[39]]),
); );
let payload = packet[40..].to_vec(); let payload = packet[40..total_packet_len].to_vec();
Some(Self { IpParseError::Ok {
version, packet: Self {
traffic_class, version,
flow_label, traffic_class,
payload_length, flow_label,
next_header, payload_length,
hop_limit, next_header,
source, hop_limit,
destination, source,
payload, destination,
}) payload,
},
bytes_consumed: total_packet_len,
}
} }
pub async fn from_reader<R: AsyncRead + Unpin>( pub async fn from_reader<R: AsyncRead + Unpin>(
@@ -278,14 +296,17 @@ impl Ipv6Packet {
log: &Option<Arc<Mutex<tokio::fs::File>>>, log: &Option<Arc<Mutex<tokio::fs::File>>>,
) -> Result<Self, std::io::Error> { ) -> Result<Self, std::io::Error> {
let mut log_packet = Vec::new(); let mut log_packet = Vec::new();
let mut header = [0u8; 40]; // IPv6 header size is fixed at 40 bytes let mut header = [0u8; 40]; // IPv6 header size is fixed at 40 bytes
reader.read_exact(&mut header).await?; reader.read_exact(&mut header).await?;
if log.is_some() { if log.is_some() {
log_packet.extend_from_slice(&header); log_packet.extend_from_slice(&header);
} }
let version = header[0] >> 4; let version = header[0] >> 4;
if version != 6 { if version != 6 {
debug!("Got an invalid IPv6 header");
return Err(std::io::Error::new( return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData, std::io::ErrorKind::InvalidData,
"Invalid IPv6 header", "Invalid IPv6 header",
@@ -457,6 +478,7 @@ pub struct TcpPacket {
impl TcpPacket { impl TcpPacket {
pub fn parse(packet: &[u8]) -> Result<Self, std::io::Error> { pub fn parse(packet: &[u8]) -> Result<Self, std::io::Error> {
if packet.len() < 20 { if packet.len() < 20 {
debug!("Got an invalid TCP header");
return Err(std::io::Error::new( return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData, std::io::ErrorKind::InvalidData,
"Not enough bytes for TCP header", "Not enough bytes for TCP header",