Rewrite TCP stack for multiple streams

This commit is contained in:
Jackson Coxson
2025-05-23 01:20:09 -06:00
parent 6bf32afe82
commit 525136662e
7 changed files with 470 additions and 314 deletions

View File

@@ -50,7 +50,7 @@ x509-cert = { version = "0.2", optional = true, features = [
], default-features = false } ], default-features = false }
[dev-dependencies] [dev-dependencies]
tokio = { version = "1.43", features = ["fs"] } tokio = { version = "1.43", features = ["full"] }
tun-rs = { version = "2.0.8", features = ["async_tokio"] } tun-rs = { version = "2.0.8", features = ["async_tokio"] }
bytes = "1.10.1" bytes = "1.10.1"

View File

@@ -20,6 +20,8 @@ pub mod xpc;
pub mod services; pub mod services;
pub use services::*; pub use services::*;
#[cfg(feature = "xpc")]
pub use xpc::RemoteXpcClient; pub use xpc::RemoteXpcClient;
use log::{debug, error, trace}; use log::{debug, error, trace};

View File

@@ -61,22 +61,45 @@
//! This implementation makes significant simplifications and should not be used //! This implementation makes significant simplifications and should not be used
//! in production environments or with unreliable network transports. //! in production environments or with unreliable network transports.
use std::{future::Future, net::IpAddr, path::Path, sync::Arc, task::Poll}; use std::{collections::HashMap, io::ErrorKind, net::IpAddr, path::Path, sync::Arc};
use log::trace; use log::trace;
use tokio::{ use tokio::{io::AsyncWriteExt, sync::Mutex};
io::{AsyncRead, AsyncWrite, AsyncWriteExt},
sync::Mutex,
};
use crate::ReadWrite; use crate::ReadWrite;
use super::packets::{Ipv4Packet, Ipv6Packet, ProtocolNumber, TcpFlags, TcpPacket}; use super::packets::{Ipv4Packet, Ipv6Packet, ProtocolNumber, TcpFlags, TcpPacket};
#[derive(Clone, Debug, PartialEq)] #[derive(Debug, Clone)]
enum AdapterState { struct ConnectionState {
seq: u32,
ack: u32,
host_port: u16,
peer_port: u16,
read_buffer: Vec<u8>,
write_buffer: Vec<u8>,
status: ConnectionStatus,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub(crate) enum ConnectionStatus {
WaitingForSyn,
Connected, Connected,
None, Error(ErrorKind),
}
impl ConnectionState {
fn new(host_port: u16, peer_port: u16) -> Self {
Self {
seq: rand::random(),
ack: 0,
host_port,
peer_port,
read_buffer: Vec::new(),
write_buffer: Vec::new(),
status: ConnectionStatus::WaitingForSyn,
}
}
} }
/// A simplified TCP network stack implementation. /// A simplified TCP network stack implementation.
@@ -96,24 +119,8 @@ pub struct Adapter {
host_ip: IpAddr, host_ip: IpAddr,
/// The remote peer's IP address /// The remote peer's IP address
peer_ip: IpAddr, peer_ip: IpAddr,
/// Current connection state
state: AdapterState,
// TCP state states: HashMap<u16, ConnectionState>, // host port by state
/// Current sequence number
seq: u32,
/// Current acknowledgement number
ack: u32,
/// Local port number
host_port: u16,
/// Remote port number
peer_port: u16,
// Read buffer to cache unused bytes
/// Buffer for storing unread received data
read_buffer: Vec<u8>,
/// Buffer for storing data to be sent
write_buffer: Vec<u8>,
// Logging // Logging
/// Optional PCAP file for packet logging /// Optional PCAP file for packet logging
@@ -135,13 +142,7 @@ impl Adapter {
peer, peer,
host_ip, host_ip,
peer_ip, peer_ip,
state: AdapterState::None, states: HashMap::new(),
seq: 0,
ack: 0,
host_port: 1024,
peer_port: 1024,
read_buffer: Vec::new(),
write_buffer: Vec::new(),
pcap: None, pcap: None,
} }
} }
@@ -158,26 +159,25 @@ impl Adapter {
/// # Errors /// # Errors
/// * Returns `InvalidData` if the SYN-ACK response is invalid /// * Returns `InvalidData` if the SYN-ACK response is invalid
/// * Returns other IO errors if underlying transport fails /// * Returns other IO errors if underlying transport fails
pub async fn connect(&mut self, port: u16) -> Result<(), std::io::Error> { pub(crate) async fn connect(&mut self, port: u16) -> Result<u16, std::io::Error> {
self.read_buffer = Vec::new(); let host_port = loop {
self.write_buffer = Vec::new(); let host_port: u16 = rand::random();
if self.states.contains_key(&host_port) {
// Randomize seq continue;
self.seq = rand::random(); } else {
self.ack = 0; break host_port;
}
// Choose a random port };
self.host_port = rand::random(); let state = ConnectionState::new(host_port, port);
self.peer_port = port;
// Create the TCP packet // Create the TCP packet
let tcp_packet = TcpPacket::create( let tcp_packet = TcpPacket::create(
self.host_ip, self.host_ip,
self.peer_ip, self.peer_ip,
self.host_port, state.host_port,
self.peer_port, state.peer_port,
self.seq, state.seq,
self.ack, state.ack,
TcpFlags { TcpFlags {
syn: true, syn: true,
..Default::default() ..Default::default()
@@ -187,24 +187,28 @@ impl Adapter {
); );
let ip_packet = self.ip_wrap(&tcp_packet); let ip_packet = self.ip_wrap(&tcp_packet);
self.peer.write_all(&ip_packet).await?; self.peer.write_all(&ip_packet).await?;
self.log_packet(&ip_packet).await?; self.log_packet(&ip_packet)?;
// Wait for the syn ack // Wait for the syn ack
let res = self.read_tcp_packet().await?; self.states.insert(host_port, state);
if !(res.flags.syn && res.flags.ack) { loop {
log::error!("Didn't get syn ack: {res:#?}, {self:#?}"); self.process_tcp_packet().await?;
return Err(std::io::Error::new( if let Some(s) = self.states.get(&host_port) {
std::io::ErrorKind::InvalidData, match s.status {
"No syn ack", ConnectionStatus::Connected => {
)); break;
}
ConnectionStatus::Error(e) => {
return Err(std::io::Error::new(e, "failed to connect"))
}
ConnectionStatus::WaitingForSyn => {
continue;
}
}
}
} }
self.seq = self.seq.wrapping_add(1);
// Ack back Ok(host_port)
self.ack().await?;
self.state = AdapterState::Connected;
Ok(())
} }
/// Enables packet capture to a PCAP file. /// Enables packet capture to a PCAP file.
@@ -232,9 +236,9 @@ impl Adapter {
Ok(()) Ok(())
} }
async fn log_packet(&mut self, packet: &[u8]) -> Result<(), std::io::Error> { fn log_packet(&self, packet: &[u8]) -> Result<(), std::io::Error> {
if let Some(file) = &self.pcap { if let Some(file) = &self.pcap {
super::log_packet(file, packet).await; super::log_packet(file, packet);
} }
Ok(()) Ok(())
} }
@@ -247,14 +251,15 @@ impl Adapter {
/// ///
/// # Errors /// # Errors
/// * Returns IO errors if underlying transport fails during close /// * Returns IO errors if underlying transport fails during close
pub async fn close(&mut self) -> Result<(), std::io::Error> { pub(crate) async fn close(&mut self, host_port: u16) -> Result<(), std::io::Error> {
if let Some(state) = self.states.remove(&host_port) {
let tcp_packet = TcpPacket::create( let tcp_packet = TcpPacket::create(
self.host_ip, self.host_ip,
self.peer_ip, self.peer_ip,
self.host_port, state.host_port,
self.peer_port, state.peer_port,
self.seq, state.seq,
self.ack, state.ack,
TcpFlags { TcpFlags {
fin: true, fin: true,
ack: true, ack: true,
@@ -265,31 +270,26 @@ impl Adapter {
); );
let ip_packet = self.ip_wrap(&tcp_packet); let ip_packet = self.ip_wrap(&tcp_packet);
self.peer.write_all(&ip_packet).await?; self.peer.write_all(&ip_packet).await?;
self.log_packet(&ip_packet).await?; self.log_packet(&ip_packet)?;
loop {
let res = self.read_tcp_packet().await?;
if res.flags.psh || !res.payload.is_empty() {
self.ack().await?;
continue;
}
if res.flags.ack || res.flags.fin || res.flags.rst {
break;
}
}
self.state = AdapterState::None;
Ok(()) Ok(())
} else {
Err(std::io::Error::new(
ErrorKind::NotConnected,
"not connected",
))
}
} }
async fn ack(&mut self) -> Result<(), std::io::Error> { async fn ack(&mut self, host_port: u16) -> Result<(), std::io::Error> {
if let Some(state) = self.states.get_mut(&host_port) {
let tcp_packet = TcpPacket::create( let tcp_packet = TcpPacket::create(
self.host_ip, self.host_ip,
self.peer_ip, self.peer_ip,
self.host_port, state.host_port,
self.peer_port, state.peer_port,
self.seq, state.seq,
self.ack, state.ack,
TcpFlags { TcpFlags {
ack: true, ack: true,
..Default::default() ..Default::default()
@@ -299,9 +299,15 @@ impl Adapter {
); );
let ip_packet = self.ip_wrap(&tcp_packet); let ip_packet = self.ip_wrap(&tcp_packet);
self.peer.write_all(&ip_packet).await?; self.peer.write_all(&ip_packet).await?;
self.log_packet(&ip_packet).await?; self.log_packet(&ip_packet)?;
Ok(()) Ok(())
} else {
Err(std::io::Error::new(
ErrorKind::NotConnected,
"not connected",
))
}
} }
/// Sends a TCP packet with PSH flag set (pushing data). /// Sends a TCP packet with PSH flag set (pushing data).
@@ -315,15 +321,20 @@ impl Adapter {
/// ///
/// # Errors /// # Errors
/// * Returns IO errors if underlying transport fails /// * Returns IO errors if underlying transport fails
pub async fn psh(&mut self, data: &[u8]) -> Result<(), std::io::Error> { async fn psh(&mut self, data: &[u8], host_port: u16) -> Result<(), std::io::Error> {
let data_len = if let Some(state) = self.states.get(&host_port) {
// Check to make sure we haven't closed since last operation
if let ConnectionStatus::Error(e) = state.status {
return Err(std::io::Error::new(e, "socket error"));
}
trace!("pshing {} bytes", data.len()); trace!("pshing {} bytes", data.len());
let tcp_packet = TcpPacket::create( let tcp_packet = TcpPacket::create(
self.host_ip, self.host_ip,
self.peer_ip, self.peer_ip,
self.host_port, state.host_port,
self.peer_port, state.peer_port,
self.seq, state.seq,
self.ack, state.ack,
TcpFlags { TcpFlags {
psh: true, psh: true,
ack: true, ack: true,
@@ -334,25 +345,109 @@ impl Adapter {
); );
let ip_packet = self.ip_wrap(&tcp_packet); let ip_packet = self.ip_wrap(&tcp_packet);
self.peer.write_all(&ip_packet).await?; self.peer.write_all(&ip_packet).await?;
self.log_packet(&ip_packet).await?; self.log_packet(&ip_packet)?;
data.len() as u32
} else {
return Err(std::io::Error::new(
ErrorKind::NotConnected,
"not connected",
));
};
self.seq = self.seq.wrapping_add(data.len() as u32); // We have to re-borrow, since we're mutating state
if let Some(state) = self.states.get_mut(&host_port) {
state.seq = state.seq.wrapping_add(data_len);
}
Ok(()) Ok(())
} }
/// Flushes the packets /// Flushes the packets
async fn write_buffer_flush(&mut self) -> Result<(), std::io::Error> { pub(crate) async fn write_buffer_flush(&mut self) -> Result<(), std::io::Error> {
if self.write_buffer.is_empty() { for (_, state) in self.states.clone() {
return Ok(()); let writer_buffer = state.write_buffer.clone();
if writer_buffer.is_empty() {
continue;
}
println!("flushing...");
self.psh(&writer_buffer, state.host_port).await.ok(); // don't care
println!("flushed {} bytes", writer_buffer.len());
// we have to borrow mutably after self.psh
if let Some(state) = self.states.get_mut(&state.host_port) {
state.write_buffer.clear();
}
} }
trace!("Flushing {} bytes", self.write_buffer.len());
let write_buffer = self.write_buffer.clone();
self.psh(&write_buffer).await?;
self.write_buffer = Vec::new();
Ok(()) Ok(())
} }
pub(crate) fn queue_send(
&mut self,
payload: &[u8],
host_port: u16,
) -> Result<(), std::io::Error> {
if let Some(state) = self.states.get_mut(&host_port) {
state.write_buffer.extend_from_slice(payload);
} else {
return Err(std::io::Error::new(
ErrorKind::NotConnected,
"not connected",
));
}
Ok(())
}
pub(crate) fn uncache(
&mut self,
to_copy: usize,
host_port: u16,
) -> Result<Vec<u8>, std::io::Error> {
if let Some(state) = self.states.get_mut(&host_port) {
let to_copy = if to_copy > state.read_buffer.len() {
state.read_buffer.len()
} else {
to_copy
};
let res = state.read_buffer[..to_copy].to_vec();
state.read_buffer = state.read_buffer[to_copy..].to_vec();
Ok(res)
} else {
Err(std::io::Error::new(
ErrorKind::NotConnected,
"not connected",
))
}
}
pub(crate) fn cache_read(
&mut self,
payload: &[u8],
host_port: u16,
) -> Result<(), std::io::Error> {
if let Some(state) = self.states.get_mut(&host_port) {
state.read_buffer.extend_from_slice(payload);
Ok(())
} else {
Err(std::io::Error::new(
ErrorKind::NotConnected,
"not connected",
))
}
}
pub(crate) fn get_status(&self, host_port: u16) -> Result<ConnectionStatus, std::io::Error> {
if let Some(state) = self.states.get(&host_port) {
Ok(state.status.clone())
} else {
Err(std::io::Error::new(
ErrorKind::NotConnected,
"not connected",
))
}
}
/// Receives data from the connection. /// Receives data from the connection.
/// ///
/// # Returns /// # Returns
@@ -362,31 +457,26 @@ impl Adapter {
/// # Errors /// # Errors
/// * Returns `ConnectionReset` if connection was reset or closed /// * Returns `ConnectionReset` if connection was reset or closed
/// * Returns other IO errors if underlying transport fails /// * Returns other IO errors if underlying transport fails
pub async fn recv(&mut self) -> Result<Vec<u8>, std::io::Error> { pub(crate) async fn recv(&mut self, host_port: u16) -> Result<Vec<u8>, std::io::Error> {
loop { loop {
let res = self.read_tcp_packet().await?; // Check to see if we already have some cached
if res.destination_port != self.host_port || res.source_port != self.peer_port { if let Some(state) = self.states.get_mut(&host_port) {
continue; if !state.read_buffer.is_empty() {
let res = state.read_buffer.clone();
state.read_buffer = Vec::new();
return Ok(res);
} }
if res.flags.psh || !res.payload.is_empty() { if let ConnectionStatus::Error(e) = state.status {
self.ack().await?; return Err(std::io::Error::new(e, "socket io error"));
break Ok(res.payload);
} }
if res.flags.rst { } else {
self.state = AdapterState::None; return Err(std::io::Error::new(
break Err(std::io::Error::new( ErrorKind::NotConnected,
std::io::ErrorKind::ConnectionReset, "not connected",
"Connection reset",
));
}
if res.flags.fin {
self.ack().await?;
self.state = AdapterState::None;
break Err(std::io::Error::new(
std::io::ErrorKind::ConnectionReset,
"Connection reset",
)); ));
} }
self.process_tcp_packet().await?;
} }
} }
@@ -413,24 +503,40 @@ impl Adapter {
}) })
} }
async fn read_tcp_packet(&mut self) -> Result<TcpPacket, std::io::Error> { async fn process_tcp_packet(&mut self) -> Result<(), std::io::Error> {
loop {
let ip_packet = self.read_ip_packet().await?; let ip_packet = self.read_ip_packet().await?;
let tcp_packet = TcpPacket::parse(&ip_packet)?; let res = TcpPacket::parse(&ip_packet)?;
if tcp_packet.destination_port != self.host_port let mut ack_me = None;
|| tcp_packet.source_port != self.peer_port if let Some(state) = self.states.get_mut(&res.destination_port) {
{ state.ack = res.sequence_number
continue; + if res.payload.is_empty() {
}
trace!("TCP packet: {tcp_packet:#?}");
self.ack = tcp_packet.sequence_number
+ if tcp_packet.payload.is_empty() {
1 1
} else { } else {
tcp_packet.payload.len() as u32 res.payload.len() as u32
}; };
break Ok(tcp_packet); if res.flags.psh || !res.payload.is_empty() {
ack_me = Some(res.destination_port);
state.read_buffer.extend(res.payload)
} }
if res.flags.rst {
state.status = ConnectionStatus::Error(ErrorKind::ConnectionReset);
}
if res.flags.fin {
ack_me = Some(res.destination_port);
state.status = ConnectionStatus::Error(ErrorKind::ConnectionReset);
}
if res.flags.syn && res.flags.ack {
ack_me = Some(res.destination_port);
state.seq = state.seq.wrapping_add(1);
state.status = ConnectionStatus::Connected;
}
}
// we have to ack outside of the mutable state borrow
if let Some(a) = ack_me {
self.ack(a).await?;
}
Ok(())
} }
fn ip_wrap(&self, packet: &[u8]) -> Vec<u8> { fn ip_wrap(&self, packet: &[u8]) -> Vec<u8> {
@@ -454,125 +560,3 @@ impl Adapter {
} }
} }
} }
impl AsyncRead for Adapter {
/// Attempts to read from the connection into the provided buffer.
///
/// Uses an internal read buffer to cache any extra received data.
///
/// # Returns
/// * `Poll::Ready(Ok(()))` if data was read successfully
/// * `Poll::Ready(Err(e))` if an error occurred
/// * `Poll::Pending` if operation would block
///
/// # Errors
/// * Returns `NotConnected` if adapter isn't connected
/// * Propagates any underlying transport errors
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
// First, check if we have any cached data
if !self.read_buffer.is_empty() {
let to_copy = std::cmp::min(buf.remaining(), self.read_buffer.len());
buf.put_slice(&self.read_buffer[..to_copy]);
// Keep any remaining data in the buffer
if to_copy < self.read_buffer.len() {
self.read_buffer = self.read_buffer[to_copy..].to_vec();
} else {
self.read_buffer.clear();
}
return std::task::Poll::Ready(Ok(()));
}
// If no cached data and not connected, return error
if self.state != AdapterState::Connected {
return std::task::Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::NotConnected,
"Adapter not connected",
)));
}
// If no cached data, try to receive new data
let future = async {
match self.recv().await {
Ok(data) => {
let len = std::cmp::min(buf.remaining(), data.len());
buf.put_slice(&data[..len]);
// If we received more data than needed, cache the rest
if len < data.len() {
self.read_buffer = data[len..].to_vec();
}
Ok(())
}
Err(e) => Err(e),
}
};
// Pin the future and poll it
futures::pin_mut!(future);
future.poll(cx)
}
}
impl AsyncWrite for Adapter {
/// Attempts to write data to the connection.
///
/// Data is buffered internally until flushed.
///
/// # Returns
/// * `Poll::Ready(Ok(n))` with number of bytes written
/// * `Poll::Ready(Err(e))` if an error occurred
/// * `Poll::Pending` if operation would block
///
/// # Errors
/// * Returns `NotConnected` if adapter isn't connected
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
trace!("poll psh {}", buf.len());
if self.state != AdapterState::Connected {
return std::task::Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::NotConnected,
"Adapter not connected",
)));
}
self.write_buffer.extend_from_slice(buf);
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let future = async {
match self.write_buffer_flush().await {
Ok(_) => Ok(()),
Err(e) => Err(e),
}
};
// Pin the future and poll it
futures::pin_mut!(future);
future.poll(cx)
}
fn poll_shutdown(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
// Create a future that can be polled
let future = async { self.close().await };
// Pin the future and poll it
futures::pin_mut!(future);
future.poll(cx)
}
}

