mirror of
https://github.com/EFForg/rayhunter.git
synced 2026-04-28 00:20:00 -07:00
Transition to async I/O for most things
Mixing async and sync I/O leads to a multitude of complications, and generally speaking it's much more convenient to stick to one paradigm or the other. Since axum (and many other HTTP servers) use async, and since async is a convenient model for performing operations like "handle an MPSC message or file read, whichever happens first", let's commit to an async interface.
This commit is contained in:
@@ -3,19 +3,18 @@
|
||||
//! QmdlReader and QmdlWriter can read and write MessagesContainers to and from
|
||||
//! QMDL files.
|
||||
|
||||
use crate::diag_reader::DiagReader;
|
||||
use crate::diag::{MessagesContainer, MESSAGE_TERMINATOR, HdlcEncapsulatedMessage, DataType};
|
||||
|
||||
use std::io::{Write, BufReader, BufRead, Read};
|
||||
use thiserror::Error;
|
||||
use futures::TryStream;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, AsyncBufReadExt};
|
||||
use log::error;
|
||||
|
||||
pub struct QmdlWriter<T> where T: Write {
|
||||
pub struct QmdlWriter<T> where T: AsyncWrite + Unpin {
|
||||
writer: T,
|
||||
pub total_written: usize,
|
||||
}
|
||||
|
||||
impl<T> QmdlWriter<T> where T: Write {
|
||||
impl<T> QmdlWriter<T> where T: AsyncWrite + Unpin {
|
||||
pub fn new(writer: T) -> Self {
|
||||
QmdlWriter::new_with_existing_size(writer, 0)
|
||||
}
|
||||
@@ -27,30 +26,22 @@ impl<T> QmdlWriter<T> where T: Write {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write_container(&mut self, container: &MessagesContainer) -> std::io::Result<()> {
|
||||
pub async fn write_container(&mut self, container: &MessagesContainer) -> std::io::Result<()> {
|
||||
for msg in &container.messages {
|
||||
self.writer.write_all(&msg.data)?;
|
||||
self.writer.write_all(&msg.data).await?;
|
||||
self.total_written += msg.data.len();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum QmdlReaderError {
|
||||
#[error("IO error: {0}")]
|
||||
IoError(#[from] std::io::Error),
|
||||
#[error("Reached max_bytes count {0}")]
|
||||
MaxBytesReached(usize),
|
||||
}
|
||||
|
||||
pub struct QmdlReader<T> where T: Read {
|
||||
pub struct QmdlReader<T> where T: AsyncRead {
|
||||
reader: BufReader<T>,
|
||||
bytes_read: usize,
|
||||
max_bytes: Option<usize>,
|
||||
}
|
||||
|
||||
impl<T> QmdlReader<T> where T: Read {
|
||||
impl<T> QmdlReader<T> where T: AsyncRead + Unpin {
|
||||
pub fn new(reader: T, max_bytes: Option<usize>) -> Self {
|
||||
QmdlReader {
|
||||
reader: BufReader::new(reader),
|
||||
@@ -58,23 +49,29 @@ impl<T> QmdlReader<T> where T: Read {
|
||||
max_bytes,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> DiagReader for QmdlReader<T> where T: Read {
|
||||
type Err = QmdlReaderError;
|
||||
pub fn as_stream(&mut self) -> impl TryStream<Ok = MessagesContainer, Error = std::io::Error> + '_ {
|
||||
futures::stream::try_unfold(self, |reader| async {
|
||||
let maybe_container = reader.get_next_messages_container().await?;
|
||||
match maybe_container {
|
||||
Some(container) => Ok(Some((container, reader))),
|
||||
None => Ok(None)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn get_next_messages_container(&mut self) -> Result<MessagesContainer, Self::Err> {
|
||||
async fn get_next_messages_container(&mut self) -> Result<Option<MessagesContainer>, std::io::Error> {
|
||||
if let Some(max_bytes) = self.max_bytes {
|
||||
if self.bytes_read >= max_bytes {
|
||||
if self.bytes_read > max_bytes {
|
||||
error!("warning: {} bytes read, but max_bytes was {}", self.bytes_read, max_bytes);
|
||||
}
|
||||
return Err(QmdlReaderError::MaxBytesReached(max_bytes));
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
|
||||
let mut buf = Vec::new();
|
||||
let bytes_read = self.reader.read_until(MESSAGE_TERMINATOR, &mut buf)?;
|
||||
let bytes_read = self.reader.read_until(MESSAGE_TERMINATOR, &mut buf).await?;
|
||||
self.bytes_read += bytes_read;
|
||||
|
||||
// Since QMDL is just a flat list of messages, we can't actually
|
||||
@@ -82,7 +79,7 @@ impl<T> DiagReader for QmdlReader<T> where T: Read {
|
||||
// read. So we'll just pretend that all containers had exactly one
|
||||
// message. As far as I know, the number of messages per container
|
||||
// doesn't actually affect anything, so this should be fine.
|
||||
Ok(MessagesContainer {
|
||||
Ok(Some(MessagesContainer {
|
||||
data_type: DataType::UserSpace,
|
||||
num_messages: 1,
|
||||
messages: vec![
|
||||
@@ -91,7 +88,7 @@ impl<T> DiagReader for QmdlReader<T> where T: Read {
|
||||
data: buf,
|
||||
},
|
||||
]
|
||||
})
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -100,7 +97,7 @@ mod test {
|
||||
use std::io::Cursor;
|
||||
|
||||
use crate::hdlc::hdlc_encapsulate;
|
||||
use crate::diag_reader::CRC_CCITT;
|
||||
use crate::diag::CRC_CCITT;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -140,8 +137,8 @@ mod test {
|
||||
]
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unbounded_qmdl_reader() {
|
||||
#[tokio::test]
|
||||
async fn test_unbounded_qmdl_reader() {
|
||||
let mut buf = Cursor::new(get_test_message_bytes());
|
||||
let mut reader = QmdlReader::new(&mut buf, None);
|
||||
let expected_messages = get_test_messages();
|
||||
@@ -151,12 +148,12 @@ mod test {
|
||||
num_messages: 1,
|
||||
messages: vec![message],
|
||||
};
|
||||
assert_eq!(expected_container, reader.get_next_messages_container().unwrap());
|
||||
assert_eq!(expected_container, reader.get_next_messages_container().await.unwrap().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bounded_qmdl_reader() {
|
||||
#[tokio::test]
|
||||
async fn test_bounded_qmdl_reader() {
|
||||
let mut buf = Cursor::new(get_test_message_bytes());
|
||||
|
||||
// bound the reader to the first two messages
|
||||
@@ -170,30 +167,30 @@ mod test {
|
||||
num_messages: 1,
|
||||
messages: vec![message],
|
||||
};
|
||||
assert_eq!(expected_container, reader.get_next_messages_container().unwrap());
|
||||
assert_eq!(expected_container, reader.get_next_messages_container().await.unwrap().unwrap());
|
||||
}
|
||||
assert!(matches!(reader.get_next_messages_container(), Err(QmdlReaderError::MaxBytesReached(_))));
|
||||
assert!(matches!(reader.get_next_messages_container().await, Ok(None)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_qmdl_writer() {
|
||||
#[tokio::test]
|
||||
async fn test_qmdl_writer() {
|
||||
let mut buf = Vec::new();
|
||||
let mut writer = QmdlWriter::new(&mut buf);
|
||||
let expected_containers = get_test_containers();
|
||||
for container in &expected_containers {
|
||||
writer.write_container(container).unwrap();
|
||||
writer.write_container(container).await.unwrap();
|
||||
}
|
||||
assert_eq!(writer.total_written, buf.len());
|
||||
assert_eq!(buf, get_test_message_bytes());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_writing_and_reading() {
|
||||
#[tokio::test]
|
||||
async fn test_writing_and_reading() {
|
||||
let mut buf = Vec::new();
|
||||
let mut writer = QmdlWriter::new(&mut buf);
|
||||
let expected_containers = get_test_containers();
|
||||
for container in &expected_containers {
|
||||
writer.write_container(container).unwrap();
|
||||
writer.write_container(container).await.unwrap();
|
||||
}
|
||||
|
||||
let limit = Some(buf.len());
|
||||
@@ -205,8 +202,8 @@ mod test {
|
||||
num_messages: 1,
|
||||
messages: vec![message],
|
||||
};
|
||||
assert_eq!(expected_container, reader.get_next_messages_container().unwrap());
|
||||
assert_eq!(expected_container, reader.get_next_messages_container().await.unwrap().unwrap());
|
||||
}
|
||||
assert!(matches!(reader.get_next_messages_container(), Err(QmdlReaderError::MaxBytesReached(_))));
|
||||
assert!(matches!(reader.get_next_messages_container().await, Ok(None)));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user