From 0cd70ad73c862ea043f6ff3e918cc81764b894c3 Mon Sep 17 00:00:00 2001 From: Will Greenberg Date: Thu, 7 May 2026 16:46:59 -0700 Subject: [PATCH] Refactor and simplify QmdlReader In the past, QmdlReader was written to share a trait with DiagDevice, so it had to pretend to be reading MessagesContainers. This needlessly complicated both its code as well as that of consumers'. Instead, QmdlReader now returns a stream of diag Messages. QmdlReader also automatically detects if it's reading a compressed QMDL stream or not. Additionally, QmdlReader no longer can be bounded by a filesize limit, and instead relies on HDLC message framing to detect file truncation. This works for both compressed and uncompressed QMDL files. --- check/src/main.rs | 34 ++--- daemon/src/analysis.rs | 33 +++-- daemon/src/diag.rs | 2 +- daemon/src/pcap.rs | 29 ++-- daemon/src/qmdl_store.rs | 8 +- daemon/src/server.rs | 112 +++++++++++--- lib/src/analysis/analyzer.rs | 92 ++++++------ lib/src/diag.rs | 48 +++--- lib/src/qmdl.rs | 280 ++++++++++++++++++++--------------- 9 files changed, 368 insertions(+), 270 deletions(-) diff --git a/check/src/main.rs b/check/src/main.rs index cf7c02a..34501b0 100644 --- a/check/src/main.rs +++ b/check/src/main.rs @@ -1,15 +1,13 @@ use clap::Parser; -use futures::TryStreamExt; use log::{debug, error, info, warn}; use pcap_file_tokio::pcapng::{Block, PcapNgReader}; use rayhunter::{ analysis::analyzer::{AnalysisRow, AnalyzerConfig, EventType, Harness}, - diag::DataType, gsmtap_parser, pcap::GsmtapPcapWriter, qmdl::QmdlReader, }; -use std::{collections::HashMap, future, path::PathBuf, pin::pin}; +use std::{collections::HashMap, path::PathBuf}; use tokio::fs::File; use walkdir::WalkDir; @@ -113,22 +111,14 @@ async fn analyze_pcap(pcap_path: &str, show_skipped: bool) { async fn analyze_qmdl(qmdl_path: &str, show_skipped: bool) { let mut harness = Harness::new_with_config(&AnalyzerConfig::default()); let qmdl_file = &mut File::open(&qmdl_path).await.expect("failed to open file"); - let compressed = qmdl_path.ends_with(".gz"); - let qmdl_reader = QmdlReader::new(qmdl_file, compressed, None); - let mut qmdl_stream = pin!( - qmdl_reader - .as_stream() - .try_filter(|container| future::ready(container.data_type == DataType::UserSpace)) - ); + let mut qmdl_reader = QmdlReader::new(qmdl_file).await.expect("failed to open QmdlReader"); let mut report = Report::new(qmdl_path); - while let Some(container) = qmdl_stream - .try_next() + while let Some(maybe_message) = qmdl_reader + .get_next_message() .await - .expect("failed getting QMDL container") + .expect("failed to get message") { - for row in harness.analyze_qmdl_messages(container) { - report.process_row(row); - } + report.process_row(harness.analyze_qmdl_message(maybe_message)); } report.print_summary(show_skipped); } @@ -137,9 +127,7 @@ async fn pcapify(qmdl_path: &PathBuf) { let qmdl_file = &mut File::open(&qmdl_path) .await .expect("failed to open qmdl file"); - let compressed = qmdl_path.ends_with(".gz"); - let qmdl_file_size = qmdl_file.metadata().await.unwrap().len(); - let mut qmdl_reader = QmdlReader::new(qmdl_file, compressed, Some(qmdl_file_size as usize)); + let mut qmdl_reader = QmdlReader::new(qmdl_file).await.expect("failed to open QmdlReader"); let mut pcap_path = qmdl_path.clone(); pcap_path.set_extension("pcapng"); let pcap_file = &mut File::create(&pcap_path) @@ -147,12 +135,12 @@ async fn pcapify(qmdl_path: &PathBuf) { .expect("failed to open pcap file"); let mut pcap_writer = GsmtapPcapWriter::new(pcap_file).await.unwrap(); pcap_writer.write_iface_header().await.unwrap(); - while let Some(container) = qmdl_reader - .get_next_messages_container() + while let Some(maybe_message) = qmdl_reader + .get_next_message() .await - .expect("failed to get container") + .expect("failed to get message") { - for msg in container.into_messages().into_iter().flatten() { + if let Ok(msg) = maybe_message { if let Ok(Some((timestamp, parsed))) = gsmtap_parser::parse(msg) { pcap_writer .write_gsmtap_message(parsed, timestamp) diff --git a/daemon/src/analysis.rs b/daemon/src/analysis.rs index e7c9a78..ddae901 100644 --- a/daemon/src/analysis.rs +++ b/daemon/src/analysis.rs @@ -1,15 +1,14 @@ use std::sync::Arc; -use std::{cmp, future, pin}; +use std::cmp; use axum::Json; use axum::{ extract::{Path, State}, http::StatusCode, }; -use futures::TryStreamExt; use log::{error, info}; use rayhunter::analysis::analyzer::{AnalyzerConfig, EventType, Harness}; -use rayhunter::diag::{DataType, MessagesContainer}; +use rayhunter::diag::{DiagParsingError, Message, MessagesContainer}; use serde::Serialize; use tokio::fs::File; use tokio::io::{AsyncWriteExt, BufWriter}; @@ -46,7 +45,7 @@ impl AnalysisWriter { // Runs the analysis harness on the given container, serializing the results // to the analysis file, returning the whether any warnings were detected - pub async fn analyze( + pub async fn analyze_container( &mut self, container: MessagesContainer, ) -> Result { @@ -61,6 +60,17 @@ impl AnalysisWriter { Ok(max_type) } + pub async fn analyze_message( + &mut self, + maybe_qmdl_msg: Result, + ) -> Result { + let row = self.harness.analyze_qmdl_message(maybe_qmdl_msg); + if !row.is_empty() { + self.write(&row).await?; + } + Ok(row.get_max_event_type()) + } + async fn write(&mut self, value: &T) -> Result<(), std::io::Error> { let mut value_str = serde_json::to_string(value).unwrap(); value_str.push('\n'); @@ -134,7 +144,7 @@ async fn perform_analysis( analyzer_config: &AnalyzerConfig, ) -> Result<(), String> { info!("Opening QMDL and analysis file for {name}..."); - let (analysis_file, qmdl_reader) = { + let (analysis_file, mut qmdl_reader) = { let mut qmdl_store = qmdl_store_lock.write().await; let (entry_index, _) = qmdl_store .entry_for_name(name) @@ -154,20 +164,15 @@ async fn perform_analysis( let mut analysis_writer = AnalysisWriter::new(analysis_file, analyzer_config) .await .map_err(|e| format!("{e:?}"))?; - let mut qmdl_stream = pin::pin!( - qmdl_reader - .as_stream() - .try_filter(|container| future::ready(container.data_type == DataType::UserSpace)) - ); info!("Starting analysis for {name}..."); - while let Some(container) = qmdl_stream - .try_next() + while let Some(maybe_message) = qmdl_reader + .get_next_message() .await - .expect("failed getting QMDL container") + .expect("failed to get message") { let _ = analysis_writer - .analyze(container) + .analyze_message(maybe_message) .await .map_err(|e| format!("{e:?}"))?; } diff --git a/daemon/src/diag.rs b/daemon/src/diag.rs index e5da9db..a0dfe3f 100644 --- a/daemon/src/diag.rs +++ b/daemon/src/diag.rs @@ -342,7 +342,7 @@ impl DiagTask { debug!("done!"); let container_bytes: usize = container.messages.iter().map(|m| m.data.len()).sum(); self.bytes_since_space_check += container_bytes; - let max_type = match analysis_writer.analyze(container).await { + let max_type = match analysis_writer.analyze_container(container).await { Ok(t) => t, Err(e) => { warn!("failed to analyze container: {e}"); diff --git a/daemon/src/pcap.rs b/daemon/src/pcap.rs index 93b29cc..88e14d5 100644 --- a/daemon/src/pcap.rs +++ b/daemon/src/pcap.rs @@ -7,12 +7,11 @@ use axum::http::StatusCode; use axum::http::header::CONTENT_TYPE; use axum::response::{IntoResponse, Response}; use log::error; -use rayhunter::diag::DataType; use rayhunter::gsmtap_parser; use rayhunter::pcap::GsmtapPcapWriter; use rayhunter::qmdl::QmdlReader; use std::sync::Arc; -use tokio::io::{AsyncRead, AsyncWrite, duplex}; +use tokio::io::{AsyncRead, AsyncSeek, AsyncWrite, duplex}; use tokio_util::io::ReaderStream; // Streams a pcap file chunk-by-chunk to the client by reading the QMDL data @@ -71,28 +70,22 @@ pub async fn get_pcap( pub async fn generate_pcap_data(writer: W, mut reader: QmdlReader) -> Result<(), Error> where W: AsyncWrite + Unpin + Send, - R: AsyncRead + Unpin, + R: AsyncRead + AsyncSeek + Unpin, { let mut pcap_writer = GsmtapPcapWriter::new(writer).await?; pcap_writer.write_iface_header().await?; - while let Some(container) = reader.get_next_messages_container().await? { - if container.data_type != DataType::UserSpace { - continue; - } - - for maybe_msg in container.into_messages() { - match maybe_msg { - Ok(msg) => { - let maybe_gsmtap_msg = gsmtap_parser::parse(msg)?; - if let Some((timestamp, gsmtap_msg)) = maybe_gsmtap_msg { - pcap_writer - .write_gsmtap_message(gsmtap_msg, timestamp) - .await?; - } + while let Some(maybe_msg) = reader.get_next_message().await? { + match maybe_msg { + Ok(msg) => { + let maybe_gsmtap_msg = gsmtap_parser::parse(msg)?; + if let Some((timestamp, gsmtap_msg)) = maybe_gsmtap_msg { + pcap_writer + .write_gsmtap_message(gsmtap_msg, timestamp) + .await?; } - Err(e) => error!("error parsing message: {e:?}"), } + Err(e) => error!("error parsing message: {e:?}"), } } diff --git a/daemon/src/qmdl_store.rs b/daemon/src/qmdl_store.rs index ac52182..0ade581 100644 --- a/daemon/src/qmdl_store.rs +++ b/daemon/src/qmdl_store.rs @@ -288,11 +288,9 @@ impl RecordingStore { let file = File::open(entry.get_qmdl_filepath(&self.path)) .await .map_err(RecordingStoreError::ReadFileError)?; - Ok(QmdlReader::new( - file, - entry.compressed, - Some(entry.uncompressed_qmdl_size_bytes), - )) + QmdlReader::new(file) + .await + .map_err(RecordingStoreError::ReadFileError) } // Returns the corresponding QMDL file for a given entry diff --git a/daemon/src/server.rs b/daemon/src/server.rs index 17dbb36..1566336 100644 --- a/daemon/src/server.rs +++ b/daemon/src/server.rs @@ -82,7 +82,7 @@ pub async fn get_qmdl( &entry.uncompressed_qmdl_size_bytes.to_string(), ), ]; - let body = Body::from_stream(qmdl_reader.as_stream()); + let body = Body::from_stream(qmdl_reader.as_qmdl_stream()); Ok((headers, body).into_response()) } @@ -310,7 +310,7 @@ pub async fn get_zip( Path(entry_name): Path, ) -> Result { let qmdl_idx = entry_name.trim_end_matches(".zip").to_owned(); - let (entry_index, compressed) = { + let (entry_index, _) = { let qmdl_store = state.qmdl_store_lock.read().await; let (entry_index, entry) = qmdl_store.entry_for_name(&qmdl_idx).ok_or(( StatusCode::NOT_FOUND, @@ -338,7 +338,7 @@ pub async fn get_zip( // Add QMDL file { let entry = ZipEntryBuilder::new( - format!("{qmdl_idx}.qmdl").into(), + format!("{qmdl_idx}.qmdl.gz").into(), Compression::Stored, ); // FuturesAsyncWriteCompatExt::compat_write because async-zip's @@ -425,9 +425,13 @@ pub async fn debug_set_display_state( #[cfg(test)] mod tests { + use std::io::Cursor; + use super::*; use async_zip::base::read::mem::ZipFileReader; use axum::extract::{Path, State}; + use futures::AsyncReadExt; + use rayhunter::{diag::{DataType, HdlcEncapsulatedMessage, Message, MessagesContainer}, qmdl::{QmdlReader, QmdlWriter}}; use tempfile::TempDir; async fn create_test_qmdl_store() -> (TempDir, Arc>) { @@ -441,24 +445,23 @@ mod tests { async fn create_test_entry_with_data( store_lock: &Arc>, - test_data: &[u8], + test_data: &MessagesContainer, ) -> String { let entry_name = { let mut store = store_lock.write().await; - let (mut qmdl_file, _analysis_file) = store.new_entry().await.unwrap(); + let (qmdl_gz_file, _analysis_file) = store.new_entry().await.unwrap(); - if !test_data.is_empty() { - use tokio::io::AsyncWriteExt; - qmdl_file.write_all(test_data).await.unwrap(); - qmdl_file.flush().await.unwrap(); - } + let mut writer = QmdlWriter::new(qmdl_gz_file); + writer.write_container(test_data).await.unwrap(); + let test_data_len = writer.total_uncompressed_bytes; + writer.close().await.unwrap(); let current_entry = store.current_entry.unwrap(); let entry = &store.manifest.entries[current_entry]; let entry_name = entry.name.clone(); store - .update_entry_qmdl_size(current_entry, test_data.len()) + .update_entry_qmdl_size(current_entry, test_data_len) .await .unwrap(); entry_name @@ -492,17 +495,69 @@ mod tests { }) } + // valid HDLC encapsulated diag message generated from + // rayhunter::diag::test::get_test_message + fn create_test_container() -> MessagesContainer { + MessagesContainer { + data_type: DataType::UserSpace, + num_messages: 1, + messages: vec![ + HdlcEncapsulatedMessage { + len: 39, + data: vec![ + 16, + 0, + 32, + 0, + 32, + 0, + 192, + 176, + 26, + 165, + 245, + 135, + 118, + 35, + 2, + 1, + 20, + 14, + 48, + 0, + 160, + 0, + 2, + 8, + 0, + 0, + 217, + 15, + 5, + 0, + 0, + 0, + 0, + 1, + 0, + 10, + 13, + 196, + 126, + ], + }, + ], + } + } + #[tokio::test] async fn test_get_zip_success() { let (_temp_dir, store_lock) = create_test_qmdl_store().await; - let test_qmdl_data = vec![0x7E, 0x00, 0x00, 0x00, 0x10, 0x00, 0x7E]; + let test_qmdl_data = create_test_container(); let entry_name = create_test_entry_with_data(&store_lock, &test_qmdl_data).await; let state = create_test_server_state(store_lock); - let result = get_zip(State(state), Path(entry_name.clone())).await; - - assert!(result.is_ok()); - let response = result.unwrap(); + let response = get_zip(State(state), Path(entry_name.clone())).await.unwrap(); let headers = response.headers(); assert_eq!(headers.get("content-type").unwrap(), "application/zip"); @@ -511,14 +566,11 @@ mod tests { let body_bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap(); let zip_reader = ZipFileReader::new(body_bytes.to_vec()).await.unwrap(); - - let filenames = zip_reader - .file() - .entries() + let zip_reader_file = zip_reader.file(); + let filenames: Vec = zip_reader_file.entries() .iter() - .map(|entry| entry.filename().as_str().unwrap().to_owned()) - .collect::>(); - + .map(|entry| entry.filename().as_str().unwrap().to_string()) + .collect(); assert_eq!( filenames, vec![ @@ -526,5 +578,19 @@ mod tests { format!("{entry_name}.pcapng"), ] ); + + let mut qmdl_body = Vec::with_capacity(128); + zip_reader.reader_without_entry(0) + .await + .unwrap() + .read_to_end(&mut qmdl_body) + .await + .unwrap(); + let mut qmdl_reader = QmdlReader::new(Cursor::new(qmdl_body)).await.unwrap(); + let expected_message = Message::from_hdlc(&test_qmdl_data.messages[0].data).unwrap(); + assert_eq!( + qmdl_reader.get_next_message().await.unwrap(), + Some(Ok(expected_message)), + ); } } diff --git a/lib/src/analysis/analyzer.rs b/lib/src/analysis/analyzer.rs index d903775..3f8db27 100644 --- a/lib/src/analysis/analyzer.rs +++ b/lib/src/analysis/analyzer.rs @@ -5,6 +5,7 @@ use serde::{Deserialize, Serialize}; use std::borrow::Cow; use crate::analysis::diagnostic::DiagnosticAnalyzer; +use crate::diag::{DiagParsingError, Message}; use crate::gsmtap::{GsmtapHeader, GsmtapMessage, GsmtapType}; use crate::util::RuntimeMetadata; use crate::{diag::MessagesContainer, gsmtap_parser}; @@ -231,6 +232,14 @@ pub struct AnalysisRow { } impl AnalysisRow { + pub fn new() -> Self { + Self { + packet_timestamp: None, + skipped_message_reason: None, + events: vec![], + } + } + pub fn is_empty(&self) -> bool { self.skipped_message_reason.is_none() && !self.contains_warnings() } @@ -412,50 +421,47 @@ impl Harness { row } + pub fn analyze_qmdl_message(&mut self, maybe_qmdl_message: Result) -> AnalysisRow { + let mut row = AnalysisRow::new(); + self.packet_num += 1; + + let qmdl_message = match maybe_qmdl_message { + Ok(msg) => msg, + Err(err) => { + row.skipped_message_reason = Some(format!("{err:?}")); + return row; + } + }; + let gsmtap_message = match gsmtap_parser::parse(qmdl_message) { + Ok(msg) => msg, + Err(err) => { + row.skipped_message_reason = Some(format!("{err:?}")); + return row; + } + }; + + let Some((timestamp, gsmtap_msg)) = gsmtap_message else { + return row; + }; + row.packet_timestamp = Some(timestamp.to_datetime()); + + let element = match InformationElement::try_from(&gsmtap_msg) { + Ok(element) => element, + Err(err) => { + row.skipped_message_reason = Some(format!("{err:?}")); + return row; + } + }; + + row.events = self.analyze_information_element(&element); + row + } + pub fn analyze_qmdl_messages(&mut self, container: MessagesContainer) -> Vec { - let mut rows = Vec::new(); - for maybe_qmdl_message in container.into_messages() { - self.packet_num += 1; - - rows.push(AnalysisRow { - packet_timestamp: None, - skipped_message_reason: None, - events: Vec::new(), - }); - // unwrap is safe here since we just pushed a value - let row = rows.last_mut().unwrap(); - let qmdl_message = match maybe_qmdl_message { - Ok(msg) => msg, - Err(err) => { - row.skipped_message_reason = Some(format!("{err:?}")); - continue; - } - }; - - let gsmtap_message = match gsmtap_parser::parse(qmdl_message) { - Ok(msg) => msg, - Err(err) => { - row.skipped_message_reason = Some(format!("{err:?}")); - continue; - } - }; - - let Some((timestamp, gsmtap_msg)) = gsmtap_message else { - continue; - }; - row.packet_timestamp = Some(timestamp.to_datetime()); - - let element = match InformationElement::try_from(&gsmtap_msg) { - Ok(element) => element, - Err(err) => { - row.skipped_message_reason = Some(format!("{err:?}")); - continue; - } - }; - - row.events = self.analyze_information_element(&element); - } - rows + container.into_messages() + .drain(..) + .map(|maybe_message| self.analyze_qmdl_message(maybe_message)) + .collect() } fn analyze_information_element(&mut self, ie: &InformationElement) -> Vec> { diff --git a/lib/src/diag.rs b/lib/src/diag.rs index 2cd2e0b..723aeb0 100644 --- a/lib/src/diag.rs +++ b/lib/src/diag.rs @@ -90,24 +90,7 @@ impl MessagesContainer { let mut result = Vec::new(); for msg in self.messages { for sub_msg in msg.data.split_inclusive(|&b| b == MESSAGE_TERMINATOR) { - match hdlc_decapsulate(sub_msg, &CRC_CCITT) { - Ok(data) => match Message::from_bytes((&data, 0)) { - Ok(((leftover_bytes, _), res)) => { - if !leftover_bytes.is_empty() { - warn!( - "warning: {} leftover bytes when parsing Message", - leftover_bytes.len() - ); - } - result.push(Ok(res)); - } - Err(e) => result.push(Err(DiagParsingError::MessageParsingError(e, data))), - }, - Err(err) => result.push(Err(DiagParsingError::HdlcDecapsulationError( - err, - sub_msg.to_vec(), - ))), - } + result.push(Message::from_hdlc(sub_msg)); } } result @@ -159,6 +142,29 @@ pub enum Message { }, } +impl Message { + pub fn from_hdlc(data: &[u8]) -> Result { + match hdlc_decapsulate(data, &CRC_CCITT) { + Ok(data) => match Message::from_bytes((&data, 0)) { + Ok(((leftover_bytes, _), res)) => { + if !leftover_bytes.is_empty() { + warn!( + "warning: {} leftover bytes when parsing Message", + leftover_bytes.len() + ); + } + Ok(res) + } + Err(e) => Err(DiagParsingError::MessageParsingError(e, data)), + }, + Err(err) => Err(DiagParsingError::HdlcDecapsulationError( + err, + data.to_vec(), + )), + } + } +} + #[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] #[deku(ctx = "log_type: u16, hdr_len: u16", id = "log_type")] pub enum LogBody { @@ -418,7 +424,7 @@ pub fn build_log_mask_request( } #[cfg(test)] -mod test { +pub(crate) mod test { use super::*; // Just about all of these test cases from manually parsing diag packets w/ QCSuper @@ -532,7 +538,7 @@ mod test { // this log is based on one captured on a real device -- if it fails to // serialize or deserialize, that's probably a problem with this mock, not // the DiagReader implementation - fn get_test_message(payload: &[u8]) -> (HdlcEncapsulatedMessage, Message) { + pub fn get_test_message(payload: &[u8]) -> (HdlcEncapsulatedMessage, Message) { let length_with_payload = 31 + payload.len() as u16; let message = Message::Log { pending_msgs: 0, @@ -566,6 +572,8 @@ mod test { len: encapsulated_data.len() as u32, data: encapsulated_data, }; + // sanity check + assert_eq!(&Message::from_hdlc(&encapsulated.data).unwrap(), &message); (encapsulated, message) } diff --git a/lib/src/qmdl.rs b/lib/src/qmdl.rs index 1115ebf..86bd040 100644 --- a/lib/src/qmdl.rs +++ b/lib/src/qmdl.rs @@ -3,17 +3,18 @@ //! QmdlReader and QmdlWriter can read and write MessagesContainers to and from //! QMDL files. -use std::io::{Cursor, ErrorKind}; +use std::io::ErrorKind; use std::pin::Pin; use std::task::Poll; -use crate::diag::{DataType, HdlcEncapsulatedMessage, MESSAGE_TERMINATOR, MessagesContainer}; +use crate::diag::{DiagParsingError, MESSAGE_TERMINATOR, Message, MessagesContainer}; use async_compression::tokio::bufread::GzipDecoder; use async_compression::tokio::write::GzipEncoder; use futures::TryStream; -use log::error; -use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}; +use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt, AsyncWrite, AsyncWriteExt, BufReader}; + +const GZIP_MAGIC_NUMBER: u16 = 0x1f8b; pub struct QmdlWriter where @@ -64,15 +65,13 @@ enum QmdlReaderSource { #[derive(Debug)] struct QmdlAsyncReader { source: QmdlReaderSource, - uncompressed_bytes_read: usize, - max_uncompressed_bytes: Option, } impl QmdlAsyncReader where T: AsyncRead, { - pub fn new(reader: T, compressed: bool, max_uncompressed_bytes: Option) -> Self { + pub fn new(reader: T, compressed: bool) -> Self { let source = if compressed { QmdlReaderSource::Compressed { reader: GzipDecoder::new(BufReader::new(reader)), @@ -83,8 +82,6 @@ where }; Self { source, - uncompressed_bytes_read: 0, - max_uncompressed_bytes, } } } @@ -98,23 +95,7 @@ where cx: &mut std::task::Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, ) -> Poll> { - // if we've already read beyond the byte limit, return without reading - // into the buffer, essentially signalling EOF - if let Some(max_bytes) = self.max_uncompressed_bytes - && self.uncompressed_bytes_read >= max_bytes - { - if self.uncompressed_bytes_read > max_bytes { - error!( - "warning: {} bytes read, but max_bytes was {}", - self.uncompressed_bytes_read, max_bytes - ); - } - return Poll::Ready(Ok(())); - } - - let before = buf.filled().len(); - let this = self.get_mut(); - let res = match &mut this.source { + let res = match &mut self.get_mut().source { QmdlReaderSource::Compressed { reader, eof } => { // if we already determined we've reached the Gzip EOF, don't read more if *eof { @@ -134,17 +115,6 @@ where } QmdlReaderSource::Uncompressed { reader } => Pin::new(reader).poll_read(cx, buf), }; - - // if we read more bytes than is allowed, cap the buffer by - // our max bytes - let after = buf.filled().len(); - let read = after - before; - if let Some(max_bytes) = this.max_uncompressed_bytes - && this.uncompressed_bytes_read + read > max_bytes - { - let overread = this.uncompressed_bytes_read + read - max_bytes; - buf.set_filled(after - overread); - } res } } @@ -157,34 +127,59 @@ where buf_reader: BufReader>, } +async fn is_gzip_stream(mut reader: T) -> std::io::Result +where + T: AsyncRead + AsyncSeek + Unpin +{ + let magic_number = reader.read_u16().await?; + reader.rewind().await?; + // this is safe because 0x1f8b.... doesn't overlap with any known + // diag::DataType values + Ok(magic_number == GZIP_MAGIC_NUMBER) +} + impl QmdlReader where - T: AsyncRead + Unpin, + T: AsyncRead + AsyncSeek + Unpin, { - pub fn new(reader: T, compressed: bool, max_uncompressed_bytes: Option) -> Self { - QmdlReader { + pub async fn new(mut reader: T) -> std::io::Result { + let compressed = is_gzip_stream(&mut reader) + .await + .unwrap_or(false); + Ok(QmdlReader { buf_reader: BufReader::new(QmdlAsyncReader::new( reader, compressed, - max_uncompressed_bytes, )), - } + }) } - pub fn as_stream(self) -> impl TryStream { + pub fn as_qmdl_stream(self) -> impl TryStream, Error = std::io::Error> { futures::stream::try_unfold(self, |mut reader| async { - let maybe_container = reader.get_next_messages_container().await?; - match maybe_container { - Some(container) => Ok(Some((container, reader))), + let mut buf = vec![]; + match reader .buf_reader + .read_until(MESSAGE_TERMINATOR, &mut buf) + .await { + Err(err) => Err(err), + Ok(0) => Ok(None), + Ok(_) => Ok(Some((buf, reader))), + } + }) + } + + pub fn as_message_stream(self) -> impl TryStream, Error = std::io::Error> { + futures::stream::try_unfold(self, |mut reader| async { + match reader.get_next_message().await? { + Some(res) => Ok(Some((res, reader))), None => Ok(None), } }) } - pub async fn get_next_messages_container( + pub async fn get_next_message( &mut self, - ) -> Result, std::io::Error> { - let mut buf = Vec::new(); + ) -> Result>, std::io::Error> { + let mut buf = vec![]; if self .buf_reader .read_until(MESSAGE_TERMINATOR, &mut buf) @@ -194,19 +189,7 @@ where return Ok(None); } - // Since QMDL is just a flat list of messages, we can't actually - // reproduce the container structure they came from in the original - // read. So we'll just pretend that all containers had exactly one - // message. As far as I know, the number of messages per container - // doesn't actually affect anything, so this should be fine. - Ok(Some(MessagesContainer { - data_type: DataType::UserSpace, - num_messages: 1, - messages: vec![HdlcEncapsulatedMessage { - len: buf.len() as u32, - data: buf, - }], - })) + Ok(Some(Message::from_hdlc(&buf))) } } @@ -227,92 +210,148 @@ where mod test { use std::io::Cursor; - use crate::diag::CRC_CCITT; - use crate::hdlc::hdlc_encapsulate; + use crate::diag::{DataType, HdlcEncapsulatedMessage, test::get_test_message}; use super::*; - fn get_test_messages() -> Vec { - let messages: Vec = (10..20) - .map(|i| { - let data = hdlc_encapsulate(&vec![i as u8; i], &CRC_CCITT); - HdlcEncapsulatedMessage { - len: data.len() as u32, - data, - } - }) - .collect(); - messages + fn get_test_messages() -> (Vec, Vec) { + let mut hdlcs = Vec::new(); + let mut messages = Vec::new(); + for i in 10..20 { + let (hdlc, msg) = get_test_message(&[i]); + hdlcs.push(hdlc); + messages.push(msg); + } + (hdlcs, messages) } // returns a byte array consisting of concatenated HDLC encapsulated // test messages fn get_test_message_bytes() -> Vec { - get_test_messages() + let (hdlcs, _) = get_test_messages(); + hdlcs .iter() .flat_map(|msg| msg.data.clone()) .collect() } fn get_test_containers() -> Vec { - let messages = get_test_messages(); - let (messages1, messages2) = messages.split_at(5); + let (hdlcs, _) = get_test_messages(); + let (hdlcs1, hdlcs2) = hdlcs.split_at(5); vec![ MessagesContainer { data_type: DataType::UserSpace, - num_messages: messages1.len() as u32, - messages: messages1.to_vec(), + num_messages: hdlcs1.len() as u32, + messages: hdlcs1.to_vec(), }, MessagesContainer { data_type: DataType::UserSpace, - num_messages: messages2.len() as u32, - messages: messages2.to_vec(), + num_messages: hdlcs2.len() as u32, + messages: hdlcs2.to_vec(), }, ] } #[tokio::test] - async fn test_unbounded_qmdl_reader() { + async fn test_qmdl_reader() { let mut buf = Cursor::new(get_test_message_bytes()); - let mut reader = QmdlReader::new(&mut buf, false, None); - let expected_messages = get_test_messages(); - for message in expected_messages { - let expected_container = MessagesContainer { - data_type: DataType::UserSpace, - num_messages: 1, - messages: vec![message], - }; + let mut reader = QmdlReader::new(&mut buf).await.unwrap(); + let (_, expected_messages) = get_test_messages(); + for msg in expected_messages { assert_eq!( - expected_container, - reader.get_next_messages_container().await.unwrap().unwrap() + Ok(msg), + reader.get_next_message().await.unwrap().unwrap() ); } } #[tokio::test] - async fn test_bounded_qmdl_reader() { - let mut buf = Cursor::new(get_test_message_bytes()); + async fn test_truncation() { + run_truncation_tests(false).await; + } - // bound the reader to the first two messages - let mut expected_messages = get_test_messages(); - let limit = expected_messages[0].len + expected_messages[1].len; + #[tokio::test] + async fn test_compressed_truncation() { + run_truncation_tests(true).await; + } - let mut reader = QmdlReader::new(&mut buf, false, Some(limit as usize)); - for message in expected_messages.drain(0..2) { - let expected_container = MessagesContainer { - data_type: DataType::UserSpace, - num_messages: 1, - messages: vec![message], - }; - assert_eq!( - expected_container, - reader.get_next_messages_container().await.unwrap().unwrap() - ); + async fn run_truncation_tests(compressed: bool) { + let (hdlcs, expected_messages) = get_test_messages(); + let (bytes, message_lengths): (Vec, Vec) = if compressed { + let mut buf = Vec::new(); + let mut compressed_lengths = Vec::new(); + let mut writer = GzipEncoder::new(&mut buf); + for hdlc in &hdlcs { + let before = writer.get_ref().len(); + writer.write_all(&hdlc.data).await.unwrap(); + writer.flush().await.unwrap(); + let after = writer.get_ref().len(); + compressed_lengths.push(after - before); + } + (buf, compressed_lengths) + } else { + ( + get_test_message_bytes(), + hdlcs.iter() + .map(|hdlc| hdlc.data.len()) + .collect() + ) + }; + for truncated_hdlc_i in 1..hdlcs.len() - 1 { + let whole_bytes: usize = message_lengths.iter().take(truncated_hdlc_i).sum(); + for truncated_byte in 1..message_lengths[truncated_hdlc_i] { + let mut truncated_bytes = Cursor::new(&bytes[0..whole_bytes + truncated_byte]); + let mut reader = QmdlReader::new(&mut truncated_bytes).await.unwrap(); + for msg in expected_messages.iter().take(truncated_hdlc_i) { + assert_eq!( + Ok(msg), + reader.get_next_message().await.unwrap().unwrap().as_ref() + ); + } + if compressed { + // for a compressed reader, we have a couple possible + // outcomes, depending on how far along the Gzip DEFLATE + // block was before it was truncated: + match reader.get_next_message().await.unwrap() { + // if the block was truncated early enough, the + // GzipDecoder will detect an unexpected EOF, and our + // QmdlReader will indicate the stream of messages is + // done + None => {}, + // if it's further along, the expanded result will be an + // invalid HDLC block. if that's the case, make sure the + // QmdlReader indicates the stream of messages is over + // with afterwards + Some(Err(DiagParsingError::HdlcDecapsulationError(_, _))) => { + assert!(matches!(reader.get_next_message().await, Ok(None))); + }, + // if it's further along still, we may get a complete + // Message, so make sure it matches the next expected + // one. then, make sure we've hit the end of the message + // stream + Some(Ok(msg)) => { + assert_eq!(&msg, &expected_messages[truncated_hdlc_i]); + assert!(matches!(reader.get_next_message().await, Ok(None))); + }, + // we should never be able to decapsulate the HDLC into + // an invalid Diag message + Some(Err(DiagParsingError::MessageParsingError(_, _))) + => { + panic!("unexpected MessageParsingError"); + } + } + } else { + // a truncated uncompressed reader should always end on an + // HdlcDecapsulationError, and then return Ok(None) to + // indicate the message stream is over + assert!(matches!( + reader.get_next_message().await, + Ok(Some(Err(DiagParsingError::HdlcDecapsulationError(_, _)))) + )); + assert!(matches!(reader.get_next_message().await, Ok(None))); + } + } } - assert!(matches!( - reader.get_next_messages_container().await, - Ok(None) - )); } /// Writes the test containers to a QmdlWriter, optionally finishing the @@ -330,21 +369,16 @@ mod test { writer.close().await.unwrap(); } } - let mut reader = QmdlReader::new(Cursor::new(buf), true, None); - let expected_messages = get_test_messages(); + let mut reader = QmdlReader::new(Cursor::new(buf)).await.unwrap(); + let (_, expected_messages) = get_test_messages(); for message in expected_messages { - let expected_container = MessagesContainer { - data_type: DataType::UserSpace, - num_messages: 1, - messages: vec![message], - }; assert_eq!( - expected_container, - reader.get_next_messages_container().await.unwrap().unwrap() + Ok(message), + reader.get_next_message().await.unwrap().unwrap() ); } assert!(matches!( - reader.get_next_messages_container().await, + reader.get_next_message().await, Ok(None) )); }