mirror of
https://github.com/jkcoxson/idevice.git
synced 2026-03-02 22:46:14 +01:00
Read IP packets outside of tokio::select to avoid cancel
This commit is contained in:
@@ -112,7 +112,7 @@ impl ConnectionState {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Adapter {
|
pub struct Adapter {
|
||||||
/// The underlying transport connection
|
/// The underlying transport connection
|
||||||
peer: Box<dyn ReadWrite>,
|
pub(crate) peer: Box<dyn ReadWrite>,
|
||||||
/// The local IP address
|
/// The local IP address
|
||||||
host_ip: IpAddr,
|
host_ip: IpAddr,
|
||||||
/// The remote peer's IP address
|
/// The remote peer's IP address
|
||||||
@@ -538,63 +538,67 @@ impl Adapter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn process_tcp_packet(&mut self) -> Result<(), std::io::Error> {
|
pub(crate) async fn process_tcp_packet(&mut self) -> Result<(), std::io::Error> {
|
||||||
loop {
|
let ip_packet = self.read_ip_packet().await?;
|
||||||
let ip_packet = self.read_ip_packet().await?;
|
self.process_tcp_packet_from_payload(&ip_packet).await
|
||||||
let res = TcpPacket::parse(&ip_packet)?;
|
}
|
||||||
let mut ack_me = None;
|
|
||||||
|
|
||||||
if let Some(state) = self.states.get(&res.destination_port) {
|
pub(crate) async fn process_tcp_packet_from_payload(
|
||||||
// A keep-alive probe: ACK set, no payload, and seq == RCV.NXT - 1
|
&mut self,
|
||||||
let is_keepalive = res.flags.ack
|
payload: &[u8],
|
||||||
&& res.payload.is_empty()
|
) -> Result<(), std::io::Error> {
|
||||||
&& res.sequence_number.wrapping_add(1) == state.ack;
|
let res = TcpPacket::parse(payload)?;
|
||||||
|
let mut ack_me = None;
|
||||||
|
|
||||||
if is_keepalive {
|
if let Some(state) = self.states.get(&res.destination_port) {
|
||||||
// Don't update any seq/ack state; just ACK what we already expect.
|
// A keep-alive probe: ACK set, no payload, and seq == RCV.NXT - 1
|
||||||
debug!("responding to keep-alive probe");
|
let is_keepalive = res.flags.ack
|
||||||
let port = res.destination_port;
|
&& res.payload.is_empty()
|
||||||
self.ack(port).await?;
|
&& res.sequence_number.wrapping_add(1) == state.ack;
|
||||||
break;
|
|
||||||
}
|
if is_keepalive {
|
||||||
|
// Don't update any seq/ack state; just ACK what we already expect.
|
||||||
|
debug!("responding to keep-alive probe");
|
||||||
|
let port = res.destination_port;
|
||||||
|
self.ack(port).await?;
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(state) = self.states.get_mut(&res.destination_port) {
|
||||||
|
if state.peer_seq > res.sequence_number {
|
||||||
|
// ignore retransmission
|
||||||
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(state) = self.states.get_mut(&res.destination_port) {
|
state.peer_seq = res.sequence_number + res.payload.len() as u32;
|
||||||
if state.peer_seq > res.sequence_number {
|
state.ack = res.sequence_number
|
||||||
// ignore retransmission
|
+ if res.payload.is_empty() && state.status != ConnectionStatus::Connected {
|
||||||
continue;
|
1
|
||||||
}
|
} else {
|
||||||
|
res.payload.len() as u32
|
||||||
state.peer_seq = res.sequence_number + res.payload.len() as u32;
|
};
|
||||||
state.ack = res.sequence_number
|
if res.flags.psh || !res.payload.is_empty() {
|
||||||
+ if res.payload.is_empty() && state.status != ConnectionStatus::Connected {
|
ack_me = Some(res.destination_port);
|
||||||
1
|
state.read_buffer.extend(res.payload);
|
||||||
} else {
|
|
||||||
res.payload.len() as u32
|
|
||||||
};
|
|
||||||
if res.flags.psh || !res.payload.is_empty() {
|
|
||||||
ack_me = Some(res.destination_port);
|
|
||||||
state.read_buffer.extend(res.payload);
|
|
||||||
}
|
|
||||||
if res.flags.rst {
|
|
||||||
warn!("stream rst");
|
|
||||||
state.status = ConnectionStatus::Error(ErrorKind::ConnectionReset);
|
|
||||||
}
|
|
||||||
if res.flags.fin {
|
|
||||||
ack_me = Some(res.destination_port);
|
|
||||||
state.status = ConnectionStatus::Error(ErrorKind::UnexpectedEof);
|
|
||||||
}
|
|
||||||
if res.flags.syn && res.flags.ack {
|
|
||||||
ack_me = Some(res.destination_port);
|
|
||||||
state.seq = state.seq.wrapping_add(1);
|
|
||||||
state.status = ConnectionStatus::Connected;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
if res.flags.rst {
|
||||||
// we have to ack outside of the mutable state borrow
|
warn!("stream rst");
|
||||||
if let Some(a) = ack_me {
|
state.status = ConnectionStatus::Error(ErrorKind::ConnectionReset);
|
||||||
self.ack(a).await?;
|
|
||||||
}
|
}
|
||||||
break;
|
if res.flags.fin {
|
||||||
|
ack_me = Some(res.destination_port);
|
||||||
|
state.status = ConnectionStatus::Error(ErrorKind::UnexpectedEof);
|
||||||
|
}
|
||||||
|
if res.flags.syn && res.flags.ack {
|
||||||
|
ack_me = Some(res.destination_port);
|
||||||
|
state.seq = state.seq.wrapping_add(1);
|
||||||
|
state.status = ConnectionStatus::Connected;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// we have to ack outside of the mutable state borrow
|
||||||
|
if let Some(a) = ack_me {
|
||||||
|
self.ack(a).await?;
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,13 +8,16 @@ use std::{collections::HashMap, path::PathBuf, sync::Mutex, task::Poll};
|
|||||||
|
|
||||||
use crossfire::{AsyncRx, MTx, Tx, mpsc, spsc, stream::AsyncStream};
|
use crossfire::{AsyncRx, MTx, Tx, mpsc, spsc, stream::AsyncStream};
|
||||||
use futures::{StreamExt, stream::FuturesUnordered};
|
use futures::{StreamExt, stream::FuturesUnordered};
|
||||||
use log::trace;
|
use log::{debug, trace};
|
||||||
use tokio::{
|
use tokio::{
|
||||||
io::{AsyncRead, AsyncWrite},
|
io::{AsyncRead, AsyncReadExt, AsyncWrite},
|
||||||
sync::oneshot,
|
sync::oneshot,
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::tcp::adapter::ConnectionStatus;
|
use crate::tcp::{
|
||||||
|
adapter::ConnectionStatus,
|
||||||
|
packets::{IpParseError, Ipv6Packet},
|
||||||
|
};
|
||||||
|
|
||||||
pub type ConnectToPortRes =
|
pub type ConnectToPortRes =
|
||||||
oneshot::Sender<Result<(u16, AsyncRx<Result<Vec<u8>, std::io::Error>>), std::io::Error>>;
|
oneshot::Sender<Result<(u16, AsyncRx<Result<Vec<u8>, std::io::Error>>), std::io::Error>>;
|
||||||
@@ -50,6 +53,9 @@ impl AdapterHandle {
|
|||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let mut handles: HashMap<u16, Tx<Result<Vec<u8>, std::io::Error>>> = HashMap::new();
|
let mut handles: HashMap<u16, Tx<Result<Vec<u8>, std::io::Error>>> = HashMap::new();
|
||||||
let mut tick = tokio::time::interval(std::time::Duration::from_millis(1));
|
let mut tick = tokio::time::interval(std::time::Duration::from_millis(1));
|
||||||
|
|
||||||
|
let mut read_buf = [0u8; 4096];
|
||||||
|
let mut bytes_in_buf = 0;
|
||||||
loop {
|
loop {
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
// check for messages for us
|
// check for messages for us
|
||||||
@@ -96,53 +102,85 @@ impl AdapterHandle {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
r = adapter.process_tcp_packet() => {
|
result = adapter.peer.read(&mut read_buf[bytes_in_buf..]) => {
|
||||||
if let Err(e) = r {
|
match result {
|
||||||
// propagate error to all streams; close them
|
Ok(0) => {
|
||||||
|
debug!("Underlying stream closed (EOF)");
|
||||||
|
break; // Exit the main actor loop
|
||||||
|
}
|
||||||
|
Ok(s) => {
|
||||||
|
bytes_in_buf += s;
|
||||||
|
loop {
|
||||||
|
match Ipv6Packet::parse(&read_buf[..bytes_in_buf]) {
|
||||||
|
IpParseError::Ok { packet, bytes_consumed } => {
|
||||||
|
// We got a full packet! Process it.
|
||||||
|
if let Err(e) = adapter.process_tcp_packet_from_payload(&packet.payload).await {
|
||||||
|
debug!("CRITICAL: Failed to process IP packet: {e:?}");
|
||||||
|
}
|
||||||
|
|
||||||
|
// And remove it from the buffer by shifting the remaining bytes
|
||||||
|
read_buf.copy_within(bytes_consumed..bytes_in_buf, 0);
|
||||||
|
bytes_in_buf -= bytes_consumed;
|
||||||
|
// Push any newly available bytes to per-conn channels
|
||||||
|
let mut dead = Vec::new();
|
||||||
|
for (&hp, tx) in &handles {
|
||||||
|
match adapter.uncache_all(hp) {
|
||||||
|
Ok(buf) if !buf.is_empty() => {
|
||||||
|
if tx.send(Ok(buf)).is_err() {
|
||||||
|
dead.push(hp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let _ = tx.send(Err(e));
|
||||||
|
dead.push(hp);
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for hp in dead {
|
||||||
|
handles.remove(&hp);
|
||||||
|
let _ = adapter.close(hp).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut to_close = Vec::new();
|
||||||
|
for (&hp, tx) in &handles {
|
||||||
|
if let Ok(ConnectionStatus::Error(kind)) = adapter.get_status(hp) {
|
||||||
|
if kind == std::io::ErrorKind::UnexpectedEof {
|
||||||
|
to_close.push(hp);
|
||||||
|
} else {
|
||||||
|
let _ = tx.send(Err(std::io::Error::from(kind)));
|
||||||
|
to_close.push(hp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for hp in to_close {
|
||||||
|
handles.remove(&hp);
|
||||||
|
// Best-effort close. For RST this just tidies state on our side
|
||||||
|
let _ = adapter.close(hp).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
IpParseError::NotEnough => {
|
||||||
|
// Buffer doesn't have a full packet, wait for the next read
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
IpParseError::Invalid => {
|
||||||
|
// Corrupted data, close the connection
|
||||||
|
// ... (error handling) ...
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
debug!("Failed to read: {e:?}, closing stack");
|
||||||
for (hp, tx) in handles.drain() {
|
for (hp, tx) in handles.drain() {
|
||||||
let _ = tx.send(Err(e.kind().into())); // or clone/convert
|
let _ = tx.send(Err(e.kind().into())); // or clone/convert
|
||||||
let _ = adapter.close(hp).await;
|
let _ = adapter.close(hp).await;
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
|
||||||
|
|
||||||
// Push any newly available bytes to per-conn channels
|
|
||||||
let mut dead = Vec::new();
|
|
||||||
for (&hp, tx) in &handles {
|
|
||||||
match adapter.uncache_all(hp) {
|
|
||||||
Ok(buf) if !buf.is_empty() => {
|
|
||||||
if tx.send(Ok(buf)).is_err() {
|
|
||||||
dead.push(hp);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
let _ = tx.send(Err(e));
|
|
||||||
dead.push(hp);
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for hp in dead {
|
|
||||||
handles.remove(&hp);
|
|
||||||
let _ = adapter.close(hp).await;
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut to_close = Vec::new();
|
|
||||||
for (&hp, tx) in &handles {
|
|
||||||
if let Ok(ConnectionStatus::Error(kind)) = adapter.get_status(hp) {
|
|
||||||
if kind == std::io::ErrorKind::UnexpectedEof {
|
|
||||||
to_close.push(hp);
|
|
||||||
} else {
|
|
||||||
let _ = tx.send(Err(std::io::Error::from(kind)));
|
|
||||||
to_close.push(hp);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for hp in to_close {
|
|
||||||
handles.remove(&hp);
|
|
||||||
// Best-effort close. For RST this just tidies state on our side
|
|
||||||
let _ = adapter.close(hp).await;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = tick.tick() => {
|
_ = tick.tick() => {
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ use std::{
|
|||||||
sync::Arc,
|
sync::Arc,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use log::debug;
|
||||||
use tokio::{
|
use tokio::{
|
||||||
io::{AsyncRead, AsyncReadExt},
|
io::{AsyncRead, AsyncReadExt},
|
||||||
sync::Mutex,
|
sync::Mutex,
|
||||||
@@ -109,6 +110,7 @@ impl Ipv4Packet {
|
|||||||
let ihl = (version_ihl & 0x0F) * 4;
|
let ihl = (version_ihl & 0x0F) * 4;
|
||||||
|
|
||||||
if version != 4 || ihl < 20 {
|
if version != 4 || ihl < 20 {
|
||||||
|
debug!("Got an invalid IPv4 header");
|
||||||
return Err(std::io::Error::new(
|
return Err(std::io::Error::new(
|
||||||
std::io::ErrorKind::InvalidData,
|
std::io::ErrorKind::InvalidData,
|
||||||
"Invalid IPv4 header",
|
"Invalid IPv4 header",
|
||||||
@@ -220,21 +222,34 @@ pub struct Ipv6Packet {
|
|||||||
pub payload: Vec<u8>,
|
pub payload: Vec<u8>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub(crate) enum IpParseError<T> {
|
||||||
|
Ok { packet: T, bytes_consumed: usize },
|
||||||
|
NotEnough,
|
||||||
|
Invalid,
|
||||||
|
}
|
||||||
|
|
||||||
impl Ipv6Packet {
|
impl Ipv6Packet {
|
||||||
pub fn parse(packet: &[u8]) -> Option<Self> {
|
pub(crate) fn parse(packet: &[u8]) -> IpParseError<Ipv6Packet> {
|
||||||
if packet.len() < 40 {
|
if packet.len() < 40 {
|
||||||
return None;
|
return IpParseError::NotEnough;
|
||||||
}
|
}
|
||||||
|
|
||||||
let version = packet[0] >> 4;
|
let version = packet[0] >> 4;
|
||||||
if version != 6 {
|
if version != 6 {
|
||||||
return None;
|
return IpParseError::Invalid;
|
||||||
}
|
}
|
||||||
|
|
||||||
let traffic_class = ((packet[0] & 0x0F) << 4) | (packet[1] >> 4);
|
let traffic_class = ((packet[0] & 0x0F) << 4) | (packet[1] >> 4);
|
||||||
let flow_label =
|
let flow_label =
|
||||||
((packet[1] as u32 & 0x0F) << 16) | ((packet[2] as u32) << 8) | packet[3] as u32;
|
((packet[1] as u32 & 0x0F) << 16) | ((packet[2] as u32) << 8) | packet[3] as u32;
|
||||||
let payload_length = u16::from_be_bytes([packet[4], packet[5]]);
|
let payload_length = u16::from_be_bytes([packet[4], packet[5]]);
|
||||||
|
let total_packet_len = 40 + payload_length as usize;
|
||||||
|
|
||||||
|
if packet.len() < total_packet_len {
|
||||||
|
return IpParseError::NotEnough;
|
||||||
|
}
|
||||||
|
|
||||||
let next_header = packet[6];
|
let next_header = packet[6];
|
||||||
let hop_limit = packet[7];
|
let hop_limit = packet[7];
|
||||||
let source = Ipv6Addr::new(
|
let source = Ipv6Addr::new(
|
||||||
@@ -258,19 +273,22 @@ impl Ipv6Packet {
|
|||||||
u16::from_be_bytes([packet[36], packet[37]]),
|
u16::from_be_bytes([packet[36], packet[37]]),
|
||||||
u16::from_be_bytes([packet[38], packet[39]]),
|
u16::from_be_bytes([packet[38], packet[39]]),
|
||||||
);
|
);
|
||||||
let payload = packet[40..].to_vec();
|
let payload = packet[40..total_packet_len].to_vec();
|
||||||
|
|
||||||
Some(Self {
|
IpParseError::Ok {
|
||||||
version,
|
packet: Self {
|
||||||
traffic_class,
|
version,
|
||||||
flow_label,
|
traffic_class,
|
||||||
payload_length,
|
flow_label,
|
||||||
next_header,
|
payload_length,
|
||||||
hop_limit,
|
next_header,
|
||||||
source,
|
hop_limit,
|
||||||
destination,
|
source,
|
||||||
payload,
|
destination,
|
||||||
})
|
payload,
|
||||||
|
},
|
||||||
|
bytes_consumed: total_packet_len,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn from_reader<R: AsyncRead + Unpin>(
|
pub async fn from_reader<R: AsyncRead + Unpin>(
|
||||||
@@ -278,14 +296,17 @@ impl Ipv6Packet {
|
|||||||
log: &Option<Arc<Mutex<tokio::fs::File>>>,
|
log: &Option<Arc<Mutex<tokio::fs::File>>>,
|
||||||
) -> Result<Self, std::io::Error> {
|
) -> Result<Self, std::io::Error> {
|
||||||
let mut log_packet = Vec::new();
|
let mut log_packet = Vec::new();
|
||||||
|
|
||||||
let mut header = [0u8; 40]; // IPv6 header size is fixed at 40 bytes
|
let mut header = [0u8; 40]; // IPv6 header size is fixed at 40 bytes
|
||||||
reader.read_exact(&mut header).await?;
|
reader.read_exact(&mut header).await?;
|
||||||
|
|
||||||
if log.is_some() {
|
if log.is_some() {
|
||||||
log_packet.extend_from_slice(&header);
|
log_packet.extend_from_slice(&header);
|
||||||
}
|
}
|
||||||
|
|
||||||
let version = header[0] >> 4;
|
let version = header[0] >> 4;
|
||||||
if version != 6 {
|
if version != 6 {
|
||||||
|
debug!("Got an invalid IPv6 header");
|
||||||
return Err(std::io::Error::new(
|
return Err(std::io::Error::new(
|
||||||
std::io::ErrorKind::InvalidData,
|
std::io::ErrorKind::InvalidData,
|
||||||
"Invalid IPv6 header",
|
"Invalid IPv6 header",
|
||||||
@@ -457,6 +478,7 @@ pub struct TcpPacket {
|
|||||||
impl TcpPacket {
|
impl TcpPacket {
|
||||||
pub fn parse(packet: &[u8]) -> Result<Self, std::io::Error> {
|
pub fn parse(packet: &[u8]) -> Result<Self, std::io::Error> {
|
||||||
if packet.len() < 20 {
|
if packet.len() < 20 {
|
||||||
|
debug!("Got an invalid TCP header");
|
||||||
return Err(std::io::Error::new(
|
return Err(std::io::Error::new(
|
||||||
std::io::ErrorKind::InvalidData,
|
std::io::ErrorKind::InvalidData,
|
||||||
"Not enough bytes for TCP header",
|
"Not enough bytes for TCP header",
|
||||||
|
|||||||
Reference in New Issue
Block a user