diff --git a/idevice/Cargo.toml b/idevice/Cargo.toml index a550591..344a64f 100644 --- a/idevice/Cargo.toml +++ b/idevice/Cargo.toml @@ -50,7 +50,7 @@ x509-cert = { version = "0.2", optional = true, features = [ ], default-features = false } [dev-dependencies] -tokio = { version = "1.43", features = ["fs"] } +tokio = { version = "1.43", features = ["full"] } tun-rs = { version = "2.0.8", features = ["async_tokio"] } bytes = "1.10.1" diff --git a/idevice/src/lib.rs b/idevice/src/lib.rs index d6a54dd..b7ead1d 100644 --- a/idevice/src/lib.rs +++ b/idevice/src/lib.rs @@ -20,6 +20,8 @@ pub mod xpc; pub mod services; pub use services::*; + +#[cfg(feature = "xpc")] pub use xpc::RemoteXpcClient; use log::{debug, error, trace}; diff --git a/idevice/src/tcp/adapter.rs b/idevice/src/tcp/adapter.rs index 4b0eb8d..5222c5f 100644 --- a/idevice/src/tcp/adapter.rs +++ b/idevice/src/tcp/adapter.rs @@ -61,22 +61,45 @@ //! This implementation makes significant simplifications and should not be used //! in production environments or with unreliable network transports. -use std::{future::Future, net::IpAddr, path::Path, sync::Arc, task::Poll}; +use std::{collections::HashMap, io::ErrorKind, net::IpAddr, path::Path, sync::Arc}; use log::trace; -use tokio::{ - io::{AsyncRead, AsyncWrite, AsyncWriteExt}, - sync::Mutex, -}; +use tokio::{io::AsyncWriteExt, sync::Mutex}; use crate::ReadWrite; use super::packets::{Ipv4Packet, Ipv6Packet, ProtocolNumber, TcpFlags, TcpPacket}; -#[derive(Clone, Debug, PartialEq)] -enum AdapterState { +#[derive(Debug, Clone)] +struct ConnectionState { + seq: u32, + ack: u32, + host_port: u16, + peer_port: u16, + read_buffer: Vec, + write_buffer: Vec, + status: ConnectionStatus, +} + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub(crate) enum ConnectionStatus { + WaitingForSyn, Connected, - None, + Error(ErrorKind), +} + +impl ConnectionState { + fn new(host_port: u16, peer_port: u16) -> Self { + Self { + seq: rand::random(), + ack: 0, + host_port, + peer_port, + read_buffer: Vec::new(), + write_buffer: Vec::new(), + status: ConnectionStatus::WaitingForSyn, + } + } } /// A simplified TCP network stack implementation. @@ -96,24 +119,8 @@ pub struct Adapter { host_ip: IpAddr, /// The remote peer's IP address peer_ip: IpAddr, - /// Current connection state - state: AdapterState, - // TCP state - /// Current sequence number - seq: u32, - /// Current acknowledgement number - ack: u32, - /// Local port number - host_port: u16, - /// Remote port number - peer_port: u16, - - // Read buffer to cache unused bytes - /// Buffer for storing unread received data - read_buffer: Vec, - /// Buffer for storing data to be sent - write_buffer: Vec, + states: HashMap, // host port by state // Logging /// Optional PCAP file for packet logging @@ -135,13 +142,7 @@ impl Adapter { peer, host_ip, peer_ip, - state: AdapterState::None, - seq: 0, - ack: 0, - host_port: 1024, - peer_port: 1024, - read_buffer: Vec::new(), - write_buffer: Vec::new(), + states: HashMap::new(), pcap: None, } } @@ -158,26 +159,25 @@ impl Adapter { /// # Errors /// * Returns `InvalidData` if the SYN-ACK response is invalid /// * Returns other IO errors if underlying transport fails - pub async fn connect(&mut self, port: u16) -> Result<(), std::io::Error> { - self.read_buffer = Vec::new(); - self.write_buffer = Vec::new(); - - // Randomize seq - self.seq = rand::random(); - self.ack = 0; - - // Choose a random port - self.host_port = rand::random(); - self.peer_port = port; + pub(crate) async fn connect(&mut self, port: u16) -> Result { + let host_port = loop { + let host_port: u16 = rand::random(); + if self.states.contains_key(&host_port) { + continue; + } else { + break host_port; + } + }; + let state = ConnectionState::new(host_port, port); // Create the TCP packet let tcp_packet = TcpPacket::create( self.host_ip, self.peer_ip, - self.host_port, - self.peer_port, - self.seq, - self.ack, + state.host_port, + state.peer_port, + state.seq, + state.ack, TcpFlags { syn: true, ..Default::default() @@ -187,24 +187,28 @@ impl Adapter { ); let ip_packet = self.ip_wrap(&tcp_packet); self.peer.write_all(&ip_packet).await?; - self.log_packet(&ip_packet).await?; + self.log_packet(&ip_packet)?; // Wait for the syn ack - let res = self.read_tcp_packet().await?; - if !(res.flags.syn && res.flags.ack) { - log::error!("Didn't get syn ack: {res:#?}, {self:#?}"); - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "No syn ack", - )); + self.states.insert(host_port, state); + loop { + self.process_tcp_packet().await?; + if let Some(s) = self.states.get(&host_port) { + match s.status { + ConnectionStatus::Connected => { + break; + } + ConnectionStatus::Error(e) => { + return Err(std::io::Error::new(e, "failed to connect")) + } + ConnectionStatus::WaitingForSyn => { + continue; + } + } + } } - self.seq = self.seq.wrapping_add(1); - // Ack back - self.ack().await?; - - self.state = AdapterState::Connected; - Ok(()) + Ok(host_port) } /// Enables packet capture to a PCAP file. @@ -232,9 +236,9 @@ impl Adapter { Ok(()) } - async fn log_packet(&mut self, packet: &[u8]) -> Result<(), std::io::Error> { + fn log_packet(&self, packet: &[u8]) -> Result<(), std::io::Error> { if let Some(file) = &self.pcap { - super::log_packet(file, packet).await; + super::log_packet(file, packet); } Ok(()) } @@ -247,61 +251,63 @@ impl Adapter { /// /// # Errors /// * Returns IO errors if underlying transport fails during close - pub async fn close(&mut self) -> Result<(), std::io::Error> { - let tcp_packet = TcpPacket::create( - self.host_ip, - self.peer_ip, - self.host_port, - self.peer_port, - self.seq, - self.ack, - TcpFlags { - fin: true, - ack: true, - ..Default::default() - }, - u16::MAX - 1, - &[], - ); - let ip_packet = self.ip_wrap(&tcp_packet); - self.peer.write_all(&ip_packet).await?; - self.log_packet(&ip_packet).await?; + pub(crate) async fn close(&mut self, host_port: u16) -> Result<(), std::io::Error> { + if let Some(state) = self.states.remove(&host_port) { + let tcp_packet = TcpPacket::create( + self.host_ip, + self.peer_ip, + state.host_port, + state.peer_port, + state.seq, + state.ack, + TcpFlags { + fin: true, + ack: true, + ..Default::default() + }, + u16::MAX - 1, + &[], + ); + let ip_packet = self.ip_wrap(&tcp_packet); + self.peer.write_all(&ip_packet).await?; + self.log_packet(&ip_packet)?; - loop { - let res = self.read_tcp_packet().await?; - if res.flags.psh || !res.payload.is_empty() { - self.ack().await?; - continue; - } - - if res.flags.ack || res.flags.fin || res.flags.rst { - break; - } + Ok(()) + } else { + Err(std::io::Error::new( + ErrorKind::NotConnected, + "not connected", + )) } - self.state = AdapterState::None; - Ok(()) } - async fn ack(&mut self) -> Result<(), std::io::Error> { - let tcp_packet = TcpPacket::create( - self.host_ip, - self.peer_ip, - self.host_port, - self.peer_port, - self.seq, - self.ack, - TcpFlags { - ack: true, - ..Default::default() - }, - u16::MAX - 1, - &[], - ); - let ip_packet = self.ip_wrap(&tcp_packet); - self.peer.write_all(&ip_packet).await?; - self.log_packet(&ip_packet).await?; + async fn ack(&mut self, host_port: u16) -> Result<(), std::io::Error> { + if let Some(state) = self.states.get_mut(&host_port) { + let tcp_packet = TcpPacket::create( + self.host_ip, + self.peer_ip, + state.host_port, + state.peer_port, + state.seq, + state.ack, + TcpFlags { + ack: true, + ..Default::default() + }, + u16::MAX - 1, + &[], + ); + let ip_packet = self.ip_wrap(&tcp_packet); + self.peer.write_all(&ip_packet).await?; + self.log_packet(&ip_packet)?; - Ok(()) + Ok(()) + } else { + Err(std::io::Error::new( + ErrorKind::NotConnected, + "not connected", + )) + } } /// Sends a TCP packet with PSH flag set (pushing data). @@ -315,44 +321,133 @@ impl Adapter { /// /// # Errors /// * Returns IO errors if underlying transport fails - pub async fn psh(&mut self, data: &[u8]) -> Result<(), std::io::Error> { - trace!("pshing {} bytes", data.len()); - let tcp_packet = TcpPacket::create( - self.host_ip, - self.peer_ip, - self.host_port, - self.peer_port, - self.seq, - self.ack, - TcpFlags { - psh: true, - ack: true, - ..Default::default() - }, - u16::MAX - 1, - data, - ); - let ip_packet = self.ip_wrap(&tcp_packet); - self.peer.write_all(&ip_packet).await?; - self.log_packet(&ip_packet).await?; + async fn psh(&mut self, data: &[u8], host_port: u16) -> Result<(), std::io::Error> { + let data_len = if let Some(state) = self.states.get(&host_port) { + // Check to make sure we haven't closed since last operation + if let ConnectionStatus::Error(e) = state.status { + return Err(std::io::Error::new(e, "socket error")); + } + trace!("pshing {} bytes", data.len()); + let tcp_packet = TcpPacket::create( + self.host_ip, + self.peer_ip, + state.host_port, + state.peer_port, + state.seq, + state.ack, + TcpFlags { + psh: true, + ack: true, + ..Default::default() + }, + u16::MAX - 1, + data, + ); + let ip_packet = self.ip_wrap(&tcp_packet); + self.peer.write_all(&ip_packet).await?; + self.log_packet(&ip_packet)?; + data.len() as u32 + } else { + return Err(std::io::Error::new( + ErrorKind::NotConnected, + "not connected", + )); + }; - self.seq = self.seq.wrapping_add(data.len() as u32); + // We have to re-borrow, since we're mutating state + if let Some(state) = self.states.get_mut(&host_port) { + state.seq = state.seq.wrapping_add(data_len); + } Ok(()) } /// Flushes the packets - async fn write_buffer_flush(&mut self) -> Result<(), std::io::Error> { - if self.write_buffer.is_empty() { - return Ok(()); + pub(crate) async fn write_buffer_flush(&mut self) -> Result<(), std::io::Error> { + for (_, state) in self.states.clone() { + let writer_buffer = state.write_buffer.clone(); + if writer_buffer.is_empty() { + continue; + } + + println!("flushing..."); + self.psh(&writer_buffer, state.host_port).await.ok(); // don't care + println!("flushed {} bytes", writer_buffer.len()); + + // we have to borrow mutably after self.psh + if let Some(state) = self.states.get_mut(&state.host_port) { + state.write_buffer.clear(); + } } - trace!("Flushing {} bytes", self.write_buffer.len()); - let write_buffer = self.write_buffer.clone(); - self.psh(&write_buffer).await?; - self.write_buffer = Vec::new(); Ok(()) } + pub(crate) fn queue_send( + &mut self, + payload: &[u8], + host_port: u16, + ) -> Result<(), std::io::Error> { + if let Some(state) = self.states.get_mut(&host_port) { + state.write_buffer.extend_from_slice(payload); + } else { + return Err(std::io::Error::new( + ErrorKind::NotConnected, + "not connected", + )); + } + Ok(()) + } + + pub(crate) fn uncache( + &mut self, + to_copy: usize, + host_port: u16, + ) -> Result, std::io::Error> { + if let Some(state) = self.states.get_mut(&host_port) { + let to_copy = if to_copy > state.read_buffer.len() { + state.read_buffer.len() + } else { + to_copy + }; + + let res = state.read_buffer[..to_copy].to_vec(); + state.read_buffer = state.read_buffer[to_copy..].to_vec(); + Ok(res) + } else { + Err(std::io::Error::new( + ErrorKind::NotConnected, + "not connected", + )) + } + } + + pub(crate) fn cache_read( + &mut self, + payload: &[u8], + host_port: u16, + ) -> Result<(), std::io::Error> { + if let Some(state) = self.states.get_mut(&host_port) { + state.read_buffer.extend_from_slice(payload); + Ok(()) + } else { + Err(std::io::Error::new( + ErrorKind::NotConnected, + "not connected", + )) + } + } + + pub(crate) fn get_status(&self, host_port: u16) -> Result { + if let Some(state) = self.states.get(&host_port) { + Ok(state.status.clone()) + } else { + Err(std::io::Error::new( + ErrorKind::NotConnected, + "not connected", + )) + } + } + /// Receives data from the connection. /// /// # Returns @@ -362,31 +457,26 @@ impl Adapter { /// # Errors /// * Returns `ConnectionReset` if connection was reset or closed /// * Returns other IO errors if underlying transport fails - pub async fn recv(&mut self) -> Result, std::io::Error> { + pub(crate) async fn recv(&mut self, host_port: u16) -> Result, std::io::Error> { loop { - let res = self.read_tcp_packet().await?; - if res.destination_port != self.host_port || res.source_port != self.peer_port { - continue; - } - if res.flags.psh || !res.payload.is_empty() { - self.ack().await?; - break Ok(res.payload); - } - if res.flags.rst { - self.state = AdapterState::None; - break Err(std::io::Error::new( - std::io::ErrorKind::ConnectionReset, - "Connection reset", - )); - } - if res.flags.fin { - self.ack().await?; - self.state = AdapterState::None; - break Err(std::io::Error::new( - std::io::ErrorKind::ConnectionReset, - "Connection reset", + // Check to see if we already have some cached + if let Some(state) = self.states.get_mut(&host_port) { + if !state.read_buffer.is_empty() { + let res = state.read_buffer.clone(); + state.read_buffer = Vec::new(); + return Ok(res); + } + if let ConnectionStatus::Error(e) = state.status { + return Err(std::io::Error::new(e, "socket io error")); + } + } else { + return Err(std::io::Error::new( + ErrorKind::NotConnected, + "not connected", )); } + + self.process_tcp_packet().await?; } } @@ -413,24 +503,40 @@ impl Adapter { }) } - async fn read_tcp_packet(&mut self) -> Result { - loop { - let ip_packet = self.read_ip_packet().await?; - let tcp_packet = TcpPacket::parse(&ip_packet)?; - if tcp_packet.destination_port != self.host_port - || tcp_packet.source_port != self.peer_port - { - continue; - } - trace!("TCP packet: {tcp_packet:#?}"); - self.ack = tcp_packet.sequence_number - + if tcp_packet.payload.is_empty() { + async fn process_tcp_packet(&mut self) -> Result<(), std::io::Error> { + let ip_packet = self.read_ip_packet().await?; + let res = TcpPacket::parse(&ip_packet)?; + let mut ack_me = None; + if let Some(state) = self.states.get_mut(&res.destination_port) { + state.ack = res.sequence_number + + if res.payload.is_empty() { 1 } else { - tcp_packet.payload.len() as u32 + res.payload.len() as u32 }; - break Ok(tcp_packet); + if res.flags.psh || !res.payload.is_empty() { + ack_me = Some(res.destination_port); + state.read_buffer.extend(res.payload) + } + if res.flags.rst { + state.status = ConnectionStatus::Error(ErrorKind::ConnectionReset); + } + if res.flags.fin { + ack_me = Some(res.destination_port); + state.status = ConnectionStatus::Error(ErrorKind::ConnectionReset); + } + if res.flags.syn && res.flags.ack { + ack_me = Some(res.destination_port); + state.seq = state.seq.wrapping_add(1); + state.status = ConnectionStatus::Connected; + } } + + // we have to ack outside of the mutable state borrow + if let Some(a) = ack_me { + self.ack(a).await?; + } + Ok(()) } fn ip_wrap(&self, packet: &[u8]) -> Vec { @@ -454,125 +560,3 @@ impl Adapter { } } } - -impl AsyncRead for Adapter { - /// Attempts to read from the connection into the provided buffer. - /// - /// Uses an internal read buffer to cache any extra received data. - /// - /// # Returns - /// * `Poll::Ready(Ok(()))` if data was read successfully - /// * `Poll::Ready(Err(e))` if an error occurred - /// * `Poll::Pending` if operation would block - /// - /// # Errors - /// * Returns `NotConnected` if adapter isn't connected - /// * Propagates any underlying transport errors - fn poll_read( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> std::task::Poll> { - // First, check if we have any cached data - if !self.read_buffer.is_empty() { - let to_copy = std::cmp::min(buf.remaining(), self.read_buffer.len()); - buf.put_slice(&self.read_buffer[..to_copy]); - - // Keep any remaining data in the buffer - if to_copy < self.read_buffer.len() { - self.read_buffer = self.read_buffer[to_copy..].to_vec(); - } else { - self.read_buffer.clear(); - } - - return std::task::Poll::Ready(Ok(())); - } - - // If no cached data and not connected, return error - if self.state != AdapterState::Connected { - return std::task::Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::NotConnected, - "Adapter not connected", - ))); - } - - // If no cached data, try to receive new data - let future = async { - match self.recv().await { - Ok(data) => { - let len = std::cmp::min(buf.remaining(), data.len()); - buf.put_slice(&data[..len]); - - // If we received more data than needed, cache the rest - if len < data.len() { - self.read_buffer = data[len..].to_vec(); - } - - Ok(()) - } - Err(e) => Err(e), - } - }; - - // Pin the future and poll it - futures::pin_mut!(future); - future.poll(cx) - } -} - -impl AsyncWrite for Adapter { - /// Attempts to write data to the connection. - /// - /// Data is buffered internally until flushed. - /// - /// # Returns - /// * `Poll::Ready(Ok(n))` with number of bytes written - /// * `Poll::Ready(Err(e))` if an error occurred - /// * `Poll::Pending` if operation would block - /// - /// # Errors - /// * Returns `NotConnected` if adapter isn't connected - fn poll_write( - mut self: std::pin::Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> std::task::Poll> { - trace!("poll psh {}", buf.len()); - if self.state != AdapterState::Connected { - return std::task::Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::NotConnected, - "Adapter not connected", - ))); - } - self.write_buffer.extend_from_slice(buf); - Poll::Ready(Ok(buf.len())) - } - - fn poll_flush( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let future = async { - match self.write_buffer_flush().await { - Ok(_) => Ok(()), - Err(e) => Err(e), - } - }; - - // Pin the future and poll it - futures::pin_mut!(future); - future.poll(cx) - } - - fn poll_shutdown( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - // Create a future that can be polled - let future = async { self.close().await }; - - // Pin the future and poll it - futures::pin_mut!(future); - future.poll(cx) - } -} diff --git a/idevice/src/tcp/mod.rs b/idevice/src/tcp/mod.rs index b1da728..c8fda3e 100644 --- a/idevice/src/tcp/mod.rs +++ b/idevice/src/tcp/mod.rs @@ -10,8 +10,9 @@ use tokio::io::AsyncWriteExt; pub mod adapter; pub mod packets; +pub mod stream; -pub(crate) async fn log_packet(file: &Arc>, packet: &[u8]) { +pub(crate) fn log_packet(file: &Arc>, packet: &[u8]) { debug!("Logging {} byte packet", packet.len()); let packet = packet.to_vec(); let file = file.to_owned(); @@ -49,6 +50,7 @@ mod tests { pin::Pin, task::{Context, Poll}, }; + use stream::AdapterStream; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tun_rs::DeviceBuilder; @@ -186,31 +188,35 @@ mod tests { let mut buf = Vec::new(); let _ = tokio::io::stdin().read(&mut buf).await.unwrap(); - if let Err(e) = adapter.connect(SERVER_PORT).await { - println!("no connect: {e:?}"); - } + let mut stream = match AdapterStream::connect(&mut adapter, SERVER_PORT).await { + Ok(s) => s, + Err(e) => { + println!("no connect: {e:?}"); + return; + } + }; - if let Err(e) = adapter.write_all(&[1, 2, 3, 4, 5]).await { + if let Err(e) = stream.write_all(&[1, 2, 3, 4, 5]).await { println!("no send: {e:?}"); } else { let mut buf = [0u8; 4]; - match adapter.read_exact(&mut buf).await { + match stream.read_exact(&mut buf).await { Ok(_) => println!("recv'd {buf:?}"), Err(e) => println!("no recv: {e:?}"), } } - if let Err(e) = adapter.write_all(&[69, 69, 42, 0, 1]).await { + if let Err(e) = stream.write_all(&[69, 69, 42, 0, 1]).await { println!("no send: {e:?}"); } else { let mut buf = [0u8; 6]; - match adapter.read_exact(&mut buf).await { + match stream.read_exact(&mut buf).await { Ok(_) => println!("recv'd {buf:?}"), Err(e) => println!("no recv: {e:?}"), } } - if let Err(e) = adapter.close().await { + if let Err(e) = stream.close().await { println!("no close: {e:?}"); } diff --git a/idevice/src/tcp/packets.rs b/idevice/src/tcp/packets.rs index 83f64eb..fa9e94f 100644 --- a/idevice/src/tcp/packets.rs +++ b/idevice/src/tcp/packets.rs @@ -143,7 +143,7 @@ impl Ipv4Packet { reader.read_exact(&mut payload).await?; if let Some(log) = log { log_packet.extend_from_slice(&payload); - super::log_packet(log, &log_packet).await; + super::log_packet(log, &log_packet); } Ok(Self { @@ -324,7 +324,7 @@ impl Ipv6Packet { reader.read_exact(&mut payload).await?; if let Some(log) = log { log_packet.extend_from_slice(&payload); - super::log_packet(log, &log_packet).await; + super::log_packet(log, &log_packet); } Ok(Self { diff --git a/idevice/src/tcp/stream.rs b/idevice/src/tcp/stream.rs new file mode 100644 index 0000000..2ada436 --- /dev/null +++ b/idevice/src/tcp/stream.rs @@ -0,0 +1,158 @@ +// Jackson Coxson + +use std::{future::Future, task::Poll}; + +use log::trace; +use tokio::io::{AsyncRead, AsyncWrite}; + +use crate::tcp::adapter::ConnectionStatus; + +use super::adapter::Adapter; + +#[derive(Debug)] +pub struct AdapterStream<'a> { + pub(crate) adapter: &'a mut Adapter, + pub host_port: u16, + pub peer_port: u16, +} + +impl<'a> AdapterStream<'a> { + pub async fn connect(adapter: &'a mut Adapter, port: u16) -> Result { + let host_port = adapter.connect(port).await?; + Ok(Self { + adapter, + host_port, + peer_port: port, + }) + } + + pub async fn close(&mut self) -> Result<(), std::io::Error> { + self.adapter.close(self.host_port).await + } +} + +impl AsyncRead for AdapterStream<'_> { + /// Attempts to read from the connection into the provided buffer. + /// + /// Uses an internal read buffer to cache any extra received data. + /// + /// # Returns + /// * `Poll::Ready(Ok(()))` if data was read successfully + /// * `Poll::Ready(Err(e))` if an error occurred + /// * `Poll::Pending` if operation would block + /// + /// # Errors + /// * Returns `NotConnected` if adapter isn't connected + /// * Propagates any underlying transport errors + fn poll_read( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + match self.adapter.get_status(self.host_port) { + Ok(ConnectionStatus::Error(e)) => { + return std::task::Poll::Ready(Err(std::io::Error::new(e, "io error"))); + } + Err(e) => { + return std::task::Poll::Ready(Err(e)); + } + _ => {} + } + + // First, check if we have any cached data + let p = self.host_port; + let cache = match self.adapter.uncache(buf.remaining(), p) { + Ok(c) => c, + Err(e) => return std::task::Poll::Ready(Err(e)), + }; + if !cache.is_empty() { + buf.put_slice(&cache); + return std::task::Poll::Ready(Ok(())); + } + + // If no cached data, try to receive new data + let future = async { + match self.adapter.recv(p).await { + Ok(data) => { + let len = std::cmp::min(buf.remaining(), data.len()); + buf.put_slice(&data[..len]); + + // If we received more data than needed, cache the rest + if len < data.len() { + self.adapter.cache_read(&data[len..], p)? + } + + Ok(()) + } + Err(e) => Err(e), + } + }; + + // Pin the future and poll it + futures::pin_mut!(future); + future.poll(cx) + } +} + +impl AsyncWrite for AdapterStream<'_> { + /// Attempts to write data to the connection. + /// + /// Data is buffered internally until flushed. + /// + /// # Returns + /// * `Poll::Ready(Ok(n))` with number of bytes written + /// * `Poll::Ready(Err(e))` if an error occurred + /// * `Poll::Pending` if operation would block + /// + /// # Errors + /// * Returns `NotConnected` if adapter isn't connected + fn poll_write( + mut self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + trace!("poll psh {}", buf.len()); + match self.adapter.get_status(self.host_port) { + Ok(ConnectionStatus::Error(e)) => { + return std::task::Poll::Ready(Err(std::io::Error::new(e, "io error"))); + } + Err(e) => { + return std::task::Poll::Ready(Err(e)); + } + _ => {} + } + let p = self.host_port; + match self.adapter.queue_send(buf, p) { + Ok(_) => Poll::Ready(Ok(buf.len())), + Err(e) => Poll::Ready(Err(e)), + } + } + + fn poll_flush( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let future = async { + match self.adapter.write_buffer_flush().await { + Ok(_) => Ok(()), + Err(e) => Err(e), + } + }; + + // Pin the future and poll it + futures::pin_mut!(future); + future.poll(cx) + } + + fn poll_shutdown( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + // Create a future that can be polled + let future = async { self.close().await }; + + // Pin the future and poll it + futures::pin_mut!(future); + future.poll(cx) + } +} diff --git a/tools/src/remotexpc.rs b/tools/src/remotexpc.rs index 9aa5c4c..7f79180 100644 --- a/tools/src/remotexpc.rs +++ b/tools/src/remotexpc.rs @@ -2,7 +2,10 @@ // Print out all the RemoteXPC services use clap::{Arg, Command}; -use idevice::{core_device_proxy::CoreDeviceProxy, xpc::RemoteXpcClient, IdeviceService}; +use idevice::{ + core_device_proxy::CoreDeviceProxy, tcp::stream::AdapterStream, xpc::RemoteXpcClient, + IdeviceService, +}; mod common; @@ -63,10 +66,13 @@ async fn main() { let rsd_port = proxy.handshake.server_rsd_port; let mut adapter = proxy.create_software_tunnel().expect("no software tunnel"); - adapter.connect(rsd_port).await.expect("no RSD connect"); + adapter.pcap("new_xpc.pcap").await.unwrap(); + let conn = AdapterStream::connect(&mut adapter, rsd_port) + .await + .expect("no RSD connect"); // Make the connection to RemoteXPC - let mut client = RemoteXpcClient::new(Box::new(adapter)).await.unwrap(); + let mut client = RemoteXpcClient::new(Box::new(conn)).await.unwrap(); println!("{:#?}", client.do_handshake().await); }