Refactors in response to review comments

A few minor refactors, and a more major one that renames
RecordingStore's update_entry_qmdl_size to
update_current_entry_qmdl_size, since the only time we're ever updating
an entry's QMDL size is when it's the current one.
This commit is contained in:
Will Greenberg
2026-06-22 11:04:10 -07:00
committed by Brad Warren
parent 10f560b5e4
commit 9627cec737
6 changed files with 57 additions and 53 deletions
+13 -11
View File
@@ -185,7 +185,7 @@ impl DiagTask {
.await
.map_err(RecordingStoreError::WriteFileError)?;
}
self.stop_current_recording().await;
self.stop_current_recording(qmdl_store).await;
let qmdl_writer = Box::new(QmdlWriter::new(qmdl_gz_file));
let analysis_writer = AnalysisWriter::new(analysis_file, &self.analyzer_config)
.await
@@ -207,7 +207,7 @@ impl DiagTask {
/// Stop recording, optionally annotating the entry with a reason.
async fn stop(&mut self, qmdl_store: &mut RecordingStore, reason: Option<String>) {
self.stop_current_recording().await;
self.stop_current_recording(qmdl_store).await;
if let Some(reason) = reason
&& let Err(e) = qmdl_store.set_current_stop_reason(reason).await
{
@@ -294,7 +294,7 @@ impl DiagTask {
}
}
async fn stop_current_recording(&mut self) {
async fn stop_current_recording(&mut self, qmdl_store: &mut RecordingStore) {
let mut state = DiagState::Stopped;
std::mem::swap(&mut self.state, &mut state);
if let DiagState::Recording {
@@ -304,13 +304,17 @@ impl DiagTask {
} = state
{
match (qmdl_writer.close().await, analysis_writer.close().await) {
(Ok(()), Ok(())) => {}
(Ok(size), Ok(())) => {
if let Err(err) = qmdl_store.update_current_entry_qmdl_size(size).await {
error!("failed to update QMDL entry size while closing it: {err:?}");
}
}
(qmdl_result, analysis_result) => {
if let Err(err) = qmdl_result {
error!("failed to close QmdlWriter: {:?}", err);
error!("failed to close QmdlWriter: {err:?}");
}
if let Err(err) = analysis_result {
error!("failed to close AnalysisWriter: {:?}", err);
error!("failed to close AnalysisWriter: {err:?}");
}
panic!();
}
@@ -387,10 +391,7 @@ impl DiagTask {
"total QMDL bytes written: {}, updating manifest...",
file_size
);
let index = qmdl_store.current_entry.expect(
"DiagDevice had qmdl_writer, but QmdlStore didn't have current entry???",
);
if let Err(e) = qmdl_store.update_entry_qmdl_size(index, file_size).await {
if let Err(e) = qmdl_store.update_current_entry_qmdl_size(file_size).await {
let reason = format!("failed to update manifest (disk full?): {e}");
error!("{reason}");
self.stop(qmdl_store, Some(reason)).await;
@@ -508,7 +509,8 @@ pub fn run_diag_read_thread(
// time to go
Some(DiagDeviceCtrlMessage::Exit) | None => {
info!("Diag reader thread exiting...");
diag_task.stop_current_recording().await;
let mut qmdl_store = qmdl_store_lock.write().await;
diag_task.stop_current_recording(qmdl_store.deref_mut()).await;
return Ok(())
},
Some(DiagDeviceCtrlMessage::DeleteEntry { name, response_tx }) => {
+7 -12
View File
@@ -388,12 +388,14 @@ impl RecordingStore {
}
}
// Sets the given entry's size and updates the last_message_time to now, updating the manifest
pub async fn update_entry_qmdl_size(
// Sets the current entry's size and updates the last_message_time to now, updating the manifest
pub async fn update_current_entry_qmdl_size(
&mut self,
entry_index: usize,
size_bytes: usize,
) -> Result<(), RecordingStoreError> {
let Some(entry_index) = self.current_entry else {
return Err(RecordingStoreError::NoCurrentEntry);
};
self.manifest.entries[entry_index].qmdl_size_bytes = size_bytes;
self.manifest.entries[entry_index].last_message_time =
Some(rayhunter::clock::get_adjusted_now());
@@ -594,10 +596,7 @@ mod tests {
.is_none()
);
store
.update_entry_qmdl_size(entry_index, 1000)
.await
.unwrap();
store.update_current_entry_qmdl_size(1000).await.unwrap();
let (entry_index, entry) = store
.entry_for_name(&store.manifest.entries[entry_index].name)
.unwrap();
@@ -620,11 +619,7 @@ mod tests {
let dir = make_temp_dir();
let mut store = RecordingStore::create(dir.path()).await.unwrap();
let _ = store.new_entry(GpsMode::Disabled).await.unwrap();
let entry_index = store.current_entry.unwrap();
store
.update_entry_qmdl_size(entry_index, 1000)
.await
.unwrap();
store.update_current_entry_qmdl_size(1000).await.unwrap();
let store = RecordingStore::create(dir.path()).await.unwrap();
assert_eq!(store.manifest.entries.len(), 0);
}
+26 -15
View File
@@ -6,20 +6,22 @@ use axum::Json;
use axum::body::Body;
use axum::extract::Path;
use axum::extract::State;
use axum::http::header::{self, CONTENT_LENGTH, CONTENT_TYPE};
use axum::http::header::{self, CONTENT_TYPE};
use axum::http::{HeaderValue, StatusCode};
use axum::response::{IntoResponse, Response};
use chrono::{DateTime, Local};
use futures::TryStreamExt;
use log::{error, warn};
use rayhunter::qmdl::QmdlMessageReader;
use serde::{Deserialize, Serialize};
use std::pin::pin;
use std::sync::Arc;
use tokio::fs::write;
use tokio::io::AsyncReadExt;
use tokio::io::copy;
use tokio::io::duplex;
use tokio::sync::RwLock;
use tokio::sync::mpsc::Sender;
use tokio_util::compat::FuturesAsyncReadCompatExt;
use tokio_util::compat::FuturesAsyncWriteCompatExt;
use tokio_util::io::ReaderStream;
use tokio_util::sync::CancellationToken;
@@ -70,7 +72,7 @@ pub async fn get_qmdl(
) -> Result<Response, (StatusCode, String)> {
let qmdl_idx = qmdl_name.trim_end_matches(".qmdl");
let qmdl_store = state.qmdl_store_lock.read().await;
let (entry_index, entry) = qmdl_store.entry_for_name(qmdl_idx).ok_or((
let (entry_index, _) = qmdl_store.entry_for_name(qmdl_idx).ok_or((
StatusCode::NOT_FOUND,
format!("couldn't find qmdl file with name {qmdl_idx}"),
))?;
@@ -91,10 +93,7 @@ pub async fn get_qmdl(
)
})?;
let headers = [
(CONTENT_TYPE, "application/octet-stream"),
(CONTENT_LENGTH, &entry.qmdl_size_bytes.to_string()),
];
let headers = [(CONTENT_TYPE, "application/octet-stream")];
let body = Body::from_stream(qmdl_reader.into_qmdl_stream());
Ok((headers, body).into_response())
}
@@ -341,7 +340,7 @@ pub async fn get_zip(
Path(entry_name): Path<String>,
) -> Result<Response, (StatusCode, String)> {
let qmdl_idx = entry_name.trim_end_matches(".zip").to_owned();
let (entry_index, compressed, qmdl_file_size) = {
let entry_index = {
let qmdl_store = state.qmdl_store_lock.read().await;
let (entry_index, entry) = qmdl_store.entry_for_name(&qmdl_idx).ok_or((
StatusCode::NOT_FOUND,
@@ -355,7 +354,7 @@ pub async fn get_zip(
));
}
(entry_index, entry.compressed, entry.qmdl_size_bytes)
entry_index
};
let qmdl_store_lock = state.qmdl_store_lock.clone();
@@ -384,8 +383,19 @@ pub async fn get_zip(
continue;
};
/*
* `qmdl_compressed` is always false here because even if the
* QMDL was already compressed, we decompress it before zipping.
* This is for two reasons
* 1. If this is the current entry, it's still being written and
* lacks a GZIP footer. If we zipped up this partial .gz
* file, some software might consider it damaged and refuse to
* extract it.
* 2. Zipping an already-GZIP'd file is redundant and
* inconvenient for the user.
*/
let zip_entry = ZipEntryBuilder::new(
file_kind.get_filename(&qmdl_idx, compressed).into(),
file_kind.get_filename(&qmdl_idx, false).into(),
Compression::Stored,
);
// FuturesAsyncWriteCompatExt::compat_write because async-zip's entrystream does
@@ -393,10 +403,11 @@ pub async fn get_zip(
// once https://github.com/Majored/rs-async-zip/pull/160 is released.
let mut entry_writer = zip.write_entry_stream(zip_entry).await?.compat_write();
// Truncating to qmdl_size_bytes is an attempt to ignore partial writes by the diag
// thread.
if file_kind == FileKind::Qmdl {
copy(&mut file.take(qmdl_file_size as u64), &mut entry_writer).await?;
let reader = QmdlMessageReader::new(&mut file).await?;
let stream = reader.into_qmdl_stream();
let mut reader = pin!(stream.into_async_read().compat());
copy(&mut reader, &mut entry_writer).await?;
} else {
copy(&mut file, &mut entry_writer).await?;
}
@@ -575,7 +586,7 @@ mod tests {
let entry_name = entry.name.clone();
store
.update_entry_qmdl_size(current_entry, qmdl_file_size)
.update_current_entry_qmdl_size(qmdl_file_size)
.await
.unwrap();
entry_name
@@ -656,7 +667,7 @@ mod tests {
assert_eq!(
filenames,
vec![
format!("{entry_name}.qmdl.gz"),
format!("{entry_name}.qmdl"),
format!("{entry_name}-gps.ndjson"),
format!("{entry_name}.pcapng"),
]
+1 -1
View File
@@ -300,7 +300,7 @@ mod tests {
analysis_file.flush().await.unwrap();
let entry_index = store.current_entry.unwrap();
let name = store.manifest.entries[entry_index].name.clone();
store.update_entry_qmdl_size(entry_index, 17).await.unwrap();
store.update_current_entry_qmdl_size(17).await.unwrap();
store.close_current_entry().await.unwrap();
(Arc::new(RwLock::new(store)), name)
}
-7
View File
@@ -1,6 +1,5 @@
//! Diag protocol serialization/deserialization
use bytes::Bytes;
use chrono::{DateTime, FixedOffset};
use crc::{Algorithm, Crc};
use deku::prelude::*;
@@ -97,12 +96,6 @@ impl MessagesContainer {
}
}
impl From<MessagesContainer> for Bytes {
fn from(value: MessagesContainer) -> Self {
value.to_bytes().unwrap().into()
}
}
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
pub struct HdlcEncapsulatedMessage {
pub len: u32,
+10 -7
View File
@@ -45,14 +45,14 @@ where
pub async fn write_container(&mut self, container: &MessagesContainer) -> std::io::Result<()> {
for msg in &container.messages {
self.writer.write_all(&msg.data).await?;
self.writer.flush().await?;
}
self.writer.flush().await?;
Ok(())
}
pub async fn close(mut self) -> std::io::Result<()> {
pub async fn close(mut self) -> std::io::Result<usize> {
self.writer.shutdown().await?;
Ok(())
self.size().await
}
}
@@ -298,7 +298,7 @@ mod test {
hdlcs.iter().map(|hdlc| hdlc.data.len()).collect(),
)
};
for truncated_hdlc_i in 1..hdlcs.len() - 1 {
for truncated_hdlc_i in 1..hdlcs.len() {
let whole_bytes: usize = message_lengths.iter().take(truncated_hdlc_i).sum();
for truncated_byte in 1..message_lengths[truncated_hdlc_i] {
let mut truncated_bytes = Cursor::new(&bytes[0..whole_bytes + truncated_byte]);
@@ -360,15 +360,18 @@ mod test {
async fn run_compressed_reading_and_writing_tests(do_close: bool) {
let containers = get_test_containers();
let mut buf = Cursor::new(Vec::new());
{
let writer_size = {
let mut writer = QmdlWriter::new(&mut buf);
for container in &containers {
writer.write_container(&container).await.unwrap();
}
if do_close {
writer.close().await.unwrap();
writer.close().await.unwrap()
} else {
writer.size().await.unwrap()
}
}
};
assert_eq!(buf.position() as usize, writer_size);
buf.set_position(0);
let mut reader = QmdlMessageReader::new(buf).await.unwrap();
assert!(reader.is_compressed());