From 6ce6777479af3c1b921a25a73ffa23d6f679180f Mon Sep 17 00:00:00 2001 From: Jackson Coxson Date: Wed, 20 Aug 2025 12:41:09 -0600 Subject: [PATCH] Read IP packets outside of tokio::select to avoid cancel --- idevice/src/tcp/adapter.rs | 108 ++++++++++++++++--------------- idevice/src/tcp/handle.rs | 126 ++++++++++++++++++++++++------------- idevice/src/tcp/packets.rs | 52 ++++++++++----- 3 files changed, 175 insertions(+), 111 deletions(-) diff --git a/idevice/src/tcp/adapter.rs b/idevice/src/tcp/adapter.rs index c8d85f1..287b59f 100644 --- a/idevice/src/tcp/adapter.rs +++ b/idevice/src/tcp/adapter.rs @@ -112,7 +112,7 @@ impl ConnectionState { #[derive(Debug)] pub struct Adapter { /// The underlying transport connection - peer: Box, + pub(crate) peer: Box, /// The local IP address host_ip: IpAddr, /// The remote peer's IP address @@ -538,63 +538,67 @@ impl Adapter { } pub(crate) async fn process_tcp_packet(&mut self) -> Result<(), std::io::Error> { - loop { - let ip_packet = self.read_ip_packet().await?; - let res = TcpPacket::parse(&ip_packet)?; - let mut ack_me = None; + let ip_packet = self.read_ip_packet().await?; + self.process_tcp_packet_from_payload(&ip_packet).await + } - if let Some(state) = self.states.get(&res.destination_port) { - // A keep-alive probe: ACK set, no payload, and seq == RCV.NXT - 1 - let is_keepalive = res.flags.ack - && res.payload.is_empty() - && res.sequence_number.wrapping_add(1) == state.ack; + pub(crate) async fn process_tcp_packet_from_payload( + &mut self, + payload: &[u8], + ) -> Result<(), std::io::Error> { + let res = TcpPacket::parse(payload)?; + let mut ack_me = None; - if is_keepalive { - // Don't update any seq/ack state; just ACK what we already expect. - debug!("responding to keep-alive probe"); - let port = res.destination_port; - self.ack(port).await?; - break; - } + if let Some(state) = self.states.get(&res.destination_port) { + // A keep-alive probe: ACK set, no payload, and seq == RCV.NXT - 1 + let is_keepalive = res.flags.ack + && res.payload.is_empty() + && res.sequence_number.wrapping_add(1) == state.ack; + + if is_keepalive { + // Don't update any seq/ack state; just ACK what we already expect. + debug!("responding to keep-alive probe"); + let port = res.destination_port; + self.ack(port).await?; + return Ok(()); + } + } + + if let Some(state) = self.states.get_mut(&res.destination_port) { + if state.peer_seq > res.sequence_number { + // ignore retransmission + return Ok(()); } - if let Some(state) = self.states.get_mut(&res.destination_port) { - if state.peer_seq > res.sequence_number { - // ignore retransmission - continue; - } - - state.peer_seq = res.sequence_number + res.payload.len() as u32; - state.ack = res.sequence_number - + if res.payload.is_empty() && state.status != ConnectionStatus::Connected { - 1 - } else { - res.payload.len() as u32 - }; - if res.flags.psh || !res.payload.is_empty() { - ack_me = Some(res.destination_port); - state.read_buffer.extend(res.payload); - } - if res.flags.rst { - warn!("stream rst"); - state.status = ConnectionStatus::Error(ErrorKind::ConnectionReset); - } - if res.flags.fin { - ack_me = Some(res.destination_port); - state.status = ConnectionStatus::Error(ErrorKind::UnexpectedEof); - } - if res.flags.syn && res.flags.ack { - ack_me = Some(res.destination_port); - state.seq = state.seq.wrapping_add(1); - state.status = ConnectionStatus::Connected; - } + state.peer_seq = res.sequence_number + res.payload.len() as u32; + state.ack = res.sequence_number + + if res.payload.is_empty() && state.status != ConnectionStatus::Connected { + 1 + } else { + res.payload.len() as u32 + }; + if res.flags.psh || !res.payload.is_empty() { + ack_me = Some(res.destination_port); + state.read_buffer.extend(res.payload); } - - // we have to ack outside of the mutable state borrow - if let Some(a) = ack_me { - self.ack(a).await?; + if res.flags.rst { + warn!("stream rst"); + state.status = ConnectionStatus::Error(ErrorKind::ConnectionReset); } - break; + if res.flags.fin { + ack_me = Some(res.destination_port); + state.status = ConnectionStatus::Error(ErrorKind::UnexpectedEof); + } + if res.flags.syn && res.flags.ack { + ack_me = Some(res.destination_port); + state.seq = state.seq.wrapping_add(1); + state.status = ConnectionStatus::Connected; + } + } + + // we have to ack outside of the mutable state borrow + if let Some(a) = ack_me { + self.ack(a).await?; } Ok(()) } diff --git a/idevice/src/tcp/handle.rs b/idevice/src/tcp/handle.rs index 3434f9f..b7049a9 100644 --- a/idevice/src/tcp/handle.rs +++ b/idevice/src/tcp/handle.rs @@ -8,13 +8,16 @@ use std::{collections::HashMap, path::PathBuf, sync::Mutex, task::Poll}; use crossfire::{AsyncRx, MTx, Tx, mpsc, spsc, stream::AsyncStream}; use futures::{StreamExt, stream::FuturesUnordered}; -use log::trace; +use log::{debug, trace}; use tokio::{ - io::{AsyncRead, AsyncWrite}, + io::{AsyncRead, AsyncReadExt, AsyncWrite}, sync::oneshot, }; -use crate::tcp::adapter::ConnectionStatus; +use crate::tcp::{ + adapter::ConnectionStatus, + packets::{IpParseError, Ipv6Packet}, +}; pub type ConnectToPortRes = oneshot::Sender, std::io::Error>>), std::io::Error>>; @@ -50,6 +53,9 @@ impl AdapterHandle { tokio::spawn(async move { let mut handles: HashMap, std::io::Error>>> = HashMap::new(); let mut tick = tokio::time::interval(std::time::Duration::from_millis(1)); + + let mut read_buf = [0u8; 4096]; + let mut bytes_in_buf = 0; loop { tokio::select! { // check for messages for us @@ -96,53 +102,85 @@ impl AdapterHandle { } } - r = adapter.process_tcp_packet() => { - if let Err(e) = r { - // propagate error to all streams; close them + result = adapter.peer.read(&mut read_buf[bytes_in_buf..]) => { + match result { + Ok(0) => { + debug!("Underlying stream closed (EOF)"); + break; // Exit the main actor loop + } + Ok(s) => { + bytes_in_buf += s; + loop { + match Ipv6Packet::parse(&read_buf[..bytes_in_buf]) { + IpParseError::Ok { packet, bytes_consumed } => { + // We got a full packet! Process it. + if let Err(e) = adapter.process_tcp_packet_from_payload(&packet.payload).await { + debug!("CRITICAL: Failed to process IP packet: {e:?}"); + } + + // And remove it from the buffer by shifting the remaining bytes + read_buf.copy_within(bytes_consumed..bytes_in_buf, 0); + bytes_in_buf -= bytes_consumed; + // Push any newly available bytes to per-conn channels + let mut dead = Vec::new(); + for (&hp, tx) in &handles { + match adapter.uncache_all(hp) { + Ok(buf) if !buf.is_empty() => { + if tx.send(Ok(buf)).is_err() { + dead.push(hp); + } + } + Err(e) => { + let _ = tx.send(Err(e)); + dead.push(hp); + } + _ => {} + } + } + for hp in dead { + handles.remove(&hp); + let _ = adapter.close(hp).await; + } + + let mut to_close = Vec::new(); + for (&hp, tx) in &handles { + if let Ok(ConnectionStatus::Error(kind)) = adapter.get_status(hp) { + if kind == std::io::ErrorKind::UnexpectedEof { + to_close.push(hp); + } else { + let _ = tx.send(Err(std::io::Error::from(kind))); + to_close.push(hp); + } + } + } + for hp in to_close { + handles.remove(&hp); + // Best-effort close. For RST this just tidies state on our side + let _ = adapter.close(hp).await; + } + } + IpParseError::NotEnough => { + // Buffer doesn't have a full packet, wait for the next read + break; + } + IpParseError::Invalid => { + // Corrupted data, close the connection + // ... (error handling) ... + return; + } + } + } + + } + Err(e) => { + debug!("Failed to read: {e:?}, closing stack"); for (hp, tx) in handles.drain() { let _ = tx.send(Err(e.kind().into())); // or clone/convert let _ = adapter.close(hp).await; } - break; - } - - // Push any newly available bytes to per-conn channels - let mut dead = Vec::new(); - for (&hp, tx) in &handles { - match adapter.uncache_all(hp) { - Ok(buf) if !buf.is_empty() => { - if tx.send(Ok(buf)).is_err() { - dead.push(hp); - } - } - Err(e) => { - let _ = tx.send(Err(e)); - dead.push(hp); - } - _ => {} + break; } } - for hp in dead { - handles.remove(&hp); - let _ = adapter.close(hp).await; - } - - let mut to_close = Vec::new(); - for (&hp, tx) in &handles { - if let Ok(ConnectionStatus::Error(kind)) = adapter.get_status(hp) { - if kind == std::io::ErrorKind::UnexpectedEof { - to_close.push(hp); - } else { - let _ = tx.send(Err(std::io::Error::from(kind))); - to_close.push(hp); - } - } - } - for hp in to_close { - handles.remove(&hp); - // Best-effort close. For RST this just tidies state on our side - let _ = adapter.close(hp).await; - } } _ = tick.tick() => { diff --git a/idevice/src/tcp/packets.rs b/idevice/src/tcp/packets.rs index fa9e94f..0186ebe 100644 --- a/idevice/src/tcp/packets.rs +++ b/idevice/src/tcp/packets.rs @@ -6,6 +6,7 @@ use std::{ sync::Arc, }; +use log::debug; use tokio::{ io::{AsyncRead, AsyncReadExt}, sync::Mutex, @@ -109,6 +110,7 @@ impl Ipv4Packet { let ihl = (version_ihl & 0x0F) * 4; if version != 4 || ihl < 20 { + debug!("Got an invalid IPv4 header"); return Err(std::io::Error::new( std::io::ErrorKind::InvalidData, "Invalid IPv4 header", @@ -220,21 +222,34 @@ pub struct Ipv6Packet { pub payload: Vec, } +#[derive(Debug, Clone)] +pub(crate) enum IpParseError { + Ok { packet: T, bytes_consumed: usize }, + NotEnough, + Invalid, +} + impl Ipv6Packet { - pub fn parse(packet: &[u8]) -> Option { + pub(crate) fn parse(packet: &[u8]) -> IpParseError { if packet.len() < 40 { - return None; + return IpParseError::NotEnough; } let version = packet[0] >> 4; if version != 6 { - return None; + return IpParseError::Invalid; } let traffic_class = ((packet[0] & 0x0F) << 4) | (packet[1] >> 4); let flow_label = ((packet[1] as u32 & 0x0F) << 16) | ((packet[2] as u32) << 8) | packet[3] as u32; let payload_length = u16::from_be_bytes([packet[4], packet[5]]); + let total_packet_len = 40 + payload_length as usize; + + if packet.len() < total_packet_len { + return IpParseError::NotEnough; + } + let next_header = packet[6]; let hop_limit = packet[7]; let source = Ipv6Addr::new( @@ -258,19 +273,22 @@ impl Ipv6Packet { u16::from_be_bytes([packet[36], packet[37]]), u16::from_be_bytes([packet[38], packet[39]]), ); - let payload = packet[40..].to_vec(); + let payload = packet[40..total_packet_len].to_vec(); - Some(Self { - version, - traffic_class, - flow_label, - payload_length, - next_header, - hop_limit, - source, - destination, - payload, - }) + IpParseError::Ok { + packet: Self { + version, + traffic_class, + flow_label, + payload_length, + next_header, + hop_limit, + source, + destination, + payload, + }, + bytes_consumed: total_packet_len, + } } pub async fn from_reader( @@ -278,14 +296,17 @@ impl Ipv6Packet { log: &Option>>, ) -> Result { let mut log_packet = Vec::new(); + let mut header = [0u8; 40]; // IPv6 header size is fixed at 40 bytes reader.read_exact(&mut header).await?; + if log.is_some() { log_packet.extend_from_slice(&header); } let version = header[0] >> 4; if version != 6 { + debug!("Got an invalid IPv6 header"); return Err(std::io::Error::new( std::io::ErrorKind::InvalidData, "Invalid IPv6 header", @@ -457,6 +478,7 @@ pub struct TcpPacket { impl TcpPacket { pub fn parse(packet: &[u8]) -> Result { if packet.len() < 20 { + debug!("Got an invalid TCP header"); return Err(std::io::Error::new( std::io::ErrorKind::InvalidData, "Not enough bytes for TCP header",