Transition to async I/O for most things

Mixing async and sync I/O leads to a multitude of complications, and
generally speaking it's much more convenient to stick to one paradigm
or the other. Since axum (and many other HTTP servers) use async,
and since async is a convenient model for performing operations like
"handle an MPSC message or file read, whichever happens first", let's
commit to an async interface.
This commit is contained in:
Will Greenberg
2024-02-13 20:23:02 -08:00
parent abd3b98cff
commit 775cbcda1e
14 changed files with 550 additions and 629 deletions

View File

@@ -1,87 +1,78 @@
use std::pin::pin;
use std::sync::Arc;
use axum::extract::State;
use axum::http::StatusCode;
use rayhunter::diag::DataType;
use rayhunter::diag_device::DiagDevice;
use rayhunter::diag_reader::DiagReader;
use tokio::sync::RwLock;
use tokio::sync::mpsc::{Receiver, self};
use tokio::sync::mpsc::Receiver;
use rayhunter::qmdl::QmdlWriter;
use log::{debug, info};
use tokio::sync::mpsc::error::TryRecvError;
use tokio::task::JoinHandle;
use log::{debug, error, info};
use tokio::fs::File;
use tokio_util::task::TaskTracker;
use futures::{StreamExt, TryStreamExt};
use crate::error::RayhunterError;
use crate::qmdl_store::QmdlStore;
use crate::server::ServerState;
pub enum DiagDeviceCtrlMessage {
StopRecording,
StartRecording(QmdlWriter<std::fs::File>),
StartRecording(QmdlWriter<File>),
Exit,
}
pub fn run_diag_read_thread(task_tracker: &TaskTracker, mut dev: DiagDevice, mut qmdl_file_rx: Receiver<DiagDeviceCtrlMessage>, qmdl_store_lock: Arc<RwLock<QmdlStore>>) -> JoinHandle<Result<(), RayhunterError>> {
// mpsc channel for updating QmdlStore entry filesizes. First usize is the
// index, second is the size in bytes
let (tx, mut rx) = mpsc::channel::<(usize, usize)>(1);
// Spawn a thread to monitor the (usize, usize) channel for updates,
// triggering QmdlStore updates
let qmdl_store_lock_clone = qmdl_store_lock.clone();
pub fn run_diag_read_thread(task_tracker: &TaskTracker, mut dev: DiagDevice, mut qmdl_file_rx: Receiver<DiagDeviceCtrlMessage>, qmdl_store_lock: Arc<RwLock<QmdlStore>>) {
task_tracker.spawn(async move {
while let Some((entry_idx, new_size)) = rx.recv().await {
let mut qmdl_store = qmdl_store_lock_clone.write().await;
qmdl_store.update_entry(entry_idx, new_size).await
.expect("failed to update qmdl file size");
}
info!("QMDL store size updater thread exiting...");
});
// Spawn a thread to drive the DiagDevice reading loop. Since DiagDevice
// works via synchronous I/O, we have to spawn a "blocking" thread to avoid
// gumming up tokio's event loop.
task_tracker.spawn_blocking(move || {
let initial_file = qmdl_store_lock.write().await.new_entry().await.expect("failed creating QMDL file entry");
let mut qmdl_writer: Option<QmdlWriter<File>> = Some(QmdlWriter::new(initial_file));
let mut diag_stream = pin!(dev.as_stream().into_stream());
loop {
// First check if we've gotten any control meesages
match qmdl_file_rx.try_recv() {
Ok(DiagDeviceCtrlMessage::StartRecording(qmdl_writer)) => {
dev.qmdl_writer = Some(qmdl_writer);
},
Ok(DiagDeviceCtrlMessage::StopRecording) => dev.qmdl_writer = None,
// Disconnected means all the Senders have been dropped, so it's
// time to go
Ok(DiagDeviceCtrlMessage::Exit) | Err(TryRecvError::Disconnected) => {
info!("Diag reader thread exiting...");
return Ok(())
},
// empty just means there's no message for us, so continue as normal
Err(TryRecvError::Empty) => {},
}
// remember the QmdlStore current entry index so we can update its size later
let qmdl_store_index = qmdl_store_lock.blocking_read().current_entry;
// TODO: once we're actually doing analysis, we'll wanna use the messages
// returned here. Until then, the DiagDevice has already written those messages
// to the QMDL file, so we can just ignore them.
debug!("reading response from diag device...");
let _messages = dev.read_response().map_err(RayhunterError::DiagReadError)?;
debug!("got diag response ({} messages)", _messages.len());
// keep track of how many bytes were written to the QMDL file so we can read
// a valid block of data from it in the HTTP server
if let Some(qmdl_writer) = dev.qmdl_writer.as_ref() {
debug!("total QMDL bytes written: {}, sending update...", qmdl_writer.total_written);
let index = qmdl_store_index.expect("DiagDevice had qmdl_writer, but QmdlStore didn't have current entry???");
tx.blocking_send((index, qmdl_writer.total_written)).unwrap();
debug!("done!");
} else {
debug!("no qmdl_writer set, continuing...");
tokio::select! {
msg = qmdl_file_rx.recv() => {
match msg {
Some(DiagDeviceCtrlMessage::StartRecording(new_writer)) => {
qmdl_writer = Some(new_writer);
},
Some(DiagDeviceCtrlMessage::StopRecording) => qmdl_writer = None,
// None means all the Senders have been dropped, so it's
// time to go
Some(DiagDeviceCtrlMessage::Exit) | None => {
info!("Diag reader thread exiting...");
return Ok(())
},
}
}
maybe_container = diag_stream.next() => {
match maybe_container.unwrap() {
Ok(container) => {
if container.data_type != DataType::UserSpace {
debug!("skipping non-userspace diag messages...");
continue;
}
// keep track of how many bytes were written to the QMDL file so we can read
// a valid block of data from it in the HTTP server
if let Some(writer) = qmdl_writer.as_mut() {
writer.write_container(&container).await.expect("failed to write to QMDL writer");
debug!("total QMDL bytes written: {}, updating manifest...", writer.total_written);
let mut qmdl_store = qmdl_store_lock.write().await;
let index = qmdl_store.current_entry.expect("DiagDevice had qmdl_writer, but QmdlStore didn't have current entry???");
qmdl_store.update_entry(index, writer.total_written).await
.expect("failed to update qmdl file size");
debug!("done!");
} else {
debug!("no qmdl_writer set, continuing...");
}
},
Err(err) => {
error!("error reading diag device: {}", err);
return Err(err);
}
}
}
}
}
})
});
}
pub async fn start_recording(State(state): State<Arc<ServerState>>) -> Result<(StatusCode, String), (StatusCode, String)> {
@@ -91,7 +82,7 @@ pub async fn start_recording(State(state): State<Arc<ServerState>>) -> Result<(S
let mut qmdl_store = state.qmdl_store_lock.write().await;
let qmdl_file = qmdl_store.new_entry().await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("couldn't create new qmdl entry: {}", e)))?;
let qmdl_writer = QmdlWriter::new(qmdl_file.into_std().await);
let qmdl_writer = QmdlWriter::new(qmdl_file);
state.diag_device_ctrl_sender.send(DiagDeviceCtrlMessage::StartRecording(qmdl_writer)).await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("couldn't send stop recording message: {}", e)))?;
Ok((StatusCode::ACCEPTED, format!("ok")))

