Refactor diag thread to have full control over the QMDL store

Fixes #269. Refactor also pull diag thread logic out into state machine
object for better encapsulation and reuse.
This commit is contained in:
Sashanoraa
2025-07-27 16:11:23 -04:00
committed by Markus Unterwaditzer
parent 6b109a9d76
commit 398997af67
4 changed files with 257 additions and 140 deletions
+241 -123
View File
@@ -1,3 +1,4 @@
use std::ops::DerefMut;
use std::pin::pin;
use std::sync::Arc;
@@ -9,135 +10,261 @@ use axum::response::{IntoResponse, Response};
use futures::{StreamExt, TryStreamExt};
use log::{debug, error, info, warn};
use rayhunter::analysis::analyzer::AnalyzerConfig;
use rayhunter::diag::DataType;
use rayhunter::diag::{DataType, MessagesContainer};
use rayhunter::diag_device::DiagDevice;
use rayhunter::qmdl::QmdlWriter;
use tokio::fs::File;
use tokio::sync::RwLock;
use tokio::sync::mpsc::{Receiver, Sender};
use tokio::sync::{RwLock, oneshot};
use tokio_util::io::ReaderStream;
use tokio_util::task::TaskTracker;
use crate::analysis::{AnalysisCtrlMessage, AnalysisWriter};
use crate::display;
use crate::qmdl_store::{EntryType, RecordingStore, RecordingStoreError};
use crate::qmdl_store::{RecordingStore, RecordingStoreError};
use crate::server::ServerState;
pub enum DiagDeviceCtrlMessage {
StopRecording,
StartRecording,
DeleteEntry {
name: String,
response_tx: oneshot::Sender<Result<(), RecordingStoreError>>,
},
DeleteAllEntries {
response_tx: oneshot::Sender<Result<(), RecordingStoreError>>,
},
Exit,
}
pub struct DiagTask {
ui_update_sender: Sender<display::DisplayState>,
analysis_sender: Sender<AnalysisCtrlMessage>,
analyzer_config: AnalyzerConfig,
state: DiagState,
}
enum DiagState {
Recording {
qmdl_writer: QmdlWriter<File>,
analysis_writer: Box<AnalysisWriter>,
},
Stopped,
}
impl DiagTask {
fn new(
ui_update_sender: Sender<display::DisplayState>,
analysis_sender: Sender<AnalysisCtrlMessage>,
analyzer_config: AnalyzerConfig,
) -> Self {
Self {
ui_update_sender,
analysis_sender,
analyzer_config,
state: DiagState::Stopped,
}
}
/// Start recording
async fn start(&mut self, qmdl_store: &mut RecordingStore) {
let (qmdl_file, analysis_file) = qmdl_store
.new_entry()
.await
.expect("failed creating QMDL file entry");
self.stop_current_recording().await;
let qmdl_writer = QmdlWriter::new(qmdl_file);
let analysis_writer = AnalysisWriter::new(analysis_file, &self.analyzer_config)
.await
.map(Box::new)
.expect("failed to write to analysis file");
self.state = DiagState::Recording {
qmdl_writer,
analysis_writer,
};
if let Err(e) = self
.ui_update_sender
.send(display::DisplayState::Recording)
.await
{
warn!("couldn't send ui update message: {e}");
}
}
/// Stop recording
async fn stop(&mut self, qmdl_store: &mut RecordingStore) {
self.stop_current_recording().await;
if let Some((_, entry)) = qmdl_store.get_current_entry() {
if let Err(e) = self
.analysis_sender
.send(AnalysisCtrlMessage::RecordingFinished(
entry.name.to_string(),
))
.await
{
warn!("couldn't send analysis message: {e}");
}
}
if let Err(e) = qmdl_store.close_current_entry().await {
error!("couldn't close current entry: {e}");
}
if let Err(e) = self
.ui_update_sender
.send(display::DisplayState::Paused)
.await
{
warn!("couldn't send ui update message: {e}");
}
}
async fn delete_entry(
&mut self,
qmdl_store: &mut RecordingStore,
name: &str,
) -> Result<(), RecordingStoreError> {
if qmdl_store.is_current_entry(name) {
self.stop(qmdl_store).await;
}
let res = qmdl_store.delete_entry(name).await;
if let Err(e) = res.as_ref() {
error!("Error deleting QMDL entry {e}");
}
res
}
async fn delete_all_entries(
&mut self,
qmdl_store: &mut RecordingStore,
) -> Result<(), RecordingStoreError> {
self.stop(qmdl_store).await;
let res = qmdl_store.delete_all_entries().await;
if let Err(e) = res.as_ref() {
error!("Error deleting QMDL entries {e}");
}
res
}
async fn stop_current_recording(&mut self) {
let mut state = DiagState::Stopped;
std::mem::swap(&mut self.state, &mut state);
if let DiagState::Recording {
analysis_writer, ..
} = state
{
analysis_writer
.close()
.await
.expect("failed to close analysis writer");
}
}
async fn process_container(
&mut self,
qmdl_store: &mut RecordingStore,
container: MessagesContainer,
) {
if container.data_type != DataType::UserSpace {
debug!("skipping non-userspace diag messages...");
return;
}
// keep track of how many bytes were written to the QMDL file so we can read
// a valid block of data from it in the HTTP server
if let DiagState::Recording {
qmdl_writer,
analysis_writer,
} = &mut self.state
{
qmdl_writer
.write_container(&container)
.await
.expect("failed to write to QMDL writer");
debug!(
"total QMDL bytes written: {}, updating manifest...",
qmdl_writer.total_written
);
let index = qmdl_store
.current_entry
.expect("DiagDevice had qmdl_writer, but QmdlStore didn't have current entry???");
qmdl_store
.update_entry_qmdl_size(index, qmdl_writer.total_written)
.await
.expect("failed to update qmdl file size");
debug!("done!");
let heuristic_warning = analysis_writer
.analyze(container)
.await
.expect("failed to analyze container");
if heuristic_warning {
info!("a heuristic triggered on this run!");
self.ui_update_sender
.send(display::DisplayState::WarningDetected)
.await
.expect("couldn't send ui update message: {}");
}
} else {
debug!("no qmdl_writer set, continuing...");
}
}
}
#[allow(clippy::too_many_arguments)]
pub fn run_diag_read_thread(
task_tracker: &TaskTracker,
mut dev: DiagDevice,
mut qmdl_file_rx: Receiver<DiagDeviceCtrlMessage>,
qmdl_file_tx: Sender<DiagDeviceCtrlMessage>,
ui_update_sender: Sender<display::DisplayState>,
qmdl_store_lock: Arc<RwLock<RecordingStore>>,
analysis_sender: Sender<AnalysisCtrlMessage>,
analyzer_config: AnalyzerConfig,
) {
task_tracker.spawn(async move {
let (initial_qmdl_file, initial_analysis_file) = qmdl_store_lock.write().await.new_entry().await.expect("failed creating QMDL file entry");
let mut maybe_qmdl_writer: Option<QmdlWriter<File>> = Some(QmdlWriter::new(initial_qmdl_file));
let mut diag_stream = pin!(dev.as_stream().into_stream());
let mut maybe_analysis_writer = Some(AnalysisWriter::new(initial_analysis_file, &analyzer_config).await
.expect("failed to create analysis writer"));
let mut diag_task = DiagTask::new(ui_update_sender, analysis_sender, analyzer_config);
qmdl_file_tx
.send(DiagDeviceCtrlMessage::StartRecording)
.await
.unwrap();
loop {
tokio::select! {
msg = qmdl_file_rx.recv() => {
match msg {
Some(DiagDeviceCtrlMessage::StartRecording) => {
let mut qmdl_store = qmdl_store_lock.write().await;
let (qmdl_file, new_analysis_file) = match qmdl_store.new_entry().await {
Ok(x) => x,
Err(e) => {
error!("couldn't create new qmdl entry: {e}");
continue;
}
};
maybe_qmdl_writer = Some(QmdlWriter::new(qmdl_file));
if let Some(analysis_writer) = maybe_analysis_writer {
analysis_writer.close().await.expect("failed to close analysis writer");
}
maybe_analysis_writer = Some(AnalysisWriter::new(new_analysis_file, &analyzer_config).await
.expect("failed to write to analysis file"));
if let Err(e) = ui_update_sender.send(display::DisplayState::Recording).await {
warn!("couldn't send ui update message: {e}");
}
diag_task.start(qmdl_store.deref_mut()).await;
},
Some(DiagDeviceCtrlMessage::StopRecording) => {
let mut qmdl_store = qmdl_store_lock.write().await;
if let Some((_, entry)) = qmdl_store.get_current_entry() {
if let Err(e) = analysis_sender
.send(AnalysisCtrlMessage::RecordingFinished(
entry.name.to_string(),
))
.await {
warn!("couldn't send analysis message: {e}");
}
}
if let Err(e) = qmdl_store.close_current_entry().await {
error!("couldn't close current entry: {e}");
}
maybe_qmdl_writer = None;
if let Some(analysis_writer) = maybe_analysis_writer {
analysis_writer.close().await.expect("failed to close analysis writer");
}
maybe_analysis_writer = None;
if let Err(e) = ui_update_sender.send(display::DisplayState::Paused).await {
warn!("couldn't send ui update message: {e}");
}
diag_task.stop(qmdl_store.deref_mut()).await;
},
// None means all the Senders have been dropped, so it's
// time to go
Some(DiagDeviceCtrlMessage::Exit) | None => {
info!("Diag reader thread exiting...");
if let Some(analysis_writer) = maybe_analysis_writer {
analysis_writer.close().await.expect("failed to close analysis writer");
}
diag_task.stop_current_recording().await;
return Ok(())
},
Some(DiagDeviceCtrlMessage::DeleteEntry { name, response_tx }) => {
let mut qmdl_store = qmdl_store_lock.write().await;
let resp = diag_task.delete_entry(qmdl_store.deref_mut(), name.as_str()).await;
if response_tx.send(resp).is_err() {
error!("Failed to send delete entry respons, receiver dropped");
}
},
Some(DiagDeviceCtrlMessage::DeleteAllEntries { response_tx }) => {
let mut qmdl_store = qmdl_store_lock.write().await;
let resp = diag_task.delete_all_entries(qmdl_store.deref_mut()).await;
if response_tx.send(resp).is_err() {
error!("Failed to send delete all entries respons, receiver dropped");
}
},
}
}
maybe_container = diag_stream.next() => {
match maybe_container.unwrap() {
Ok(container) => {
if container.data_type != DataType::UserSpace {
debug!("skipping non-userspace diag messages...");
continue;
}
// keep track of how many bytes were written to the QMDL file so we can read
// a valid block of data from it in the HTTP server
if let Some(qmdl_writer) = maybe_qmdl_writer.as_mut() {
qmdl_writer.write_container(&container).await.expect("failed to write to QMDL writer");
debug!("total QMDL bytes written: {}, updating manifest...", qmdl_writer.total_written);
let mut qmdl_store = qmdl_store_lock.write().await;
let index = qmdl_store.current_entry.expect("DiagDevice had qmdl_writer, but QmdlStore didn't have current entry???");
qmdl_store.update_entry_qmdl_size(index, qmdl_writer.total_written).await
.expect("failed to update qmdl file size");
debug!("done!");
} else {
debug!("no qmdl_writer set, continuing...");
}
if let Some(analysis_writer) = maybe_analysis_writer.as_mut() {
let heuristic_warning = analysis_writer.analyze(container).await
.expect("failed to analyze container");
if heuristic_warning {
info!("a heuristic triggered on this run!");
ui_update_sender.send(display::DisplayState::WarningDetected).await
.expect("couldn't send ui update message: {}");
}
}
let mut qmdl_store = qmdl_store_lock.write().await;
diag_task.process_container(qmdl_store.deref_mut(), container).await
},
Err(err) => {
error!("error reading diag device: {err}");
@@ -150,6 +277,7 @@ pub fn run_diag_read_thread(
});
}
/// Start recording API for web thread
pub async fn start_recording(
State(state): State<Arc<ServerState>>,
) -> Result<(StatusCode, String), (StatusCode, String)> {
@@ -171,6 +299,7 @@ pub async fn start_recording(
Ok((StatusCode::ACCEPTED, "ok".to_string()))
}
/// Stop recording API for web thread
pub async fn stop_recording(
State(state): State<Arc<ServerState>>,
) -> Result<(StatusCode, String), (StatusCode, String)> {
@@ -197,8 +326,27 @@ pub async fn delete_recording(
if state.config.debug_mode {
return Err((StatusCode::FORBIDDEN, "server is in debug mode".to_string()));
}
let mut qmdl_store = state.qmdl_store_lock.write().await;
match qmdl_store.delete_entry(&qmdl_name).await {
let (response_tx, response_rx) = oneshot::channel();
state
.diag_device_ctrl_sender
.send(DiagDeviceCtrlMessage::DeleteEntry {
name: qmdl_name.clone(),
response_tx,
})
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("couldn't send delete entry message: {e}"),
)
})?;
match response_rx.await.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("failed to receive delete response: {e}"),
)
})? {
Ok(_) => Ok((StatusCode::ACCEPTED, "ok".to_string())),
Err(RecordingStoreError::NoSuchEntryError) => Err((
StatusCode::BAD_REQUEST,
format!("no recording with name {qmdl_name}"),
@@ -207,31 +355,6 @@ pub async fn delete_recording(
StatusCode::INTERNAL_SERVER_ERROR,
format!("couldn't delete recording: {e}"),
)),
Ok(entry_type) => {
if entry_type == EntryType::Current {
state
.diag_device_ctrl_sender
.send(DiagDeviceCtrlMessage::StopRecording)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("couldn't send stop recording message: {e}"),
)
})?;
state
.ui_update_sender
.send(display::DisplayState::Paused)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("couldn't send ui update message: {e}"),
)
})?;
}
Ok((StatusCode::ACCEPTED, "ok".to_string()))
}
}
}
@@ -241,34 +364,29 @@ pub async fn delete_all_recordings(
if state.config.debug_mode {
return Err((StatusCode::FORBIDDEN, "server is in debug mode".to_string()));
}
let (response_tx, response_rx) = oneshot::channel();
state
.diag_device_ctrl_sender
.send(DiagDeviceCtrlMessage::StopRecording)
.send(DiagDeviceCtrlMessage::DeleteAllEntries { response_tx })
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("couldn't send stop recording message: {e}"),
format!("couldn't send delete all entries message: {e}"),
)
})?;
let mut qmdl_store = state.qmdl_store_lock.write().await;
qmdl_store.delete_all_entries().await.map_err(|e| {
match response_rx.await.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("couldn't delete all recordings: {e}"),
format!("failed to receive delete all response: {e}"),
)
})?;
state
.ui_update_sender
.send(display::DisplayState::Paused)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("couldn't send ui update message: {e}"),
)
})?;
Ok((StatusCode::ACCEPTED, "ok".to_string()))
})? {
Ok(_) => Ok((StatusCode::ACCEPTED, "ok".to_string())),
Err(e) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("couldn't delete recordings: {e}"),
)),
}
}
pub async fn get_analysis_report(
+1 -1
View File
@@ -229,6 +229,7 @@ async fn run_with_config(
&task_tracker,
dev,
diag_rx,
diag_tx.clone(),
ui_update_tx.clone(),
qmdl_store_lock.clone(),
analysis_tx.clone(),
@@ -284,7 +285,6 @@ async fn run_with_config(
config,
qmdl_store_lock: qmdl_store_lock.clone(),
diag_device_ctrl_sender: diag_tx,
ui_update_sender: ui_update_tx,
analysis_status_lock,
analysis_sender: analysis_tx,
daemon_restart_tx: Arc::new(RwLock::new(Some(daemon_restart_tx))),
+14 -12
View File
@@ -56,12 +56,6 @@ pub struct ManifestEntry {
pub arch: Option<String>,
}
#[derive(PartialEq, Eq)]
pub enum EntryType {
Current,
Past,
}
impl ManifestEntry {
fn new() -> Self {
let now = Local::now();
@@ -347,23 +341,31 @@ impl RecordingStore {
Some((entry_index, &self.manifest.entries[entry_index]))
}
pub async fn delete_entry(&mut self, name: &str) -> Result<EntryType, RecordingStoreError> {
pub fn is_current_entry(&self, name: &str) -> bool {
match self.current_entry {
Some(idx) => match self.manifest.entries.get(idx) {
Some(entry) => entry.name == name,
None => false,
},
None => false,
}
}
pub async fn delete_entry(&mut self, name: &str) -> Result<(), RecordingStoreError> {
let entry_to_delete_idx = self
.manifest
.entries
.iter()
.position(|entry| entry.name == name)
.ok_or(RecordingStoreError::NoSuchEntryError)?;
let is_current = match self.current_entry {
match self.current_entry {
Some(current_entry) if current_entry == entry_to_delete_idx => {
self.close_current_entry().await?;
EntryType::Current
}
Some(current_entry) => {
self.current_entry = Some(current_entry - 1);
EntryType::Past
}
None => EntryType::Past,
None => {}
};
let entry_to_delete = self.manifest.entries.remove(entry_to_delete_idx);
self.write_manifest().await?;
@@ -375,7 +377,7 @@ impl RecordingStore {
remove_file_if_exists(&analysis_filepath)
.await
.map_err(RecordingStoreError::DeleteFileError)?;
Ok(is_current)
Ok(())
}
pub async fn delete_all_entries(&mut self) -> Result<(), RecordingStoreError> {
+1 -4
View File
@@ -18,18 +18,17 @@ use tokio::sync::{RwLock, oneshot};
use tokio_util::compat::FuturesAsyncWriteCompatExt;
use tokio_util::io::ReaderStream;
use crate::DiagDeviceCtrlMessage;
use crate::analysis::{AnalysisCtrlMessage, AnalysisStatus};
use crate::config::Config;
use crate::pcap::generate_pcap_data;
use crate::qmdl_store::RecordingStore;
use crate::{DiagDeviceCtrlMessage, display};
pub struct ServerState {
pub config_path: String,
pub config: Config,
pub qmdl_store_lock: Arc<RwLock<RecordingStore>>,
pub diag_device_ctrl_sender: Sender<DiagDeviceCtrlMessage>,
pub ui_update_sender: Sender<display::DisplayState>,
pub analysis_status_lock: Arc<RwLock<AnalysisStatus>>,
pub analysis_sender: Sender<AnalysisCtrlMessage>,
pub daemon_restart_tx: Arc<RwLock<Option<oneshot::Sender<()>>>>,
@@ -293,7 +292,6 @@ mod tests {
store_lock: Arc<RwLock<crate::qmdl_store::RecordingStore>>,
) -> Arc<ServerState> {
let (tx, _rx) = tokio::sync::mpsc::channel(1);
let (ui_tx, _ui_rx) = tokio::sync::mpsc::channel(1);
let (analysis_tx, _analysis_rx) = tokio::sync::mpsc::channel(1);
let analysis_status = {
@@ -306,7 +304,6 @@ mod tests {
config: Config::default(),
qmdl_store_lock: store_lock,
diag_device_ctrl_sender: tx,
ui_update_sender: ui_tx,
analysis_status_lock: Arc::new(RwLock::new(analysis_status)),
analysis_sender: analysis_tx,
daemon_restart_tx: Arc::new(RwLock::new(None)),