diff --git a/src/main.rs b/src/main.rs index ea297cb..ad6af29 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,14 +1,15 @@ +use std::fs::File; use std::io::Cursor; +use std::io::Read; +use std::io::Write; use std::mem; use std::io; +use std::net::TcpListener; +use std::sync::mpsc; +use bytes::{Buf, BufMut}; use std::os::fd::AsRawFd; -use std::sync::Arc; +use std::thread; use thiserror::Error; -use tokio::fs::File; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::TcpListener; -use tokio::net::tcp::OwnedWriteHalf; -use tokio::sync::Mutex; type DiagResult = Result; @@ -20,7 +21,7 @@ const DIAG_IOCTL_SWITCH_LOGGING: u32 = 7; #[derive(Error, Debug)] enum DiagDeviceError { - #[error("IO error: {0}")] + #[error("IO error {0}")] IO(#[from] io::Error), #[error("Failed to initialize /dev/diag: {0}")] InitializationFailed(String), @@ -63,11 +64,11 @@ fn determine_use_mdm(fd: i32) -> DiagResult { } impl DiagDevice { - pub async fn new() -> DiagResult { + pub fn new() -> DiagResult { let file = File::options() .read(true) .write(true) - .open("/dev/diag").await?; + .open("/dev/diag")?; let fd = file.as_raw_fd(); enable_frame_readwrite(fd, MEMORY_DEVICE_MODE)?; @@ -79,97 +80,105 @@ impl DiagDevice { }) } - pub async fn read_response(&mut self) -> DiagResult>>> { + pub fn try_clone(&self) -> DiagResult { + Ok(DiagDevice { + file: self.file.try_clone()?, + use_mdm: self.use_mdm, + }) + } + + pub fn read_response(&mut self) -> DiagResult>>> { let mut buf = vec![0; BUFFER_LEN]; - let bytes_read = self.file.read(&mut buf).await?; + let bytes_read = self.file.read(&mut buf)?; if bytes_read < 4 { let msg = format!("read {} bytes from diag device, expected > 4", bytes_read); return Err(DiagDeviceError::DeviceReadFailed(msg)); } let mut reader = Cursor::new(buf); - // is this a USER_SPACE_DATA_TYPE? - if reader.read_i32().await? != USER_SPACE_DATA_TYPE { + if reader.get_i32_le() != USER_SPACE_DATA_TYPE { return Ok(None); } - let num_messages = reader.read_u32().await?; + let num_messages = reader.get_u32_le(); let mut messages = Vec::new(); for _ in 0..num_messages { - let msg_len = reader.read_u32().await? as usize; + let msg_len = reader.get_u32_le() as usize; let mut msg = vec![0; msg_len]; - reader.read_exact(&mut msg).await?; + reader.read_exact(&mut msg)?; messages.push(msg); } Ok(Some(messages)) } - pub async fn write_request(&mut self, req: &[u8]) -> DiagResult<()> { - let mut buf: Vec = Vec::with_capacity(req.len()); - buf.write_i32(USER_SPACE_DATA_TYPE).await?; + pub fn write_request(&mut self, req: &[u8]) -> DiagResult<()> { + let mut buf: Vec = vec![]; + buf.put_i32_le(USER_SPACE_DATA_TYPE); if self.use_mdm > 0 { - buf.write_u32(0xffffffff).await?; + buf.put_i32_le(-1); } buf.extend_from_slice(req); - self.file.write_all(&buf).await?; + unsafe { + let fd = self.file.as_raw_fd(); + let buf_ptr = buf.as_ptr() as *const libc::c_void; + let ret = libc::write(fd, buf_ptr, buf.len()); + if ret < 0 { + let msg = format!("write failed with error code {}", ret); + return Err(DiagDeviceError::DeviceReadFailed(msg)); + } + } Ok(()) } } -#[tokio::main] -async fn main() -> io::Result<()> { - println!("Initializing DIAG"); - let dev = Arc::new(Mutex::new(DiagDevice::new().await.unwrap())); - let clients: Arc>> = Arc::new(Mutex::new(Vec::new())); - - let dev_clone = dev.clone(); - let clients_clone = clients.clone(); - tokio::spawn(async move { - loop { - let mut dev_ = dev_clone.lock().await; - if let Some(msg) = dev_.read_response().await.unwrap() { - let mut clients_ = clients_clone.lock().await; - for client in clients_.iter_mut() { - for buf in &msg { - let _ = client.write(buf).await.unwrap(); - } - } - } - } - }); - +fn main() -> io::Result<()> { println!("Starting server"); - let listener = TcpListener::bind("0.0.0.0:1312").await?; + let listener = TcpListener::bind("0.0.0.0:43555")?; - // handle incoming clients loop { - let (socket, _) = listener.accept().await?; - let (mut read, write) = socket.into_split(); - let client_idx: usize; - { - let mut clients_ = clients.lock().await; - clients_.push(write); - client_idx = clients_.len(); - } - let dev_clone = dev.clone(); - let clients_clone = clients.clone(); - tokio::spawn(async move { - let mut buf = vec![0; BUFFER_LEN]; + println!("waiting for client..."); + let (mut client_reader, _) = listener.accept()?; + let mut client_writer = client_reader.try_clone()?; + + println!("client connected, initializing diag device..."); + let mut dev_reader = DiagDevice::new().unwrap(); + let mut dev_writer = dev_reader.try_clone().unwrap(); + + let (reader_exit_tx, reader_exit_rx) = mpsc::channel::(); + let reader_handle = thread::spawn(move || { loop { - let bytes_read = read.read(&mut buf).await.unwrap(); - if bytes_read == 0 { - let mut clients_ = clients_clone.lock().await; - clients_.remove(client_idx); - println!("client {} disconnected", client_idx); - break; + if reader_exit_rx.try_recv().is_ok() { + return; + } + match dev_reader.read_response() { + Ok(Some(msgs)) => { + println!("writing {} messages to client...", msgs.len()); + for msg in msgs { + client_writer.write_all(&msg).unwrap(); + } + }, + Ok(None) => {}, + Err(err) => { + println!("dev reader thread err: {}", err); + return; + }, } - println!("waiting to write {} byte diag request...", bytes_read); - let mut dev_ = dev_clone.lock().await; - dev_.write_request(&buf[0..bytes_read]).await.unwrap(); - println!("diag request complete"); } }); + + let mut buf = vec![0; BUFFER_LEN]; + loop { + let bytes_read = client_reader.read(&mut buf).unwrap(); + if bytes_read == 0 { + println!("client disconnected, waiting for thread to exit..."); + reader_exit_tx.send(true).unwrap(); + reader_handle.join().unwrap(); + break; + } + println!("writing {} bytes to diag device...", bytes_read); + dev_writer.write_request(&buf[0..bytes_read]).unwrap(); + } } }