View File

@@ -10,8 +10,9 @@ use tokio::io::AsyncWriteExt;
pub mod adapter; pub mod adapter;
pub mod packets; pub mod packets;
pub mod stream;
pub(crate) async fn log_packet(file: &Arc<tokio::sync::Mutex<tokio::fs::File>>, packet: &[u8]) { pub(crate) fn log_packet(file: &Arc<tokio::sync::Mutex<tokio::fs::File>>, packet: &[u8]) {
debug!("Logging {} byte packet", packet.len()); debug!("Logging {} byte packet", packet.len());
let packet = packet.to_vec(); let packet = packet.to_vec();
let file = file.to_owned(); let file = file.to_owned();
@@ -49,6 +50,7 @@ mod tests {
pin::Pin, pin::Pin,
task::{Context, Poll}, task::{Context, Poll},
}; };
use stream::AdapterStream;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tun_rs::DeviceBuilder; use tun_rs::DeviceBuilder;
@@ -186,31 +188,35 @@ mod tests {
let mut buf = Vec::new(); let mut buf = Vec::new();
let _ = tokio::io::stdin().read(&mut buf).await.unwrap(); let _ = tokio::io::stdin().read(&mut buf).await.unwrap();
if let Err(e) = adapter.connect(SERVER_PORT).await { let mut stream = match AdapterStream::connect(&mut adapter, SERVER_PORT).await {
Ok(s) => s,
Err(e) => {
println!("no connect: {e:?}"); println!("no connect: {e:?}");
return;
} }
};
if let Err(e) = adapter.write_all(&[1, 2, 3, 4, 5]).await { if let Err(e) = stream.write_all(&[1, 2, 3, 4, 5]).await {
println!("no send: {e:?}"); println!("no send: {e:?}");
} else { } else {
let mut buf = [0u8; 4]; let mut buf = [0u8; 4];
match adapter.read_exact(&mut buf).await { match stream.read_exact(&mut buf).await {
Ok(_) => println!("recv'd {buf:?}"), Ok(_) => println!("recv'd {buf:?}"),
Err(e) => println!("no recv: {e:?}"), Err(e) => println!("no recv: {e:?}"),
} }
} }
if let Err(e) = adapter.write_all(&[69, 69, 42, 0, 1]).await { if let Err(e) = stream.write_all(&[69, 69, 42, 0, 1]).await {
println!("no send: {e:?}"); println!("no send: {e:?}");
} else { } else {
let mut buf = [0u8; 6]; let mut buf = [0u8; 6];
match adapter.read_exact(&mut buf).await { match stream.read_exact(&mut buf).await {
Ok(_) => println!("recv'd {buf:?}"), Ok(_) => println!("recv'd {buf:?}"),
Err(e) => println!("no recv: {e:?}"), Err(e) => println!("no recv: {e:?}"),
} }
} }
if let Err(e) = adapter.close().await { if let Err(e) = stream.close().await {
println!("no close: {e:?}"); println!("no close: {e:?}");
} }

View File

@@ -143,7 +143,7 @@ impl Ipv4Packet {
reader.read_exact(&mut payload).await?; reader.read_exact(&mut payload).await?;
if let Some(log) = log { if let Some(log) = log {
log_packet.extend_from_slice(&payload); log_packet.extend_from_slice(&payload);
super::log_packet(log, &log_packet).await; super::log_packet(log, &log_packet);
} }
Ok(Self { Ok(Self {
@@ -324,7 +324,7 @@ impl Ipv6Packet {
reader.read_exact(&mut payload).await?; reader.read_exact(&mut payload).await?;
if let Some(log) = log { if let Some(log) = log {
log_packet.extend_from_slice(&payload); log_packet.extend_from_slice(&payload);
super::log_packet(log, &log_packet).await; super::log_packet(log, &log_packet);
} }
Ok(Self { Ok(Self {

158
idevice/src/tcp/stream.rs Normal file
View File

@@ -0,0 +1,158 @@
// Jackson Coxson
use std::{future::Future, task::Poll};
use log::trace;
use tokio::io::{AsyncRead, AsyncWrite};
use crate::tcp::adapter::ConnectionStatus;
use super::adapter::Adapter;
#[derive(Debug)]
pub struct AdapterStream<'a> {
pub(crate) adapter: &'a mut Adapter,
pub host_port: u16,
pub peer_port: u16,
}
impl<'a> AdapterStream<'a> {
pub async fn connect(adapter: &'a mut Adapter, port: u16) -> Result<Self, std::io::Error> {
let host_port = adapter.connect(port).await?;
Ok(Self {
adapter,
host_port,
peer_port: port,
})
}
pub async fn close(&mut self) -> Result<(), std::io::Error> {
self.adapter.close(self.host_port).await
}
}
impl AsyncRead for AdapterStream<'_> {
/// Attempts to read from the connection into the provided buffer.
///
/// Uses an internal read buffer to cache any extra received data.
///
/// # Returns
/// * `Poll::Ready(Ok(()))` if data was read successfully
/// * `Poll::Ready(Err(e))` if an error occurred
/// * `Poll::Pending` if operation would block
///
/// # Errors
/// * Returns `NotConnected` if adapter isn't connected
/// * Propagates any underlying transport errors
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
match self.adapter.get_status(self.host_port) {
Ok(ConnectionStatus::Error(e)) => {
return std::task::Poll::Ready(Err(std::io::Error::new(e, "io error")));
}
Err(e) => {
return std::task::Poll::Ready(Err(e));
}
_ => {}
}
// First, check if we have any cached data
let p = self.host_port;
let cache = match self.adapter.uncache(buf.remaining(), p) {
Ok(c) => c,
Err(e) => return std::task::Poll::Ready(Err(e)),
};
if !cache.is_empty() {
buf.put_slice(&cache);
return std::task::Poll::Ready(Ok(()));
}
// If no cached data, try to receive new data
let future = async {
match self.adapter.recv(p).await {
Ok(data) => {
let len = std::cmp::min(buf.remaining(), data.len());
buf.put_slice(&data[..len]);
// If we received more data than needed, cache the rest
if len < data.len() {
self.adapter.cache_read(&data[len..], p)?
}
Ok(())
}
Err(e) => Err(e),
}
};
// Pin the future and poll it
futures::pin_mut!(future);
future.poll(cx)
}
}
impl AsyncWrite for AdapterStream<'_> {
/// Attempts to write data to the connection.
///
/// Data is buffered internally until flushed.
///
/// # Returns
/// * `Poll::Ready(Ok(n))` with number of bytes written
/// * `Poll::Ready(Err(e))` if an error occurred
/// * `Poll::Pending` if operation would block
///
/// # Errors
/// * Returns `NotConnected` if adapter isn't connected
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
trace!("poll psh {}", buf.len());
match self.adapter.get_status(self.host_port) {
Ok(ConnectionStatus::Error(e)) => {
return std::task::Poll::Ready(Err(std::io::Error::new(e, "io error")));
}
Err(e) => {
return std::task::Poll::Ready(Err(e));
}
_ => {}
}
let p = self.host_port;
match self.adapter.queue_send(buf, p) {
Ok(_) => Poll::Ready(Ok(buf.len())),
Err(e) => Poll::Ready(Err(e)),
}
}
fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let future = async {
match self.adapter.write_buffer_flush().await {
Ok(_) => Ok(()),
Err(e) => Err(e),
}
};
// Pin the future and poll it
futures::pin_mut!(future);
future.poll(cx)
}
fn poll_shutdown(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
// Create a future that can be polled
let future = async { self.close().await };
// Pin the future and poll it
futures::pin_mut!(future);
future.poll(cx)
}
}

View File

@@ -2,7 +2,10 @@
// Print out all the RemoteXPC services // Print out all the RemoteXPC services
use clap::{Arg, Command}; use clap::{Arg, Command};
use idevice::{core_device_proxy::CoreDeviceProxy, xpc::RemoteXpcClient, IdeviceService}; use idevice::{
core_device_proxy::CoreDeviceProxy, tcp::stream::AdapterStream, xpc::RemoteXpcClient,
IdeviceService,
};
mod common; mod common;
@@ -63,10 +66,13 @@ async fn main() {
let rsd_port = proxy.handshake.server_rsd_port; let rsd_port = proxy.handshake.server_rsd_port;
let mut adapter = proxy.create_software_tunnel().expect("no software tunnel"); let mut adapter = proxy.create_software_tunnel().expect("no software tunnel");
adapter.connect(rsd_port).await.expect("no RSD connect"); adapter.pcap("new_xpc.pcap").await.unwrap();
let conn = AdapterStream::connect(&mut adapter, rsd_port)
.await
.expect("no RSD connect");
// Make the connection to RemoteXPC // Make the connection to RemoteXPC
let mut client = RemoteXpcClient::new(Box::new(adapter)).await.unwrap(); let mut client = RemoteXpcClient::new(Box::new(conn)).await.unwrap();
println!("{:#?}", client.do_handshake().await); println!("{:#?}", client.do_handshake().await);
} }