diff --git a/daemon/src/diag.rs b/daemon/src/diag.rs index 0faaa33..508a1b9 100644 --- a/daemon/src/diag.rs +++ b/daemon/src/diag.rs @@ -1,3 +1,4 @@ +use std::ops::DerefMut; use std::pin::pin; use std::sync::Arc; @@ -9,135 +10,261 @@ use axum::response::{IntoResponse, Response}; use futures::{StreamExt, TryStreamExt}; use log::{debug, error, info, warn}; use rayhunter::analysis::analyzer::AnalyzerConfig; -use rayhunter::diag::DataType; +use rayhunter::diag::{DataType, MessagesContainer}; use rayhunter::diag_device::DiagDevice; use rayhunter::qmdl::QmdlWriter; use tokio::fs::File; -use tokio::sync::RwLock; use tokio::sync::mpsc::{Receiver, Sender}; +use tokio::sync::{RwLock, oneshot}; use tokio_util::io::ReaderStream; use tokio_util::task::TaskTracker; use crate::analysis::{AnalysisCtrlMessage, AnalysisWriter}; use crate::display; -use crate::qmdl_store::{EntryType, RecordingStore, RecordingStoreError}; +use crate::qmdl_store::{RecordingStore, RecordingStoreError}; use crate::server::ServerState; pub enum DiagDeviceCtrlMessage { StopRecording, StartRecording, + DeleteEntry { + name: String, + response_tx: oneshot::Sender>, + }, + DeleteAllEntries { + response_tx: oneshot::Sender>, + }, Exit, } +pub struct DiagTask { + ui_update_sender: Sender, + analysis_sender: Sender, + analyzer_config: AnalyzerConfig, + state: DiagState, +} + +enum DiagState { + Recording { + qmdl_writer: QmdlWriter, + analysis_writer: Box, + }, + Stopped, +} + +impl DiagTask { + fn new( + ui_update_sender: Sender, + analysis_sender: Sender, + analyzer_config: AnalyzerConfig, + ) -> Self { + Self { + ui_update_sender, + analysis_sender, + analyzer_config, + state: DiagState::Stopped, + } + } + + /// Start recording + async fn start(&mut self, qmdl_store: &mut RecordingStore) { + let (qmdl_file, analysis_file) = qmdl_store + .new_entry() + .await + .expect("failed creating QMDL file entry"); + self.stop_current_recording().await; + let qmdl_writer = QmdlWriter::new(qmdl_file); + let analysis_writer = AnalysisWriter::new(analysis_file, &self.analyzer_config) + .await + .map(Box::new) + .expect("failed to write to analysis file"); + self.state = DiagState::Recording { + qmdl_writer, + analysis_writer, + }; + if let Err(e) = self + .ui_update_sender + .send(display::DisplayState::Recording) + .await + { + warn!("couldn't send ui update message: {e}"); + } + } + + /// Stop recording + async fn stop(&mut self, qmdl_store: &mut RecordingStore) { + self.stop_current_recording().await; + if let Some((_, entry)) = qmdl_store.get_current_entry() { + if let Err(e) = self + .analysis_sender + .send(AnalysisCtrlMessage::RecordingFinished( + entry.name.to_string(), + )) + .await + { + warn!("couldn't send analysis message: {e}"); + } + } + if let Err(e) = qmdl_store.close_current_entry().await { + error!("couldn't close current entry: {e}"); + } + if let Err(e) = self + .ui_update_sender + .send(display::DisplayState::Paused) + .await + { + warn!("couldn't send ui update message: {e}"); + } + } + + async fn delete_entry( + &mut self, + qmdl_store: &mut RecordingStore, + name: &str, + ) -> Result<(), RecordingStoreError> { + if qmdl_store.is_current_entry(name) { + self.stop(qmdl_store).await; + } + let res = qmdl_store.delete_entry(name).await; + if let Err(e) = res.as_ref() { + error!("Error deleting QMDL entry {e}"); + } + res + } + + async fn delete_all_entries( + &mut self, + qmdl_store: &mut RecordingStore, + ) -> Result<(), RecordingStoreError> { + self.stop(qmdl_store).await; + let res = qmdl_store.delete_all_entries().await; + if let Err(e) = res.as_ref() { + error!("Error deleting QMDL entries {e}"); + } + res + } + + async fn stop_current_recording(&mut self) { + let mut state = DiagState::Stopped; + std::mem::swap(&mut self.state, &mut state); + if let DiagState::Recording { + analysis_writer, .. + } = state + { + analysis_writer + .close() + .await + .expect("failed to close analysis writer"); + } + } + + async fn process_container( + &mut self, + qmdl_store: &mut RecordingStore, + container: MessagesContainer, + ) { + if container.data_type != DataType::UserSpace { + debug!("skipping non-userspace diag messages..."); + return; + } + // keep track of how many bytes were written to the QMDL file so we can read + // a valid block of data from it in the HTTP server + if let DiagState::Recording { + qmdl_writer, + analysis_writer, + } = &mut self.state + { + qmdl_writer + .write_container(&container) + .await + .expect("failed to write to QMDL writer"); + debug!( + "total QMDL bytes written: {}, updating manifest...", + qmdl_writer.total_written + ); + let index = qmdl_store + .current_entry + .expect("DiagDevice had qmdl_writer, but QmdlStore didn't have current entry???"); + qmdl_store + .update_entry_qmdl_size(index, qmdl_writer.total_written) + .await + .expect("failed to update qmdl file size"); + debug!("done!"); + let heuristic_warning = analysis_writer + .analyze(container) + .await + .expect("failed to analyze container"); + if heuristic_warning { + info!("a heuristic triggered on this run!"); + self.ui_update_sender + .send(display::DisplayState::WarningDetected) + .await + .expect("couldn't send ui update message: {}"); + } + } else { + debug!("no qmdl_writer set, continuing..."); + } + } +} + #[allow(clippy::too_many_arguments)] pub fn run_diag_read_thread( task_tracker: &TaskTracker, mut dev: DiagDevice, mut qmdl_file_rx: Receiver, + qmdl_file_tx: Sender, ui_update_sender: Sender, qmdl_store_lock: Arc>, analysis_sender: Sender, analyzer_config: AnalyzerConfig, ) { task_tracker.spawn(async move { - let (initial_qmdl_file, initial_analysis_file) = qmdl_store_lock.write().await.new_entry().await.expect("failed creating QMDL file entry"); - let mut maybe_qmdl_writer: Option> = Some(QmdlWriter::new(initial_qmdl_file)); let mut diag_stream = pin!(dev.as_stream().into_stream()); - let mut maybe_analysis_writer = Some(AnalysisWriter::new(initial_analysis_file, &analyzer_config).await - .expect("failed to create analysis writer")); + let mut diag_task = DiagTask::new(ui_update_sender, analysis_sender, analyzer_config); + qmdl_file_tx + .send(DiagDeviceCtrlMessage::StartRecording) + .await + .unwrap(); loop { tokio::select! { msg = qmdl_file_rx.recv() => { match msg { Some(DiagDeviceCtrlMessage::StartRecording) => { let mut qmdl_store = qmdl_store_lock.write().await; - let (qmdl_file, new_analysis_file) = match qmdl_store.new_entry().await { - Ok(x) => x, - Err(e) => { - error!("couldn't create new qmdl entry: {e}"); - continue; - } - }; - - maybe_qmdl_writer = Some(QmdlWriter::new(qmdl_file)); - - if let Some(analysis_writer) = maybe_analysis_writer { - analysis_writer.close().await.expect("failed to close analysis writer"); - } - - maybe_analysis_writer = Some(AnalysisWriter::new(new_analysis_file, &analyzer_config).await - .expect("failed to write to analysis file")); - - if let Err(e) = ui_update_sender.send(display::DisplayState::Recording).await { - warn!("couldn't send ui update message: {e}"); - } + diag_task.start(qmdl_store.deref_mut()).await; }, Some(DiagDeviceCtrlMessage::StopRecording) => { let mut qmdl_store = qmdl_store_lock.write().await; - if let Some((_, entry)) = qmdl_store.get_current_entry() { - if let Err(e) = analysis_sender - .send(AnalysisCtrlMessage::RecordingFinished( - entry.name.to_string(), - )) - .await { - warn!("couldn't send analysis message: {e}"); - } - } - if let Err(e) = qmdl_store.close_current_entry().await { - error!("couldn't close current entry: {e}"); - } - - maybe_qmdl_writer = None; - if let Some(analysis_writer) = maybe_analysis_writer { - analysis_writer.close().await.expect("failed to close analysis writer"); - } - maybe_analysis_writer = None; - - if let Err(e) = ui_update_sender.send(display::DisplayState::Paused).await { - warn!("couldn't send ui update message: {e}"); - } + diag_task.stop(qmdl_store.deref_mut()).await; }, // None means all the Senders have been dropped, so it's // time to go Some(DiagDeviceCtrlMessage::Exit) | None => { info!("Diag reader thread exiting..."); - if let Some(analysis_writer) = maybe_analysis_writer { - analysis_writer.close().await.expect("failed to close analysis writer"); - } + diag_task.stop_current_recording().await; return Ok(()) }, + Some(DiagDeviceCtrlMessage::DeleteEntry { name, response_tx }) => { + let mut qmdl_store = qmdl_store_lock.write().await; + let resp = diag_task.delete_entry(qmdl_store.deref_mut(), name.as_str()).await; + if response_tx.send(resp).is_err() { + error!("Failed to send delete entry respons, receiver dropped"); + } + }, + Some(DiagDeviceCtrlMessage::DeleteAllEntries { response_tx }) => { + let mut qmdl_store = qmdl_store_lock.write().await; + let resp = diag_task.delete_all_entries(qmdl_store.deref_mut()).await; + if response_tx.send(resp).is_err() { + error!("Failed to send delete all entries respons, receiver dropped"); + } + }, } } maybe_container = diag_stream.next() => { match maybe_container.unwrap() { Ok(container) => { - if container.data_type != DataType::UserSpace { - debug!("skipping non-userspace diag messages..."); - continue; - } - // keep track of how many bytes were written to the QMDL file so we can read - // a valid block of data from it in the HTTP server - if let Some(qmdl_writer) = maybe_qmdl_writer.as_mut() { - qmdl_writer.write_container(&container).await.expect("failed to write to QMDL writer"); - debug!("total QMDL bytes written: {}, updating manifest...", qmdl_writer.total_written); - let mut qmdl_store = qmdl_store_lock.write().await; - let index = qmdl_store.current_entry.expect("DiagDevice had qmdl_writer, but QmdlStore didn't have current entry???"); - qmdl_store.update_entry_qmdl_size(index, qmdl_writer.total_written).await - .expect("failed to update qmdl file size"); - debug!("done!"); - } else { - debug!("no qmdl_writer set, continuing..."); - } - - if let Some(analysis_writer) = maybe_analysis_writer.as_mut() { - let heuristic_warning = analysis_writer.analyze(container).await - .expect("failed to analyze container"); - if heuristic_warning { - info!("a heuristic triggered on this run!"); - ui_update_sender.send(display::DisplayState::WarningDetected).await - .expect("couldn't send ui update message: {}"); - } - } + let mut qmdl_store = qmdl_store_lock.write().await; + diag_task.process_container(qmdl_store.deref_mut(), container).await }, Err(err) => { error!("error reading diag device: {err}"); @@ -150,6 +277,7 @@ pub fn run_diag_read_thread( }); } +/// Start recording API for web thread pub async fn start_recording( State(state): State>, ) -> Result<(StatusCode, String), (StatusCode, String)> { @@ -171,6 +299,7 @@ pub async fn start_recording( Ok((StatusCode::ACCEPTED, "ok".to_string())) } +/// Stop recording API for web thread pub async fn stop_recording( State(state): State>, ) -> Result<(StatusCode, String), (StatusCode, String)> { @@ -197,8 +326,27 @@ pub async fn delete_recording( if state.config.debug_mode { return Err((StatusCode::FORBIDDEN, "server is in debug mode".to_string())); } - let mut qmdl_store = state.qmdl_store_lock.write().await; - match qmdl_store.delete_entry(&qmdl_name).await { + let (response_tx, response_rx) = oneshot::channel(); + state + .diag_device_ctrl_sender + .send(DiagDeviceCtrlMessage::DeleteEntry { + name: qmdl_name.clone(), + response_tx, + }) + .await + .map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("couldn't send delete entry message: {e}"), + ) + })?; + match response_rx.await.map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("failed to receive delete response: {e}"), + ) + })? { + Ok(_) => Ok((StatusCode::ACCEPTED, "ok".to_string())), Err(RecordingStoreError::NoSuchEntryError) => Err(( StatusCode::BAD_REQUEST, format!("no recording with name {qmdl_name}"), @@ -207,31 +355,6 @@ pub async fn delete_recording( StatusCode::INTERNAL_SERVER_ERROR, format!("couldn't delete recording: {e}"), )), - Ok(entry_type) => { - if entry_type == EntryType::Current { - state - .diag_device_ctrl_sender - .send(DiagDeviceCtrlMessage::StopRecording) - .await - .map_err(|e| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("couldn't send stop recording message: {e}"), - ) - })?; - state - .ui_update_sender - .send(display::DisplayState::Paused) - .await - .map_err(|e| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("couldn't send ui update message: {e}"), - ) - })?; - } - Ok((StatusCode::ACCEPTED, "ok".to_string())) - } } } @@ -241,34 +364,29 @@ pub async fn delete_all_recordings( if state.config.debug_mode { return Err((StatusCode::FORBIDDEN, "server is in debug mode".to_string())); } + let (response_tx, response_rx) = oneshot::channel(); state .diag_device_ctrl_sender - .send(DiagDeviceCtrlMessage::StopRecording) + .send(DiagDeviceCtrlMessage::DeleteAllEntries { response_tx }) .await .map_err(|e| { ( StatusCode::INTERNAL_SERVER_ERROR, - format!("couldn't send stop recording message: {e}"), + format!("couldn't send delete all entries message: {e}"), ) })?; - let mut qmdl_store = state.qmdl_store_lock.write().await; - qmdl_store.delete_all_entries().await.map_err(|e| { + match response_rx.await.map_err(|e| { ( StatusCode::INTERNAL_SERVER_ERROR, - format!("couldn't delete all recordings: {e}"), + format!("failed to receive delete all response: {e}"), ) - })?; - state - .ui_update_sender - .send(display::DisplayState::Paused) - .await - .map_err(|e| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("couldn't send ui update message: {e}"), - ) - })?; - Ok((StatusCode::ACCEPTED, "ok".to_string())) + })? { + Ok(_) => Ok((StatusCode::ACCEPTED, "ok".to_string())), + Err(e) => Err(( + StatusCode::INTERNAL_SERVER_ERROR, + format!("couldn't delete recordings: {e}"), + )), + } } pub async fn get_analysis_report( diff --git a/daemon/src/main.rs b/daemon/src/main.rs index 9da6b81..31c4ec5 100644 --- a/daemon/src/main.rs +++ b/daemon/src/main.rs @@ -229,6 +229,7 @@ async fn run_with_config( &task_tracker, dev, diag_rx, + diag_tx.clone(), ui_update_tx.clone(), qmdl_store_lock.clone(), analysis_tx.clone(), @@ -284,7 +285,6 @@ async fn run_with_config( config, qmdl_store_lock: qmdl_store_lock.clone(), diag_device_ctrl_sender: diag_tx, - ui_update_sender: ui_update_tx, analysis_status_lock, analysis_sender: analysis_tx, daemon_restart_tx: Arc::new(RwLock::new(Some(daemon_restart_tx))), diff --git a/daemon/src/qmdl_store.rs b/daemon/src/qmdl_store.rs index 4fbc700..0c6d036 100644 --- a/daemon/src/qmdl_store.rs +++ b/daemon/src/qmdl_store.rs @@ -56,12 +56,6 @@ pub struct ManifestEntry { pub arch: Option, } -#[derive(PartialEq, Eq)] -pub enum EntryType { - Current, - Past, -} - impl ManifestEntry { fn new() -> Self { let now = Local::now(); @@ -347,23 +341,31 @@ impl RecordingStore { Some((entry_index, &self.manifest.entries[entry_index])) } - pub async fn delete_entry(&mut self, name: &str) -> Result { + pub fn is_current_entry(&self, name: &str) -> bool { + match self.current_entry { + Some(idx) => match self.manifest.entries.get(idx) { + Some(entry) => entry.name == name, + None => false, + }, + None => false, + } + } + + pub async fn delete_entry(&mut self, name: &str) -> Result<(), RecordingStoreError> { let entry_to_delete_idx = self .manifest .entries .iter() .position(|entry| entry.name == name) .ok_or(RecordingStoreError::NoSuchEntryError)?; - let is_current = match self.current_entry { + match self.current_entry { Some(current_entry) if current_entry == entry_to_delete_idx => { self.close_current_entry().await?; - EntryType::Current } Some(current_entry) => { self.current_entry = Some(current_entry - 1); - EntryType::Past } - None => EntryType::Past, + None => {} }; let entry_to_delete = self.manifest.entries.remove(entry_to_delete_idx); self.write_manifest().await?; @@ -375,7 +377,7 @@ impl RecordingStore { remove_file_if_exists(&analysis_filepath) .await .map_err(RecordingStoreError::DeleteFileError)?; - Ok(is_current) + Ok(()) } pub async fn delete_all_entries(&mut self) -> Result<(), RecordingStoreError> { diff --git a/daemon/src/server.rs b/daemon/src/server.rs index 8008c17..b733b4a 100644 --- a/daemon/src/server.rs +++ b/daemon/src/server.rs @@ -18,18 +18,17 @@ use tokio::sync::{RwLock, oneshot}; use tokio_util::compat::FuturesAsyncWriteCompatExt; use tokio_util::io::ReaderStream; +use crate::DiagDeviceCtrlMessage; use crate::analysis::{AnalysisCtrlMessage, AnalysisStatus}; use crate::config::Config; use crate::pcap::generate_pcap_data; use crate::qmdl_store::RecordingStore; -use crate::{DiagDeviceCtrlMessage, display}; pub struct ServerState { pub config_path: String, pub config: Config, pub qmdl_store_lock: Arc>, pub diag_device_ctrl_sender: Sender, - pub ui_update_sender: Sender, pub analysis_status_lock: Arc>, pub analysis_sender: Sender, pub daemon_restart_tx: Arc>>>, @@ -293,7 +292,6 @@ mod tests { store_lock: Arc>, ) -> Arc { let (tx, _rx) = tokio::sync::mpsc::channel(1); - let (ui_tx, _ui_rx) = tokio::sync::mpsc::channel(1); let (analysis_tx, _analysis_rx) = tokio::sync::mpsc::channel(1); let analysis_status = { @@ -306,7 +304,6 @@ mod tests { config: Config::default(), qmdl_store_lock: store_lock, diag_device_ctrl_sender: tx, - ui_update_sender: ui_tx, analysis_status_lock: Arc::new(RwLock::new(analysis_status)), analysis_sender: analysis_tx, daemon_restart_tx: Arc::new(RwLock::new(None)),