View File

@@ -9,8 +9,6 @@ pub enum RayhunterError{
ConfigFileParsingError(#[from] toml::de::Error),
#[error("Diag intialization error: {0}")]
DiagInitError(DiagDeviceError),
#[error("Diag read error: {0}")]
DiagReadError(DiagDeviceError),
#[error("Tokio error: {0}")]
TokioError(#[from] tokio::io::Error),
#[error("QmdlStore error: {0}")]

View File

@@ -20,7 +20,6 @@ use log::{info, error};
use rayhunter::diag_device::DiagDevice;
use axum::routing::{get, post};
use axum::Router;
use rayhunter::qmdl::QmdlWriter;
use stats::get_qmdl_manifest;
use tokio::sync::mpsc::{self, Sender};
use tokio::task::JoinHandle;
@@ -127,11 +126,9 @@ async fn main() -> Result<(), RayhunterError> {
let qmdl_store_lock = Arc::new(RwLock::new(init_qmdl_store(&config).await?));
let (tx, rx) = mpsc::channel::<DiagDeviceCtrlMessage>(1);
if !config.readonly_mode {
let qmdl_file = qmdl_store_lock.write().await.new_entry().await?;
let qmdl_writer = QmdlWriter::new(qmdl_file.into_std().await);
let mut dev = DiagDevice::new(Some(qmdl_writer))
let mut dev = DiagDevice::new().await
.map_err(RayhunterError::DiagInitError)?;
dev.config_logs()
dev.config_logs().await
.map_err(RayhunterError::DiagInitError)?;
run_diag_read_thread(&task_tracker, dev, rx, qmdl_store_lock.clone());

View File

@@ -1,26 +1,24 @@
use crate::ServerState;
use rayhunter::diag::DataType;
use rayhunter::gsmtap_parser::GsmtapParser;
use rayhunter::pcap::GsmtapPcapWriter;
use rayhunter::qmdl::{QmdlReader, QmdlReaderError};
use rayhunter::diag_reader::DiagReader;
use rayhunter::qmdl::QmdlReader;
use axum::body::Body;
use axum::http::header::CONTENT_TYPE;
use axum::extract::{State, Path};
use axum::http::StatusCode;
use axum::response::{Response, IntoResponse};
use std::io::Write;
use std::pin::Pin;
use tokio::io::duplex;
use tokio_util::io::ReaderStream;
use std::{future, pin::pin};
use std::sync::Arc;
use std::task::{Poll, Context};
use futures_core::Stream;
use log::error;
use tokio::sync::mpsc;
use futures::TryStreamExt;
// Streams a pcap file chunk-by-chunk to the client by reading the QMDL data
// written so far. This is done by spawning a blocking thread (a tokio thread
// capable of handling blocking operations) which streams chunks of pcap data to
// a channel that's piped to the client.
// written so far. This is done by spawning a thread which streams chunks of
// pcap data to a channel that's piped to the client.
pub async fn get_pcap(State(state): State<Arc<ServerState>>, Path(qmdl_name): Path<String>) -> Result<Response, (StatusCode, String)> {
let qmdl_store = state.qmdl_store_lock.read().await;
let entry = qmdl_store.entry_for_name(&qmdl_name)
@@ -31,82 +29,39 @@ pub async fn get_pcap(State(state): State<Arc<ServerState>>, Path(qmdl_name): Pa
"QMDL file is empty, try again in a bit!".to_string()
));
}
let qmdl_file = qmdl_store.open_entry(&entry).await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("{:?}", e)))?
.into_std().await;
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("{:?}", e)))?;
// the QMDL reader should stop at the last successfully written data chunk
// (entry.size_bytes)
let mut gsmtap_parser = GsmtapParser::new();
let (reader, writer) = duplex(1024);
let mut pcap_writer = GsmtapPcapWriter::new(writer).await.unwrap();
pcap_writer.write_iface_header().await.unwrap();
let (tx, rx) = mpsc::channel(1);
let channel_reader = ChannelReader { rx };
let channel_writer = ChannelWriter { tx };
tokio::task::spawn_blocking(move || {
// the QMDL reader should stop at the last successfully written data
// chunk (qmdl_bytes_written)
let mut qmdl_reader = QmdlReader::new(qmdl_file, Some(entry.size_bytes));
let mut gsmtap_parser = GsmtapParser::new();
let mut pcap_writer = GsmtapPcapWriter::new(channel_writer).unwrap();
pcap_writer.write_iface_header().unwrap();
loop {
match qmdl_reader.read_response() {
Ok(messages) => {
for maybe_msg in messages {
match maybe_msg {
Ok(msg) => {
let maybe_gsmtap_msg = gsmtap_parser.recv_message(msg)
.expect("error parsing gsmtap message");
if let Some((timestamp, gsmtap_msg)) = maybe_gsmtap_msg {
pcap_writer.write_gsmtap_message(gsmtap_msg, timestamp)
.expect("error writing pcap packet");
}
},
Err(e) => error!("error parsing message: {:?}", e),
tokio::spawn(async move {
let mut reader = QmdlReader::new(qmdl_file, Some(entry.size_bytes));
let mut messages_stream = pin!(reader.as_stream()
.try_filter(|container| future::ready(container.data_type == DataType::UserSpace)));
while let Some(container) = messages_stream.try_next().await.expect("failed getting QMDL container") {
for maybe_msg in container.into_messages() {
match maybe_msg {
Ok(msg) => {
let maybe_gsmtap_msg = gsmtap_parser.recv_message(msg)
.expect("error parsing gsmtap message");
if let Some((timestamp, gsmtap_msg)) = maybe_gsmtap_msg {
pcap_writer.write_gsmtap_message(gsmtap_msg, timestamp).await
.expect("error writing pcap packet");
}
}
},
// this is expected, and just means we've reached the end of the
// safely written QMDL data
Err(QmdlReaderError::MaxBytesReached(_)) => break,
Err(e) => {
error!("error reading qmdl file: {:?}", e);
break;
},
},
Err(e) => error!("error parsing message: {:?}", e),
}
}
}
});
let headers = [(CONTENT_TYPE, "application/vnd.tcpdump.pcap")];
let body = Body::from_stream(channel_reader);
let body = Body::from_stream(ReaderStream::new(reader));
Ok((headers, body).into_response())
}
struct ChannelWriter {
tx: mpsc::Sender<Vec<u8>>,
}
impl Write for ChannelWriter {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.tx.blocking_send(buf.to_vec())
.map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, "channel closed"))?;
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
struct ChannelReader {
rx: mpsc::Receiver<Vec<u8>>,
}
impl Stream for ChannelReader {
type Item = Result<Vec<u8>, String>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.rx.poll_recv(cx) {
Poll::Ready(Some(msg)) => Poll::Ready(Some(Ok(msg))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}