From 5bd2c3f653432f792fa4438fb2bd451edf8b9737 Mon Sep 17 00:00:00 2001 From: Will Greenberg Date: Tue, 14 Nov 2023 20:22:39 -0800 Subject: [PATCH] various cleanups, fix client disconnect flow --- src/main.rs | 83 +++++++++++++++++++++++++++++------------------------ 1 file changed, 45 insertions(+), 38 deletions(-) diff --git a/src/main.rs b/src/main.rs index ad6af29..9772a33 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,11 +1,7 @@ use std::fs::File; -use std::io::Cursor; -use std::io::Read; -use std::io::Write; -use std::mem; -use std::io; -use std::net::TcpListener; -use std::sync::mpsc; +use std::io::{Cursor, Read, Write}; +use std::net::{TcpListener, TcpStream}; +use std::sync::{Arc, Mutex}; use bytes::{Buf, BufMut}; use std::os::fd::AsRawFd; use std::thread; @@ -22,7 +18,7 @@ const DIAG_IOCTL_SWITCH_LOGGING: u32 = 7; #[derive(Error, Debug)] enum DiagDeviceError { #[error("IO error {0}")] - IO(#[from] io::Error), + IO(#[from] std::io::Error), #[error("Failed to initialize /dev/diag: {0}")] InitializationFailed(String), #[error("Failed to read diag device: {0}")] @@ -41,7 +37,7 @@ fn enable_frame_readwrite(fd: i32, mode: i32) -> DiagResult<()> { fd, DIAG_IOCTL_SWITCH_LOGGING, &mut [mode, -1, 0] as *mut _, // diag_logging_mode_param_t - mem::size_of::<[i32; 3]>(), 0, 0, 0, 0 + std::mem::size_of::<[i32; 3]>(), 0, 0, 0, 0 ); if ret < 0 { let msg = format!("DIAG_IOCTL_SWITCH_LOGGING ioctl failed with error code {}", ret); @@ -133,51 +129,62 @@ impl DiagDevice { } } -fn main() -> io::Result<()> { +fn main() -> std::io::Result<()> { println!("Starting server"); let listener = TcpListener::bind("0.0.0.0:43555")?; - loop { - println!("waiting for client..."); - let (mut client_reader, _) = listener.accept()?; - let mut client_writer = client_reader.try_clone()?; + let client_mutex: Arc>> = Arc::new(Mutex::new(None)); - println!("client connected, initializing diag device..."); - let mut dev_reader = DiagDevice::new().unwrap(); - let mut dev_writer = dev_reader.try_clone().unwrap(); + let mut dev_reader = DiagDevice::new().unwrap(); + let mut dev_writer = dev_reader.try_clone().unwrap(); - let (reader_exit_tx, reader_exit_rx) = mpsc::channel::(); - let reader_handle = thread::spawn(move || { - loop { - if reader_exit_rx.try_recv().is_ok() { - return; - } - match dev_reader.read_response() { - Ok(Some(msgs)) => { - println!("writing {} messages to client...", msgs.len()); + let client_mutex_clone = client_mutex.clone(); + // Spawn a thread to continuously read from the diag device, sending + // messages to the client (if any) + thread::spawn(move || { + loop { + match dev_reader.read_response() { + Ok(Some(msgs)) => { + if let Some(client_writer) = client_mutex_clone.lock().unwrap().as_mut() { + println!("> Writing {} diag messages to client", msgs.len()); for msg in msgs { client_writer.write_all(&msg).unwrap(); } - }, - Ok(None) => {}, - Err(err) => { - println!("dev reader thread err: {}", err); - return; - }, - } + } + }, + Ok(None) => {}, + Err(err) => { + println!("Unable to read from /dev/diag: {}", err); + return; + }, } - }); + } + }); + + // Accept connections from clients, writing any data received to the diag device + loop { + println!("Waiting for client"); + let (mut client_reader, _) = listener.accept()?; + + println!("Client connected"); + let client_writer = client_reader.try_clone()?; + { + let mut client_writer_mutex = client_mutex.lock().unwrap(); + *client_writer_mutex = Some(client_writer); + } let mut buf = vec![0; BUFFER_LEN]; loop { let bytes_read = client_reader.read(&mut buf).unwrap(); if bytes_read == 0 { - println!("client disconnected, waiting for thread to exit..."); - reader_exit_tx.send(true).unwrap(); - reader_handle.join().unwrap(); + println!("Client disconnected"); + { + let mut client_writer_mutex = client_mutex.lock().unwrap(); + *client_writer_mutex = None; + } break; } - println!("writing {} bytes to diag device...", bytes_read); + println!("< Got {} bytes from client", bytes_read); dev_writer.write_request(&buf[0..bytes_read]).unwrap(); } }