diff --git a/crates/brk_cohort/src/address.rs b/crates/brk_cohort/src/address.rs index b8d7e3f97..96fccd1cc 100644 --- a/crates/brk_cohort/src/address.rs +++ b/crates/brk_cohort/src/address.rs @@ -1,12 +1,11 @@ use brk_traversable::Traversable; use rayon::prelude::*; -use vecdb::{AnyExportableVec, ReadOnlyClone}; use crate::Filter; use super::{ByAmountRange, ByGreatEqualAmount, ByLowerThanAmount}; -#[derive(Default, Clone)] +#[derive(Default, Clone, Traversable)] pub struct AddressGroups { pub ge_amount: ByGreatEqualAmount, pub amount_range: ByAmountRange, @@ -80,49 +79,3 @@ impl AddressGroups { } } -impl ReadOnlyClone for AddressGroups { - type ReadOnly = AddressGroups; - - fn read_only_clone(&self) -> Self::ReadOnly { - AddressGroups { - ge_amount: self.ge_amount.read_only_clone(), - amount_range: self.amount_range.read_only_clone(), - lt_amount: self.lt_amount.read_only_clone(), - } - } -} - -impl Traversable for AddressGroups -where - ByGreatEqualAmount: brk_traversable::Traversable, - ByAmountRange: brk_traversable::Traversable, - ByLowerThanAmount: brk_traversable::Traversable, - T: Send + Sync, -{ - fn to_tree_node(&self) -> brk_traversable::TreeNode { - brk_traversable::TreeNode::Branch( - [ - (String::from("ge_amount"), self.ge_amount.to_tree_node()), - ( - String::from("amount_range"), - self.amount_range.to_tree_node(), - ), - (String::from("lt_amount"), self.lt_amount.to_tree_node()), - ] - .into(), - ) - } - - fn iter_any_exportable(&self) -> impl Iterator { - [ - Box::new(self.ge_amount.iter_any_exportable()) - as Box>, - Box::new(self.amount_range.iter_any_exportable()) - as Box>, - Box::new(self.lt_amount.iter_any_exportable()) - as Box>, - ] - .into_iter() - .flatten() - } -} diff --git a/crates/brk_computer/src/distribution/metrics/realized/basic.rs b/crates/brk_computer/src/distribution/metrics/realized/basic.rs new file mode 100644 index 000000000..f21040910 --- /dev/null +++ b/crates/brk_computer/src/distribution/metrics/realized/basic.rs @@ -0,0 +1,486 @@ +use brk_error::Result; +use brk_traversable::Traversable; +use brk_types::{ + BasisPoints32, BasisPointsSigned32, Bitcoin, Cents, CentsSats, CentsSigned, CentsSquaredSats, + Dollars, Height, Indexes, Version, +}; +use derive_more::{Deref, DerefMut}; +use vecdb::{ + AnyStoredVec, AnyVec, BytesVec, Exit, ReadableCloneableVec, ReadableVec, Rw, StorageMode, + WritableVec, +}; + +use crate::{ + blocks, + distribution::state::RealizedState, + internal::{ + CentsUnsignedToDollars, ComputedFromHeight, ComputedFromHeightCumulative, + ComputedFromHeightRatio, ComputedFromHeightRatioPercentiles, LazyFromHeight, + PercentFromHeight, PercentRollingEmas1w1m, PercentRollingWindows, Price, RatioCentsBp32, + RatioCentsSignedCentsBps32, RatioCentsSignedDollarsBps32, RollingEmas2w, RollingWindows, + ValueFromHeightCumulative, + }, + prices, +}; + +use crate::distribution::metrics::ImportConfig; + +use super::RealizedCore; + +#[derive(Deref, DerefMut, Traversable)] +pub struct RealizedBasic { + #[deref] + #[deref_mut] + #[traversable(flatten)] + pub core: RealizedCore, + + // --- Stateful fields --- + pub profit_value_created: ComputedFromHeight, + pub profit_value_destroyed: ComputedFromHeight, + pub loss_value_created: ComputedFromHeight, + pub loss_value_destroyed: ComputedFromHeight, + + pub capitulation_flow: LazyFromHeight, + pub profit_flow: LazyFromHeight, + + pub gross_pnl_sum: RollingWindows, + + pub net_pnl_change_1m: ComputedFromHeight, + pub net_pnl_change_1m_rel_to_realized_cap: PercentFromHeight, + pub net_pnl_change_1m_rel_to_market_cap: PercentFromHeight, + + pub sent_in_profit: ValueFromHeightCumulative, + pub sent_in_profit_ema: RollingEmas2w, + pub sent_in_loss: ValueFromHeightCumulative, + pub sent_in_loss_ema: RollingEmas2w, + + // --- Investor price & price bands --- + pub investor_price: Price>, + pub investor_price_ratio: ComputedFromHeightRatio, + + pub lower_price_band: Price>, + pub upper_price_band: Price>, + + pub cap_raw: M::Stored>, + pub investor_cap_raw: M::Stored>, + + pub sell_side_risk_ratio: PercentRollingWindows, + pub sell_side_risk_ratio_24h_ema: PercentRollingEmas1w1m, + + // --- Peak regret --- + pub peak_regret: ComputedFromHeightCumulative, + pub peak_regret_rel_to_realized_cap: PercentFromHeight, + + // --- Realized price ratio percentiles --- + pub realized_price_ratio_percentiles: ComputedFromHeightRatioPercentiles, +} + +impl RealizedBasic { + pub(crate) fn forced_import(cfg: &ImportConfig) -> Result { + let v0 = Version::ZERO; + let v1 = Version::ONE; + + let core = RealizedCore::forced_import(cfg)?; + + // Stateful fields + let profit_value_created = cfg.import_computed("profit_value_created", v0)?; + let profit_value_destroyed = cfg.import_computed("profit_value_destroyed", v0)?; + let loss_value_created = cfg.import_computed("loss_value_created", v0)?; + let loss_value_destroyed = cfg.import_computed("loss_value_destroyed", v0)?; + + let capitulation_flow = LazyFromHeight::from_computed::( + &cfg.name("capitulation_flow"), + cfg.version, + loss_value_destroyed.height.read_only_boxed_clone(), + &loss_value_destroyed, + ); + let profit_flow = LazyFromHeight::from_computed::( + &cfg.name("profit_flow"), + cfg.version, + profit_value_destroyed.height.read_only_boxed_clone(), + &profit_value_destroyed, + ); + + let gross_pnl_sum = cfg.import_rolling("gross_pnl_sum", v1)?; + + // Investor price & price bands + let investor_price = cfg.import_price("investor_price", v0)?; + let investor_price_ratio = cfg.import_ratio("investor_price", v0)?; + let lower_price_band = cfg.import_price("lower_price_band", v0)?; + let upper_price_band = cfg.import_price("upper_price_band", v0)?; + + let cap_raw = cfg.import_bytes("cap_raw", v0)?; + let investor_cap_raw = cfg.import_bytes("investor_cap_raw", v0)?; + + let sell_side_risk_ratio = + cfg.import_percent_rolling_bp32("sell_side_risk_ratio", Version::new(2))?; + let sell_side_risk_ratio_24h_ema = + cfg.import_percent_emas_1w_1m_bp32("sell_side_risk_ratio_24h", Version::new(2))?; + + // Peak regret + let peak_regret = cfg.import_cumulative("realized_peak_regret", Version::new(2))?; + let peak_regret_rel_to_realized_cap = + cfg.import_percent_bp32("realized_peak_regret_rel_to_realized_cap", Version::new(2))?; + + // Realized price ratio percentiles + let realized_price_ratio_percentiles = + ComputedFromHeightRatioPercentiles::forced_import( + cfg.db, + &cfg.name("realized_price"), + cfg.version + v1, + cfg.indexes, + )?; + + Ok(Self { + core, + profit_value_created, + profit_value_destroyed, + loss_value_created, + loss_value_destroyed, + capitulation_flow, + profit_flow, + gross_pnl_sum, + net_pnl_change_1m: cfg.import_computed("net_pnl_change_1m", Version::new(3))?, + net_pnl_change_1m_rel_to_realized_cap: cfg + .import_percent_bps32("net_pnl_change_1m_rel_to_realized_cap", Version::new(4))?, + net_pnl_change_1m_rel_to_market_cap: cfg + .import_percent_bps32("net_pnl_change_1m_rel_to_market_cap", Version::new(4))?, + sent_in_profit: cfg.import_value_cumulative("sent_in_profit", v0)?, + sent_in_profit_ema: cfg.import_emas_2w("sent_in_profit", v0)?, + sent_in_loss: cfg.import_value_cumulative("sent_in_loss", v0)?, + sent_in_loss_ema: cfg.import_emas_2w("sent_in_loss", v0)?, + investor_price, + investor_price_ratio, + lower_price_band, + upper_price_band, + cap_raw, + investor_cap_raw, + sell_side_risk_ratio, + sell_side_risk_ratio_24h_ema, + peak_regret, + peak_regret_rel_to_realized_cap, + realized_price_ratio_percentiles, + }) + } + + pub(crate) fn min_stateful_height_len(&self) -> usize { + self.core + .min_stateful_height_len() + .min(self.profit_value_created.height.len()) + .min(self.profit_value_destroyed.height.len()) + .min(self.loss_value_created.height.len()) + .min(self.loss_value_destroyed.height.len()) + .min(self.sent_in_profit.base.sats.height.len()) + .min(self.sent_in_loss.base.sats.height.len()) + .min(self.investor_price.cents.height.len()) + .min(self.cap_raw.len()) + .min(self.investor_cap_raw.len()) + .min(self.peak_regret.height.len()) + } + + pub(crate) fn truncate_push(&mut self, height: Height, state: &RealizedState) -> Result<()> { + self.core.truncate_push(height, state)?; + self.profit_value_created + .height + .truncate_push(height, state.profit_value_created())?; + self.profit_value_destroyed + .height + .truncate_push(height, state.profit_value_destroyed())?; + self.loss_value_created + .height + .truncate_push(height, state.loss_value_created())?; + self.loss_value_destroyed + .height + .truncate_push(height, state.loss_value_destroyed())?; + self.sent_in_profit + .base + .sats + .height + .truncate_push(height, state.sent_in_profit())?; + self.sent_in_loss + .base + .sats + .height + .truncate_push(height, state.sent_in_loss())?; + self.investor_price + .cents + .height + .truncate_push(height, state.investor_price())?; + self.cap_raw.truncate_push(height, state.cap_raw())?; + self.investor_cap_raw + .truncate_push(height, state.investor_cap_raw())?; + self.peak_regret + .height + .truncate_push(height, state.peak_regret())?; + + Ok(()) + } + + pub(crate) fn collect_vecs_mut(&mut self) -> Vec<&mut dyn AnyStoredVec> { + let mut vecs = self.core.collect_vecs_mut(); + vecs.push(&mut self.profit_value_created.height as &mut dyn AnyStoredVec); + vecs.push(&mut self.profit_value_destroyed.height); + vecs.push(&mut self.loss_value_created.height); + vecs.push(&mut self.loss_value_destroyed.height); + vecs.push(&mut self.sent_in_profit.base.sats.height); + vecs.push(&mut self.sent_in_loss.base.sats.height); + vecs.push(&mut self.investor_price.cents.height); + vecs.push(&mut self.cap_raw as &mut dyn AnyStoredVec); + vecs.push(&mut self.investor_cap_raw as &mut dyn AnyStoredVec); + vecs.push(&mut self.peak_regret.height); + vecs + } + + pub(crate) fn compute_from_stateful( + &mut self, + starting_indexes: &Indexes, + others: &[&Self], + exit: &Exit, + ) -> Result<()> { + // Core aggregation + let core_refs: Vec<&RealizedCore> = others.iter().map(|o| &o.core).collect(); + self.core + .compute_from_stateful(starting_indexes, &core_refs, exit)?; + + // Stateful field aggregation + sum_others!(self, starting_indexes, others, exit; profit_value_created.height); + sum_others!(self, starting_indexes, others, exit; profit_value_destroyed.height); + sum_others!(self, starting_indexes, others, exit; loss_value_created.height); + sum_others!(self, starting_indexes, others, exit; loss_value_destroyed.height); + sum_others!(self, starting_indexes, others, exit; sent_in_profit.base.sats.height); + sum_others!(self, starting_indexes, others, exit; sent_in_loss.base.sats.height); + + // Investor price aggregation from raw values + let investor_price_dep_version = others + .iter() + .map(|o| o.investor_price.cents.height.version()) + .fold(vecdb::Version::ZERO, |acc, v| acc + v); + self.investor_price + .cents + .height + .validate_computed_version_or_reset(investor_price_dep_version)?; + + let start = self + .cap_raw + .len() + .min(self.investor_cap_raw.len()) + .min(self.investor_price.cents.height.len()); + let end = others.iter().map(|o| o.cap_raw.len()).min().unwrap_or(0); + + let cap_ranges: Vec> = others + .iter() + .map(|o| o.cap_raw.collect_range_at(start, end)) + .collect(); + let investor_cap_ranges: Vec> = others + .iter() + .map(|o| o.investor_cap_raw.collect_range_at(start, end)) + .collect(); + + for i in start..end { + let height = Height::from(i); + let local_i = i - start; + + let mut sum_cap = CentsSats::ZERO; + let mut sum_investor_cap = CentsSquaredSats::ZERO; + + for idx in 0..others.len() { + sum_cap += cap_ranges[idx][local_i]; + sum_investor_cap += investor_cap_ranges[idx][local_i]; + } + + self.cap_raw.truncate_push(height, sum_cap)?; + self.investor_cap_raw + .truncate_push(height, sum_investor_cap)?; + + let investor_price = if sum_cap.inner() == 0 { + Cents::ZERO + } else { + Cents::new((sum_investor_cap / sum_cap.inner()) as u64) + }; + self.investor_price + .cents + .height + .truncate_push(height, investor_price)?; + } + + { + let _lock = exit.lock(); + self.investor_price.cents.height.write()?; + } + + // Peak regret aggregation + self.peak_regret.height.compute_sum_of_others( + starting_indexes.height, + &others + .iter() + .map(|v| &v.peak_regret.height) + .collect::>(), + exit, + )?; + + Ok(()) + } + + pub(crate) fn compute_rest_part1( + &mut self, + starting_indexes: &Indexes, + exit: &Exit, + ) -> Result<()> { + self.core.compute_rest_part1(starting_indexes, exit)?; + self.peak_regret + .compute_rest(starting_indexes.height, exit)?; + + Ok(()) + } + + pub(crate) fn compute_rest_part2( + &mut self, + blocks: &blocks::Vecs, + prices: &prices::Vecs, + starting_indexes: &Indexes, + height_to_supply: &impl ReadableVec, + height_to_market_cap: &impl ReadableVec, + exit: &Exit, + ) -> Result<()> { + // Core computation + self.core.compute_rest_part2( + blocks, + prices, + starting_indexes, + height_to_supply, + exit, + )?; + + // Gross PnL rolling sums + let window_starts = blocks.count.window_starts(); + self.gross_pnl_sum.compute_rolling_sum( + starting_indexes.height, + &window_starts, + &self.core.gross_pnl.cents.height, + exit, + )?; + + // Sent in profit/loss EMAs + self.sent_in_profit_ema.compute( + starting_indexes.height, + &blocks.count.height_2w_ago, + &self.sent_in_profit.base.sats.height, + &self.sent_in_profit.base.cents.height, + exit, + )?; + self.sent_in_loss_ema.compute( + starting_indexes.height, + &blocks.count.height_2w_ago, + &self.sent_in_loss.base.sats.height, + &self.sent_in_loss.base.cents.height, + exit, + )?; + + // Net PnL change 1m + self.net_pnl_change_1m.height.compute_rolling_change( + starting_indexes.height, + &blocks.count.height_1m_ago, + &self.core.net_realized_pnl.cumulative.height, + exit, + )?; + + self.net_pnl_change_1m_rel_to_realized_cap + .compute_binary::( + starting_indexes.height, + &self.net_pnl_change_1m.height, + &self.core.realized_cap_cents.height, + exit, + )?; + + self.net_pnl_change_1m_rel_to_market_cap + .compute_binary::( + starting_indexes.height, + &self.net_pnl_change_1m.height, + height_to_market_cap, + exit, + )?; + + // Investor price ratio and price bands + self.investor_price_ratio.compute_ratio( + starting_indexes, + &prices.price.cents.height, + &self.investor_price.cents.height, + exit, + )?; + + self.lower_price_band.cents.height.compute_transform2( + starting_indexes.height, + &self.core.realized_price.cents.height, + &self.investor_price.cents.height, + |(i, rp, ip, ..)| { + let rp = rp.as_u128(); + let ip = ip.as_u128(); + if ip == 0 { + (i, Cents::ZERO) + } else { + (i, Cents::from(rp * rp / ip)) + } + }, + exit, + )?; + + self.upper_price_band.cents.height.compute_transform2( + starting_indexes.height, + &self.investor_price.cents.height, + &self.core.realized_price.cents.height, + |(i, ip, rp, ..)| { + let ip = ip.as_u128(); + let rp = rp.as_u128(); + if rp == 0 { + (i, Cents::ZERO) + } else { + (i, Cents::from(ip * ip / rp)) + } + }, + exit, + )?; + + // Sell-side risk ratios + for (ssrr, rv) in self + .sell_side_risk_ratio + .as_mut_array() + .into_iter() + .zip(self.gross_pnl_sum.as_array()) + { + ssrr.compute_binary::( + starting_indexes.height, + &rv.height, + &self.core.realized_cap_cents.height, + exit, + )?; + } + + self.sell_side_risk_ratio_24h_ema.compute_from_24h( + starting_indexes.height, + &blocks.count.height_1w_ago, + &blocks.count.height_1m_ago, + &self.sell_side_risk_ratio._24h.bps.height, + exit, + )?; + + // Peak regret relative to realized cap + self.peak_regret_rel_to_realized_cap + .compute_binary::( + starting_indexes.height, + &self.peak_regret.height, + &self.core.realized_cap_cents.height, + exit, + )?; + + // Realized price ratio percentiles + self.realized_price_ratio_percentiles.compute( + blocks, + starting_indexes, + exit, + &self.core.realized_price_ratio.ratio.height, + &self.core.realized_price.cents.height, + )?; + + Ok(()) + } +} diff --git a/crates/brk_computer/src/distribution/metrics/realized/extended.rs b/crates/brk_computer/src/distribution/metrics/realized/extended.rs index 4d1792d75..6ba658eec 100644 --- a/crates/brk_computer/src/distribution/metrics/realized/extended.rs +++ b/crates/brk_computer/src/distribution/metrics/realized/extended.rs @@ -6,8 +6,8 @@ use vecdb::{Exit, ReadableVec, Rw, StorageMode}; use crate::{ blocks, internal::{ - ComputedFromHeightRatioFull, PercentFromHeight, RatioCents64, RatioDollarsBp32, - RollingWindows, + ComputedFromHeightRatioPercentiles, ComputedFromHeightRatioStdDevBands, + PercentFromHeight, RatioCents64, RatioDollarsBp32, RollingWindows, }, prices, }; @@ -25,12 +25,19 @@ pub struct RealizedExtended { pub realized_profit_to_loss_ratio: RollingWindows, - pub realized_price_ratio: ComputedFromHeightRatioFull, - pub investor_price_ratio: ComputedFromHeightRatioFull, + pub realized_price_ratio_percentiles: ComputedFromHeightRatioPercentiles, + pub realized_price_ratio_std_dev: ComputedFromHeightRatioStdDevBands, + pub investor_price_ratio_percentiles: ComputedFromHeightRatioPercentiles, + pub investor_price_ratio_std_dev: ComputedFromHeightRatioStdDevBands, } impl RealizedExtended { pub(crate) fn forced_import(cfg: &ImportConfig) -> Result { + let realized_price_name = cfg.name("realized_price"); + let realized_price_version = cfg.version + Version::ONE; + let investor_price_name = cfg.name("investor_price"); + let investor_price_version = cfg.version; + Ok(RealizedExtended { realized_cap_rel_to_own_market_cap: cfg .import_percent_bp32("realized_cap_rel_to_own_market_cap", Version::ONE)?, @@ -38,16 +45,28 @@ impl RealizedExtended { realized_loss_sum: cfg.import_rolling("realized_loss", Version::ONE)?, realized_profit_to_loss_ratio: cfg .import_rolling("realized_profit_to_loss_ratio", Version::ONE)?, - realized_price_ratio: ComputedFromHeightRatioFull::forced_import( + realized_price_ratio_percentiles: ComputedFromHeightRatioPercentiles::forced_import( cfg.db, - &cfg.name("realized_price"), - cfg.version + Version::ONE, + &realized_price_name, + realized_price_version, cfg.indexes, )?, - investor_price_ratio: ComputedFromHeightRatioFull::forced_import( + realized_price_ratio_std_dev: ComputedFromHeightRatioStdDevBands::forced_import( cfg.db, - &cfg.name("investor_price"), - cfg.version, + &realized_price_name, + realized_price_version, + cfg.indexes, + )?, + investor_price_ratio_percentiles: ComputedFromHeightRatioPercentiles::forced_import( + cfg.db, + &investor_price_name, + investor_price_version, + cfg.indexes, + )?, + investor_price_ratio_std_dev: ComputedFromHeightRatioStdDevBands::forced_import( + cfg.db, + &investor_price_name, + investor_price_version, cfg.indexes, )?, }) @@ -102,22 +121,38 @@ impl RealizedExtended { )?; } - // Realized price: ratio + percentiles + stddev bands - self.realized_price_ratio.compute_rest( + // Realized price: percentiles + stddev bands + let realized_price = &base.realized_price.cents.height; + self.realized_price_ratio_percentiles.compute( blocks, - prices, starting_indexes, exit, - &base.realized_price.cents.height, + &base.realized_price_ratio.ratio.height, + realized_price, + )?; + self.realized_price_ratio_std_dev.compute( + blocks, + starting_indexes, + exit, + &base.realized_price_ratio.ratio.height, + realized_price, )?; - // Investor price: ratio + percentiles + stddev bands - self.investor_price_ratio.compute_rest( + // Investor price: percentiles + stddev bands + let investor_price = &base.investor_price.cents.height; + self.investor_price_ratio_percentiles.compute( blocks, - prices, starting_indexes, exit, - &base.investor_price.cents.height, + &base.investor_price_ratio.ratio.height, + investor_price, + )?; + self.investor_price_ratio_std_dev.compute( + blocks, + starting_indexes, + exit, + &base.investor_price_ratio.ratio.height, + investor_price, )?; Ok(()) diff --git a/crates/brk_computer/src/internal/distribution_stats.rs b/crates/brk_computer/src/internal/containers/distribution_stats.rs similarity index 100% rename from crates/brk_computer/src/internal/distribution_stats.rs rename to crates/brk_computer/src/internal/containers/distribution_stats.rs diff --git a/crates/brk_computer/src/internal/emas.rs b/crates/brk_computer/src/internal/containers/emas.rs similarity index 100% rename from crates/brk_computer/src/internal/emas.rs rename to crates/brk_computer/src/internal/containers/emas.rs diff --git a/crates/brk_computer/src/internal/containers/mod.rs b/crates/brk_computer/src/internal/containers/mod.rs new file mode 100644 index 000000000..34c63d30f --- /dev/null +++ b/crates/brk_computer/src/internal/containers/mod.rs @@ -0,0 +1,9 @@ +mod distribution_stats; +mod emas; +mod per_period; +mod windows; + +pub use distribution_stats::*; +pub use emas::*; +pub use per_period::*; +pub use windows::*; diff --git a/crates/brk_computer/src/internal/per_period.rs b/crates/brk_computer/src/internal/containers/per_period.rs similarity index 100% rename from crates/brk_computer/src/internal/per_period.rs rename to crates/brk_computer/src/internal/containers/per_period.rs diff --git a/crates/brk_computer/src/internal/windows.rs b/crates/brk_computer/src/internal/containers/windows.rs similarity index 100% rename from crates/brk_computer/src/internal/windows.rs rename to crates/brk_computer/src/internal/containers/windows.rs diff --git a/crates/brk_computer/src/internal/from_height/aggregated.rs b/crates/brk_computer/src/internal/from_height/computed/aggregated.rs similarity index 100% rename from crates/brk_computer/src/internal/from_height/aggregated.rs rename to crates/brk_computer/src/internal/from_height/computed/aggregated.rs diff --git a/crates/brk_computer/src/internal/from_height/cumulative.rs b/crates/brk_computer/src/internal/from_height/computed/cumulative.rs similarity index 100% rename from crates/brk_computer/src/internal/from_height/cumulative.rs rename to crates/brk_computer/src/internal/from_height/computed/cumulative.rs diff --git a/crates/brk_computer/src/internal/from_height/cumulative_sum.rs b/crates/brk_computer/src/internal/from_height/computed/cumulative_sum.rs similarity index 100% rename from crates/brk_computer/src/internal/from_height/cumulative_sum.rs rename to crates/brk_computer/src/internal/from_height/computed/cumulative_sum.rs diff --git a/crates/brk_computer/src/internal/from_height/distribution.rs b/crates/brk_computer/src/internal/from_height/computed/distribution.rs similarity index 100% rename from crates/brk_computer/src/internal/from_height/distribution.rs rename to crates/brk_computer/src/internal/from_height/computed/distribution.rs diff --git a/crates/brk_computer/src/internal/from_height/full.rs b/crates/brk_computer/src/internal/from_height/computed/full.rs similarity index 100% rename from crates/brk_computer/src/internal/from_height/full.rs rename to crates/brk_computer/src/internal/from_height/computed/full.rs diff --git a/crates/brk_computer/src/internal/from_height/computed/mod.rs b/crates/brk_computer/src/internal/from_height/computed/mod.rs new file mode 100644 index 000000000..d6719ef57 --- /dev/null +++ b/crates/brk_computer/src/internal/from_height/computed/mod.rs @@ -0,0 +1,11 @@ +mod aggregated; +mod cumulative; +mod cumulative_sum; +mod distribution; +mod full; + +pub use aggregated::*; +pub use cumulative::*; +pub use cumulative_sum::*; +pub use distribution::*; +pub use full::*; diff --git a/crates/brk_computer/src/internal/from_height/mod.rs b/crates/brk_computer/src/internal/from_height/mod.rs index aae6c8447..62a416856 100644 --- a/crates/brk_computer/src/internal/from_height/mod.rs +++ b/crates/brk_computer/src/internal/from_height/mod.rs @@ -1,33 +1,23 @@ -mod aggregated; mod base; mod by_unit; +mod computed; mod constant; -mod cumulative; -mod cumulative_sum; -mod distribution; mod fiat; -mod full; mod lazy; mod percent; -mod percent_distribution; mod percentiles; mod price; mod ratio; mod stddev; mod value; -pub use aggregated::*; pub use base::*; pub use by_unit::*; +pub use computed::*; pub use constant::*; -pub use cumulative::*; -pub use cumulative_sum::*; -pub use distribution::*; pub use fiat::*; -pub use full::*; pub use lazy::*; pub use percent::*; -pub use percent_distribution::*; pub use percentiles::*; pub use price::*; pub use ratio::*; diff --git a/crates/brk_computer/src/internal/from_height/percent.rs b/crates/brk_computer/src/internal/from_height/percent/base.rs similarity index 97% rename from crates/brk_computer/src/internal/from_height/percent.rs rename to crates/brk_computer/src/internal/from_height/percent/base.rs index 627aad382..4d64bf25c 100644 --- a/crates/brk_computer/src/internal/from_height/percent.rs +++ b/crates/brk_computer/src/internal/from_height/percent/base.rs @@ -10,7 +10,7 @@ use crate::{ internal::{BpsType, ComputeDrawdown}, }; -use super::{ComputedFromHeight, LazyFromHeight}; +use crate::internal::{ComputedFromHeight, LazyFromHeight}; /// Basis-point storage with both ratio and percentage float views. /// diff --git a/crates/brk_computer/src/internal/from_height/percent_distribution.rs b/crates/brk_computer/src/internal/from_height/percent/distribution.rs similarity index 96% rename from crates/brk_computer/src/internal/from_height/percent_distribution.rs rename to crates/brk_computer/src/internal/from_height/percent/distribution.rs index 1c5f52608..e440fcfb4 100644 --- a/crates/brk_computer/src/internal/from_height/percent_distribution.rs +++ b/crates/brk_computer/src/internal/from_height/percent/distribution.rs @@ -8,7 +8,7 @@ use crate::{ internal::{BpsType, WindowStarts}, }; -use super::{ComputedFromHeightDistribution, LazyFromHeight}; +use crate::internal::{ComputedFromHeightDistribution, LazyFromHeight}; /// Like PercentFromHeight but with rolling distribution stats on the bps data. #[derive(Traversable)] diff --git a/crates/brk_computer/src/internal/from_height/percent/mod.rs b/crates/brk_computer/src/internal/from_height/percent/mod.rs new file mode 100644 index 000000000..2f366af34 --- /dev/null +++ b/crates/brk_computer/src/internal/from_height/percent/mod.rs @@ -0,0 +1,5 @@ +mod base; +mod distribution; + +pub use base::*; +pub use distribution::*; diff --git a/crates/brk_computer/src/internal/eager_indexes.rs b/crates/brk_computer/src/internal/indexes/eager.rs similarity index 100% rename from crates/brk_computer/src/internal/eager_indexes.rs rename to crates/brk_computer/src/internal/indexes/eager.rs diff --git a/crates/brk_computer/src/internal/lazy_eager_indexes.rs b/crates/brk_computer/src/internal/indexes/lazy.rs similarity index 100% rename from crates/brk_computer/src/internal/lazy_eager_indexes.rs rename to crates/brk_computer/src/internal/indexes/lazy.rs diff --git a/crates/brk_computer/src/internal/indexes/mod.rs b/crates/brk_computer/src/internal/indexes/mod.rs new file mode 100644 index 000000000..d7ec3a5a6 --- /dev/null +++ b/crates/brk_computer/src/internal/indexes/mod.rs @@ -0,0 +1,5 @@ +mod eager; +mod lazy; + +pub use eager::*; +pub use lazy::*; diff --git a/crates/brk_computer/src/internal/mod.rs b/crates/brk_computer/src/internal/mod.rs index fa7d9b6fe..ae606430e 100644 --- a/crates/brk_computer/src/internal/mod.rs +++ b/crates/brk_computer/src/internal/mod.rs @@ -1,35 +1,25 @@ mod aggregate; pub(crate) mod algo; +mod containers; mod db_utils; mod derived; -mod distribution_stats; -mod eager_indexes; -mod emas; mod from_height; mod from_tx; -mod lazy_eager_indexes; -mod lazy_value; -mod per_period; +mod indexes; mod rolling; mod traits; pub mod transform; mod value; -mod windows; pub(crate) use aggregate::*; pub(crate) use algo::*; +pub(crate) use containers::*; pub(crate) use db_utils::*; pub(crate) use derived::*; -pub(crate) use distribution_stats::*; -pub(crate) use eager_indexes::*; -pub(crate) use emas::*; pub(crate) use from_height::*; pub(crate) use from_tx::*; -pub(crate) use lazy_eager_indexes::*; -pub(crate) use lazy_value::*; -pub(crate) use per_period::*; +pub(crate) use indexes::*; pub(crate) use rolling::*; pub(crate) use traits::*; pub use transform::*; pub(crate) use value::*; -pub(crate) use windows::*; diff --git a/crates/brk_computer/src/internal/value.rs b/crates/brk_computer/src/internal/value/base.rs similarity index 100% rename from crates/brk_computer/src/internal/value.rs rename to crates/brk_computer/src/internal/value/base.rs diff --git a/crates/brk_computer/src/internal/lazy_value.rs b/crates/brk_computer/src/internal/value/lazy.rs similarity index 100% rename from crates/brk_computer/src/internal/lazy_value.rs rename to crates/brk_computer/src/internal/value/lazy.rs diff --git a/crates/brk_computer/src/internal/value/mod.rs b/crates/brk_computer/src/internal/value/mod.rs new file mode 100644 index 000000000..230d73479 --- /dev/null +++ b/crates/brk_computer/src/internal/value/mod.rs @@ -0,0 +1,5 @@ +mod base; +mod lazy; + +pub use base::*; +pub use lazy::*; diff --git a/crates/brk_traversable_derive/src/lib.rs b/crates/brk_traversable_derive/src/lib.rs index 583d9cdde..93d06c519 100644 --- a/crates/brk_traversable_derive/src/lib.rs +++ b/crates/brk_traversable_derive/src/lib.rs @@ -2,14 +2,14 @@ use proc_macro::TokenStream; use quote::quote; use syn::{Data, DeriveInput, Fields, Type, parse_macro_input}; -/// Struct-level attributes for Traversable derive +// =========================================================================== +// Struct & field attribute parsing +// =========================================================================== + #[derive(Default)] struct StructAttr { - /// If true, call .merge_branches().unwrap() on the final result merge: bool, - /// If true, delegate to the single field (transparent newtype pattern) transparent: bool, - /// If set, wrap the result in Branch { key: inner } wrap: Option, } @@ -20,7 +20,6 @@ fn get_struct_attr(attrs: &[syn::Attribute]) -> StructAttr { continue; } - // Try parsing as single ident (merge, transparent) if let Ok(ident) = attr.parse_args::() { match ident.to_string().as_str() { "merge" => result.merge = true, @@ -30,7 +29,6 @@ fn get_struct_attr(attrs: &[syn::Attribute]) -> StructAttr { continue; } - // Try parsing as name-value (wrap = "...") if let Ok(meta) = attr.parse_args::() && meta.path.is_ident("wrap") && let syn::Expr::Lit(syn::ExprLit { @@ -44,6 +42,155 @@ fn get_struct_attr(attrs: &[syn::Attribute]) -> StructAttr { result } +enum FieldAttr { + Normal, + Flatten, +} + +struct FieldInfo<'a> { + name: &'a syn::Ident, + is_option: bool, + attr: FieldAttr, + rename: Option, + wrap: Option, +} + +/// Returns None for skip, Some((attr, rename, wrap)) for normal/flatten. +fn get_field_attr(field: &syn::Field) -> Option<(FieldAttr, Option, Option)> { + let mut attr_type = FieldAttr::Normal; + let mut rename = None; + let mut wrap = None; + + for attr in &field.attrs { + if !attr.path().is_ident("traversable") { + continue; + } + + if let Ok(ident) = attr.parse_args::() { + match ident.to_string().as_str() { + "skip" => return None, + "flatten" => attr_type = FieldAttr::Flatten, + _ => {} + } + continue; + } + + if let Ok(metas) = attr.parse_args_with( + syn::punctuated::Punctuated::::parse_terminated, + ) { + for meta in metas { + if let syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Str(lit_str), + .. + }) = &meta.value + { + if meta.path.is_ident("rename") { + rename = Some(lit_str.value()); + } else if meta.path.is_ident("wrap") { + wrap = Some(lit_str.value()); + } + } + } + } + } + + Some((attr_type, rename, wrap)) +} + +fn is_field_skipped(field: &syn::Field) -> bool { + field.attrs.iter().any(|attr| { + attr.path().is_ident("traversable") + && attr.parse_args::().is_ok_and(|id| id == "skip") + }) +} + +// =========================================================================== +// Type helpers +// =========================================================================== + +fn is_option_type(ty: &Type) -> bool { + matches!( + ty, + Type::Path(type_path) + if type_path.path.segments.last() + .is_some_and(|seg| seg.ident == "Option") + ) +} + +fn is_box_type(ty: &Type) -> bool { + matches!( + ty, + Type::Path(type_path) + if type_path.path.segments.last() + .is_some_and(|seg| seg.ident == "Box") + ) +} + +/// Extract the inner type from `Option`, returning `Some(&T)`. +fn extract_option_inner(ty: &Type) -> Option<&Type> { + if let Type::Path(type_path) = ty + && let Some(seg) = type_path.path.segments.last() + && seg.ident == "Option" + && let syn::PathArguments::AngleBracketed(args) = &seg.arguments + && let Some(syn::GenericArgument::Type(inner)) = args.args.first() + { + Some(inner) + } else { + None + } +} + +/// Check if a type AST references the given identifier anywhere. +fn type_contains_ident(ty: &Type, ident: &syn::Ident) -> bool { + match ty { + Type::Path(type_path) => { + if let Some(qself) = &type_path.qself + && type_contains_ident(&qself.ty, ident) + { + return true; + } + type_path.path.segments.iter().any(|seg| { + if seg.ident == *ident { + return true; + } + match &seg.arguments { + syn::PathArguments::AngleBracketed(args) => args.args.iter().any(|arg| { + matches!(arg, syn::GenericArgument::Type(inner) if type_contains_ident(inner, ident)) + }), + syn::PathArguments::Parenthesized(args) => { + args.inputs.iter().any(|inner| type_contains_ident(inner, ident)) + || matches!(&args.output, syn::ReturnType::Type(_, inner) if type_contains_ident(inner, ident)) + } + syn::PathArguments::None => false, + } + }) + } + Type::Reference(r) => type_contains_ident(&r.elem, ident), + Type::Tuple(t) => t.elems.iter().any(|e| type_contains_ident(e, ident)), + Type::Array(a) => type_contains_ident(&a.elem, ident), + Type::Slice(s) => type_contains_ident(&s.elem, ident), + Type::Paren(p) => type_contains_ident(&p.elem, ident), + _ => false, + } +} + +/// Find the generic type parameter bounded by `StorageMode`, if any. +fn find_storage_mode_param(generics: &syn::Generics) -> Option<&syn::Ident> { + generics.type_params().find_map(|p| { + p.bounds + .iter() + .any(|b| { + matches!(b, syn::TypeParamBound::Trait(t) + if t.path.segments.last().is_some_and(|s| s.ident == "StorageMode")) + }) + .then_some(&p.ident) + }) +} + +// =========================================================================== +// Entry point +// =========================================================================== + #[proc_macro_derive(Traversable, attributes(traversable))] pub fn derive_traversable(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); @@ -53,6 +200,10 @@ pub fn derive_traversable(input: TokenStream) -> TokenStream { TokenStream::from(output) } +// =========================================================================== +// Traversable generation +// =========================================================================== + fn gen_traversable(input: &DeriveInput) -> proc_macro2::TokenStream { let name = &input.ident; let generics = &input.generics; @@ -68,20 +219,16 @@ fn gen_traversable(input: &DeriveInput) -> proc_macro2::TokenStream { .to_compile_error(); }; - // Handle single-field tuple struct delegation (automatic transparent) + // Single-field tuple struct: delegate (automatic transparent). if let Fields::Unnamed(fields) = &data.fields && fields.unnamed.len() == 1 { let field_ty = &fields.unnamed.first().unwrap().ty; let where_clause = build_where_clause(generics, &[], &[field_ty]); let to_tree_node_body = if let Some(wrap_key) = &struct_attr.wrap { - quote! { - brk_traversable::TreeNode::wrap(#wrap_key, self.0.to_tree_node()) - } + quote! { brk_traversable::TreeNode::wrap(#wrap_key, self.0.to_tree_node()) } } else { - quote! { - self.0.to_tree_node() - } + quote! { self.0.to_tree_node() } }; return quote! { impl #impl_generics Traversable for #name #ty_generics #where_clause { @@ -96,7 +243,7 @@ fn gen_traversable(input: &DeriveInput) -> proc_macro2::TokenStream { }; } - // Handle named fields + // Named fields required from here. let Fields::Named(named_fields) = &data.fields else { return quote! { impl #impl_generics Traversable for #name #ty_generics { @@ -111,7 +258,7 @@ fn gen_traversable(input: &DeriveInput) -> proc_macro2::TokenStream { }; }; - // Handle transparent delegation for named structs (delegates to first field) + // Transparent delegation: forward everything to the first field. if struct_attr.transparent { let first_field = named_fields .named @@ -160,19 +307,6 @@ fn gen_traversable(input: &DeriveInput) -> proc_macro2::TokenStream { } } -enum FieldAttr { - Normal, - Flatten, -} - -struct FieldInfo<'a> { - name: &'a syn::Ident, - is_option: bool, - attr: FieldAttr, - rename: Option, - wrap: Option, -} - fn analyze_fields<'a>( fields: &'a syn::FieldsNamed, generic_params: &[&'a syn::Ident], @@ -183,7 +317,6 @@ fn analyze_fields<'a>( for field in &fields.named { let Some((attr, rename, wrap)) = get_field_attr(field) else { - // Skip attribute means don't process at all continue; }; @@ -205,8 +338,6 @@ fn analyze_fields<'a>( { generics_set.insert(param); } else { - // For non-bare-generic field types, add a Traversable bound. - // For Option fields, unwrap to get the inner T. let ty = if is_option { extract_option_inner(&field.ty).unwrap_or(&field.ty) } else { @@ -231,86 +362,32 @@ fn analyze_fields<'a>( ) } -/// Extract the inner type from `Option`, returning `Some(&T)`. -fn extract_option_inner(ty: &Type) -> Option<&Type> { - if let Type::Path(type_path) = ty - && let Some(seg) = type_path.path.segments.last() - && seg.ident == "Option" - && let syn::PathArguments::AngleBracketed(args) = &seg.arguments - && let Some(syn::GenericArgument::Type(inner)) = args.args.first() +fn build_where_clause( + generics: &syn::Generics, + generics_needing_traversable: &[&syn::Ident], + extra_traversable_types: &[&syn::Type], +) -> proc_macro2::TokenStream { + let generic_params: Vec<_> = generics.type_params().map(|p| &p.ident).collect(); + let original_predicates = generics.where_clause.as_ref().map(|w| &w.predicates); + + if generics_needing_traversable.is_empty() + && extra_traversable_types.is_empty() + && generic_params.is_empty() + && original_predicates.is_none() { - Some(inner) - } else { - None - } -} - -/// Returns None for skip, Some((attr, rename, wrap)) for normal/flatten -fn get_field_attr(field: &syn::Field) -> Option<(FieldAttr, Option, Option)> { - let mut attr_type = FieldAttr::Normal; - let mut rename = None; - let mut wrap = None; - - for attr in &field.attrs { - if !attr.path().is_ident("traversable") { - continue; - } - - // Try parsing as a single ident (skip, flatten) - if let Ok(ident) = attr.parse_args::() { - match ident.to_string().as_str() { - "skip" => return None, - "flatten" => attr_type = FieldAttr::Flatten, - _ => {} - } - continue; - } - - // Try parsing as comma-separated name-value pairs (rename = "...", wrap = "...") - if let Ok(metas) = attr.parse_args_with( - syn::punctuated::Punctuated::::parse_terminated, - ) { - for meta in metas { - if let syn::Expr::Lit(syn::ExprLit { - lit: syn::Lit::Str(lit_str), - .. - }) = &meta.value - { - if meta.path.is_ident("rename") { - rename = Some(lit_str.value()); - } else if meta.path.is_ident("wrap") { - wrap = Some(lit_str.value()); - } - } - } - } + return quote! {}; } - Some((attr_type, rename, wrap)) -} - -fn is_option_type(ty: &Type) -> bool { - matches!( - ty, - Type::Path(type_path) - if type_path.path.segments.last() - .is_some_and(|seg| seg.ident == "Option") - ) -} - -fn is_box_type(ty: &Type) -> bool { - matches!( - ty, - Type::Path(type_path) - if type_path.path.segments.last() - .is_some_and(|seg| seg.ident == "Box") - ) + quote! { + where + #(#generics_needing_traversable: brk_traversable::Traversable,)* + #(#extra_traversable_types: brk_traversable::Traversable,)* + #(#generic_params: Send + Sync,)* + #original_predicates + } } fn generate_field_traversals(infos: &[FieldInfo], merge: bool) -> proc_macro2::TokenStream { - let has_flatten = infos.iter().any(|i| matches!(i.attr, FieldAttr::Flatten)); - - // Generate normal field entries let normal_entries: Vec<_> = infos .iter() .filter(|i| matches!(i.attr, FieldAttr::Normal)) @@ -321,10 +398,6 @@ fn generate_field_traversals(infos: &[FieldInfo], merge: bool) -> proc_macro2::T s.strip_prefix('_').map(String::from).unwrap_or(s) }; - // Determine outer key and inner wrap key based on which attrs are present - // When both wrap and rename are present: wrap is outer container, rename is inner key - // When only wrap: wrap is outer container, field_name is inner key - // When only rename: rename is outer, no inner wrapping let (outer_key, inner_wrap): (&str, Option<&str>) = match (info.wrap.as_deref(), info.rename.as_deref()) { (Some(wrap), Some(rename)) => (wrap, Some(rename)), @@ -333,7 +406,6 @@ fn generate_field_traversals(infos: &[FieldInfo], merge: bool) -> proc_macro2::T (None, None) => (&field_name_str, None), }; - // Generate tree node expression, optionally wrapped let node_expr = if let Some(inner_key) = inner_wrap { quote! { brk_traversable::TreeNode::wrap(#inner_key, nested.to_tree_node()) } } else { @@ -357,46 +429,33 @@ fn generate_field_traversals(infos: &[FieldInfo], merge: bool) -> proc_macro2::T }) .collect(); - // Generate flatten field entries let flatten_entries: Vec<_> = infos .iter() .filter(|i| matches!(i.attr, FieldAttr::Flatten)) .map(|info| { let field_name = info.name; + let merge_branch = quote! { + brk_traversable::TreeNode::Branch(map) => { + for (key, node) in map { + brk_traversable::TreeNode::merge_node(&mut collected, key, node) + .expect("Conflicting values for same key during flatten"); + } + } + leaf @ brk_traversable::TreeNode::Leaf(_) => { + brk_traversable::TreeNode::merge_node(&mut collected, String::from(stringify!(#field_name)), leaf) + .expect("Conflicting values for same key during flatten"); + } + }; if info.is_option { quote! { if let Some(ref nested) = self.#field_name { - match nested.to_tree_node() { - brk_traversable::TreeNode::Branch(map) => { - for (key, node) in map { - brk_traversable::TreeNode::merge_node(&mut collected, key, node) - .expect("Conflicting values for same key during flatten"); - } - } - leaf @ brk_traversable::TreeNode::Leaf(_) => { - // Collapsed leaf from child - insert with field name as key - brk_traversable::TreeNode::merge_node(&mut collected, String::from(stringify!(#field_name)), leaf) - .expect("Conflicting values for same key during flatten"); - } - } + match nested.to_tree_node() { #merge_branch } } } } else { quote! { - match self.#field_name.to_tree_node() { - brk_traversable::TreeNode::Branch(map) => { - for (key, node) in map { - brk_traversable::TreeNode::merge_node(&mut collected, key, node) - .expect("Conflicting values for same key during flatten"); - } - } - leaf @ brk_traversable::TreeNode::Leaf(_) => { - // Collapsed leaf from child - insert with field name as key - brk_traversable::TreeNode::merge_node(&mut collected, String::from(stringify!(#field_name)), leaf) - .expect("Conflicting values for same key during flatten"); - } - } + match self.#field_name.to_tree_node() { #merge_branch } } } }) @@ -408,48 +467,28 @@ fn generate_field_traversals(infos: &[FieldInfo], merge: bool) -> proc_macro2::T quote! { brk_traversable::TreeNode::Branch(collected) } }; - // Build collected map initialization based on what we have - // Use merge_entry to handle duplicate keys (e.g., multiple fields renamed to same key) - let (init_collected, extend_flatten) = if !has_flatten { - // No flatten fields - use merge_entry for each to handle duplicates - ( - quote! { - let mut collected: brk_traversable::IndexMap = - brk_traversable::IndexMap::new(); - for entry in [#(#normal_entries,)*].into_iter().flatten() { - brk_traversable::TreeNode::merge_node(&mut collected, entry.0, entry.1) - .expect("Conflicting values for same key"); - } - }, - quote! {}, - ) - } else if normal_entries.is_empty() { - // Only flatten fields - explicit type annotation needed - ( - quote! { - let mut collected: brk_traversable::IndexMap = - brk_traversable::IndexMap::new(); - }, - quote! { #(#flatten_entries)* }, - ) - } else { - // Both normal and flatten fields - use merge_entry for normal fields - ( - quote! { - let mut collected: brk_traversable::IndexMap = - brk_traversable::IndexMap::new(); - for entry in [#(#normal_entries,)*].into_iter().flatten() { - brk_traversable::TreeNode::merge_node(&mut collected, entry.0, entry.1) - .expect("Conflicting values for same key"); - } - }, - quote! { #(#flatten_entries)* }, - ) + let init_collected = quote! { + let mut collected: brk_traversable::IndexMap = + brk_traversable::IndexMap::new(); }; + let normal_insert = if !normal_entries.is_empty() { + quote! { + for entry in [#(#normal_entries,)*].into_iter().flatten() { + brk_traversable::TreeNode::merge_node(&mut collected, entry.0, entry.1) + .expect("Conflicting values for same key"); + } + } + } else { + quote! {} + }; + + let flatten_insert = quote! { #(#flatten_entries)* }; + quote! { #init_collected - #extend_flatten + #normal_insert + #flatten_insert #final_expr } } @@ -518,93 +557,19 @@ fn generate_iterator_impl(infos: &[FieldInfo]) -> proc_macro2::TokenStream { } } -fn build_where_clause( - generics: &syn::Generics, - generics_needing_traversable: &[&syn::Ident], - extra_traversable_types: &[&syn::Type], -) -> proc_macro2::TokenStream { - let generic_params: Vec<_> = generics.type_params().map(|p| &p.ident).collect(); - let original_predicates = generics.where_clause.as_ref().map(|w| &w.predicates); - - if generics_needing_traversable.is_empty() - && extra_traversable_types.is_empty() - && generic_params.is_empty() - && original_predicates.is_none() - { - return quote! {}; - } - - quote! { - where - #(#generics_needing_traversable: brk_traversable::Traversable,)* - #(#extra_traversable_types: brk_traversable::Traversable,)* - #(#generic_params: Send + Sync,)* - #original_predicates - } -} - -// --------------------------------------------------------------------------- -// ReadOnlyClone + Clone generation -// --------------------------------------------------------------------------- - -/// Find the generic type parameter bounded by `StorageMode`, if any. -fn find_storage_mode_param(generics: &syn::Generics) -> Option<&syn::Ident> { - generics.type_params().find_map(|p| { - p.bounds - .iter() - .any(|b| { - matches!(b, syn::TypeParamBound::Trait(t) - if t.path.segments.last().is_some_and(|s| s.ident == "StorageMode")) - }) - .then_some(&p.ident) - }) -} - -/// Check if a type AST references the given identifier anywhere. -fn type_contains_ident(ty: &Type, ident: &syn::Ident) -> bool { - match ty { - Type::Path(type_path) => { - // Check qualified self (e.g. ::Stored) - if let Some(qself) = &type_path.qself - && type_contains_ident(&qself.ty, ident) - { - return true; - } - type_path.path.segments.iter().any(|seg| { - if seg.ident == *ident { - return true; - } - match &seg.arguments { - syn::PathArguments::AngleBracketed(args) => args.args.iter().any(|arg| { - matches!(arg, syn::GenericArgument::Type(inner) if type_contains_ident(inner, ident)) - }), - syn::PathArguments::Parenthesized(args) => { - args.inputs.iter().any(|inner| type_contains_ident(inner, ident)) - || matches!(&args.output, syn::ReturnType::Type(_, inner) if type_contains_ident(inner, ident)) - } - syn::PathArguments::None => false, - } - }) - } - Type::Reference(r) => type_contains_ident(&r.elem, ident), - Type::Tuple(t) => t.elems.iter().any(|e| type_contains_ident(e, ident)), - Type::Array(a) => type_contains_ident(&a.elem, ident), - Type::Slice(s) => type_contains_ident(&s.elem, ident), - Type::Paren(p) => type_contains_ident(&p.elem, ident), - _ => false, - } -} +// =========================================================================== +// ReadOnlyClone generation +// =========================================================================== /// Generate `ReadOnlyClone` for Traversable-derived types. /// -/// - Types with `M: StorageMode` → maps `Self` → `Self`. -/// - Types with other generic type params (no M) → propagates `ReadOnlyClone` through each param. -/// - Types with no generic type params → nothing generated (they should `#[derive(Clone)]`). +/// Three paths: +/// 1. `M: StorageMode` → concrete impl mapping `Self` → `Self`. +/// 2. Generic container params → propagates `ReadOnlyClone` through each param. +/// 3. No container params → nothing generated. /// -/// Container params (mapped through ReadOnlyClone) are identified by: -/// - Unbounded type params (no inline or where-clause bounds), OR -/// - Bounded params that appear as a bare field type in a non-skipped field -/// (e.g. `metrics: M` where M is the param itself). +/// Container params are: unbounded type params, OR bounded params that appear +/// as a bare field type (e.g. `field: M` where M is the param itself). fn gen_read_only_clone(input: &DeriveInput) -> proc_macro2::TokenStream { let generics = &input.generics; let name = &input.ident; @@ -613,11 +578,12 @@ fn gen_read_only_clone(input: &DeriveInput) -> proc_macro2::TokenStream { return quote! {}; }; + // Path 1: StorageMode param → Rw/Ro substitution. if let Some(mode_param) = find_storage_mode_param(generics) { - return gen_read_only_clone_for_m(name, generics, data, mode_param); + return gen_read_only_clone_storage_mode(name, generics, data, mode_param); } - // Collect all generic type params. + // Path 2/3: classify type params as containers or leaves. let type_params: Vec<&syn::TypeParam> = generics .params .iter() @@ -631,53 +597,36 @@ fn gen_read_only_clone(input: &DeriveInput) -> proc_macro2::TokenStream { return quote! {}; } - // Determine which type params have bounds (inline or via where clause). - let where_bounded: Vec<&syn::Ident> = if let Some(where_clause) = &generics.where_clause { - where_clause - .predicates - .iter() - .filter_map(|pred| { - if let syn::WherePredicate::Type(pt) = pred - && let Type::Path(tp) = &pt.bounded_ty - && let Some(seg) = tp.path.segments.first() - { - type_params - .iter() - .find(|p| p.ident == seg.ident) - .map(|p| &p.ident) - } else { - None - } - }) - .collect() - } else { - Vec::new() + let is_bounded = |tp: &syn::TypeParam| -> bool { + if !tp.bounds.is_empty() { + return true; + } + if let Some(wc) = &generics.where_clause { + return wc.predicates.iter().any(|pred| { + matches!(pred, syn::WherePredicate::Type(pt) + if matches!(&pt.bounded_ty, Type::Path(p) + if p.path.segments.first().is_some_and(|s| s.ident == tp.ident))) + }); + } + false }; - // Find params that appear as bare (direct) field types in non-skipped fields. let bare_field_params = find_bare_field_params(data, &type_params); - // Container params: unbounded OR bare-field params. let container_params: Vec<&syn::Ident> = type_params .iter() - .filter(|tp| { - let is_unbounded = tp.bounds.is_empty() && !where_bounded.contains(&&tp.ident); - let is_bare = bare_field_params.contains(&&tp.ident); - is_unbounded || is_bare - }) + .filter(|tp| !is_bounded(tp) || bare_field_params.contains(&&tp.ident)) .map(|tp| &tp.ident) .collect(); - // If no container params, this is a pure leaf type — skip. if container_params.is_empty() { return quote! {}; } - gen_read_only_clone_for_generics(name, generics, data, &type_params, &container_params) + gen_read_only_clone_generics(name, generics, data, &type_params, &container_params) } -/// Find type params that appear as bare (direct) field types in non-skipped fields. -/// E.g. `metrics: M` where M is a type param → M is a bare field param. +/// Find type params used as bare (direct) field types in non-skipped fields. fn find_bare_field_params<'a>( data: &syn::DataStruct, type_params: &[&'a syn::TypeParam], @@ -706,14 +655,107 @@ fn find_bare_field_params<'a>( bare } -/// Generate `ReadOnlyClone` for types with `M: StorageMode`. -fn gen_read_only_clone_for_m( +// --------------------------------------------------------------------------- +// Shared field-conversion helpers +// --------------------------------------------------------------------------- + +/// Generate the value expression for a single field in a ReadOnlyClone impl. +/// +/// - Skipped + Option → `None` +/// - Skipped + non-Option → `Default::default()` +/// - Contains relevant param + Box → `Box::new(read_only_clone(&*self.field))` +/// - Contains relevant param → `read_only_clone(&self.field)` +/// - Otherwise → `self.field.clone()` +fn gen_roc_field_value( + field: &syn::Field, + self_access: proc_macro2::TokenStream, + is_relevant: impl Fn(&Type) -> bool, +) -> proc_macro2::TokenStream { + if is_field_skipped(field) { + if is_option_type(&field.ty) { + return quote! { None }; + } + return quote! { #self_access.clone() }; + } + + if is_relevant(&field.ty) { + if is_box_type(&field.ty) { + quote! { Box::new(vecdb::ReadOnlyClone::read_only_clone(&*#self_access)) } + } else { + quote! { vecdb::ReadOnlyClone::read_only_clone(&#self_access) } + } + } else { + quote! { #self_access.clone() } + } +} + +/// Generate the struct body for a ReadOnlyClone impl. +fn gen_roc_body( + name: &syn::Ident, + data: &syn::DataStruct, + is_relevant: impl Fn(&Type) -> bool, +) -> proc_macro2::TokenStream { + match &data.fields { + Fields::Named(named) => { + let conversions: Vec<_> = named + .named + .iter() + .map(|f| { + let field_name = f.ident.as_ref().unwrap(); + let value = gen_roc_field_value(f, quote! { self.#field_name }, &is_relevant); + quote! { #field_name: #value } + }) + .collect(); + quote! { #name { #(#conversions,)* } } + } + Fields::Unnamed(unnamed) => { + let conversions: Vec<_> = unnamed + .unnamed + .iter() + .enumerate() + .map(|(i, f)| { + let idx = syn::Index::from(i); + gen_roc_field_value(f, quote! { self.#idx }, &is_relevant) + }) + .collect(); + quote! { #name(#(#conversions,)*) } + } + Fields::Unit => quote! { #name }, + } +} + +/// Collect type args from generics, applying a mapping function to each. +fn collect_ty_args( + generics: &syn::Generics, + map_type: impl Fn(&syn::TypeParam) -> proc_macro2::TokenStream, +) -> Vec { + generics + .params + .iter() + .map(|p| match p { + syn::GenericParam::Type(tp) => map_type(tp), + syn::GenericParam::Lifetime(lt) => { + let lt = <.lifetime; + quote! { #lt } + } + syn::GenericParam::Const(c) => { + let id = &c.ident; + quote! { #id } + } + }) + .collect() +} + +// --------------------------------------------------------------------------- +// Path 1: StorageMode → Rw/Ro substitution +// --------------------------------------------------------------------------- + +fn gen_read_only_clone_storage_mode( name: &syn::Ident, generics: &syn::Generics, data: &syn::DataStruct, mode_param: &syn::Ident, ) -> proc_macro2::TokenStream { - // Impl generics: all params except M, with bounds but without defaults. let impl_params: Vec = generics .params .iter() @@ -737,80 +779,22 @@ fn gen_read_only_clone_for_m( }) .collect(); - // Type args with M replaced by Rw / Ro. - let make_ty_args = |replacement: proc_macro2::TokenStream| -> Vec { - generics - .params - .iter() - .map(|p| match p { - syn::GenericParam::Type(tp) if tp.ident == *mode_param => replacement.clone(), - syn::GenericParam::Type(tp) => { - let id = &tp.ident; - quote! { #id } - } - syn::GenericParam::Lifetime(lt) => { - let lt = <.lifetime; - quote! { #lt } - } - syn::GenericParam::Const(c) => { - let id = &c.ident; - quote! { #id } - } - }) - .collect() + let make_ty_args = |replacement: proc_macro2::TokenStream| { + collect_ty_args(generics, |tp| { + if tp.ident == *mode_param { + replacement.clone() + } else { + let id = &tp.ident; + quote! { #id } + } + }) }; let ty_args_rw = make_ty_args(quote! { vecdb::Rw }); let ty_args_ro = make_ty_args(quote! { vecdb::Ro }); - let where_clause = &generics.where_clause; - let body = match &data.fields { - Fields::Named(named) => { - let field_conversions: Vec<_> = named - .named - .iter() - .map(|f| { - let field_name = f.ident.as_ref().unwrap(); - if is_field_skipped(f) && is_option_type(&f.ty) { - quote! { #field_name: None } - } else if type_contains_ident(&f.ty, mode_param) { - if is_box_type(&f.ty) { - quote! { #field_name: Box::new(vecdb::ReadOnlyClone::read_only_clone(&*self.#field_name)) } - } else { - quote! { #field_name: vecdb::ReadOnlyClone::read_only_clone(&self.#field_name) } - } - } else { - quote! { #field_name: self.#field_name.clone() } - } - }) - .collect(); - quote! { #name { #(#field_conversions,)* } } - } - Fields::Unnamed(unnamed) => { - let field_conversions: Vec<_> = unnamed - .unnamed - .iter() - .enumerate() - .map(|(i, f)| { - let idx = syn::Index::from(i); - if is_field_skipped(f) && is_option_type(&f.ty) { - quote! { None } - } else if type_contains_ident(&f.ty, mode_param) { - if is_box_type(&f.ty) { - quote! { Box::new(vecdb::ReadOnlyClone::read_only_clone(&*self.#idx)) } - } else { - quote! { vecdb::ReadOnlyClone::read_only_clone(&self.#idx) } - } - } else { - quote! { self.#idx.clone() } - } - }) - .collect(); - quote! { #name(#(#field_conversions,)*) } - } - Fields::Unit => quote! { #name }, - }; + let body = gen_roc_body(name, data, |ty| type_contains_ident(ty, mode_param)); let impl_generics = if impl_params.is_empty() { quote! {} @@ -829,31 +813,18 @@ fn gen_read_only_clone_for_m( } } -/// Check if a field has `#[traversable(skip)]`. -fn is_field_skipped(field: &syn::Field) -> bool { - field.attrs.iter().any(|attr| { - attr.path().is_ident("traversable") - && attr.parse_args::().is_ok_and(|id| id == "skip") - }) -} +// --------------------------------------------------------------------------- +// Path 2: Generic container params → ReadOnlyClone propagation +// --------------------------------------------------------------------------- -/// Generate `ReadOnlyClone` for types with generic type params but no `M: StorageMode`. -/// -/// `container_params` are type params that get `ReadOnlyClone` bounds and are -/// mapped to `T::ReadOnly` in the target type. -/// Leaf type params are kept as-is — they don't change across storage modes. -/// Fields containing container params use `.read_only_clone()`, others use `.clone()`. -/// -/// For bounded container params, the original bounds are preserved and propagated -/// to the ReadOnly version via where clause (e.g. `M::ReadOnly: CohortMetricsState`). -fn gen_read_only_clone_for_generics( +fn gen_read_only_clone_generics( name: &syn::Ident, generics: &syn::Generics, data: &syn::DataStruct, type_params: &[&syn::TypeParam], container_params: &[&syn::Ident], ) -> proc_macro2::TokenStream { - // Check if any non-skipped field references a container param (otherwise skip). + // Check if any non-skipped field actually uses a container param. let has_container_field = match &data.fields { Fields::Named(named) => named.named.iter().any(|f| { !is_field_skipped(f) @@ -876,8 +847,7 @@ fn gen_read_only_clone_for_generics( let is_container = |ident: &syn::Ident| container_params.iter().any(|cp| *cp == ident); - // Impl generics: add ReadOnlyClone bound to container params, keep bounds for leaf params. - // For bounded container params, preserve original bounds alongside ReadOnlyClone. + // Impl params: containers get ReadOnlyClone (+ original bounds), others keep their bounds. let impl_params: Vec = generics .params .iter() @@ -906,55 +876,23 @@ fn gen_read_only_clone_for_generics( }) .collect(); - // Self type args (just the param names). - let self_ty_args: Vec = generics - .params - .iter() - .map(|p| match p { - syn::GenericParam::Type(tp) => { - let id = &tp.ident; - quote! { #id } - } - syn::GenericParam::Lifetime(lt) => { - let lt = <.lifetime; - quote! { #lt } - } - syn::GenericParam::Const(c) => { - let id = &c.ident; - quote! { #id } - } - }) - .collect(); + let self_ty_args = collect_ty_args(generics, |tp| { + let id = &tp.ident; + quote! { #id } + }); - // ReadOnly type args: map container params to ReadOnly, keep leaf params as-is. - let ro_ty_args: Vec = generics - .params - .iter() - .map(|p| match p { - syn::GenericParam::Type(tp) => { - let id = &tp.ident; - if is_container(id) { - quote! { <#id as vecdb::ReadOnlyClone>::ReadOnly } - } else { - quote! { #id } - } - } - syn::GenericParam::Lifetime(lt) => { - let lt = <.lifetime; - quote! { #lt } - } - syn::GenericParam::Const(c) => { - let id = &c.ident; - quote! { #id } - } - }) - .collect(); + let ro_ty_args = collect_ty_args(generics, |tp| { + let id = &tp.ident; + if is_container(id) { + quote! { <#id as vecdb::ReadOnlyClone>::ReadOnly } + } else { + quote! { #id } + } + }); - // Build where clause: propagate bounds from bounded container params to their ReadOnly. - // E.g. `M: Trait` → add `::ReadOnly: Trait`. + // Where clause: propagate bounds from bounded container params to their ReadOnly. let mut extra_where: Vec = Vec::new(); - // Propagate inline bounds. for tp in type_params { if is_container(&tp.ident) && !tp.bounds.is_empty() { let ident = &tp.ident; @@ -965,7 +903,6 @@ fn gen_read_only_clone_for_generics( } } - // Propagate where-clause bounds for container params. if let Some(wc) = &generics.where_clause { for pred in &wc.predicates { if let syn::WherePredicate::Type(pt) = pred @@ -982,67 +919,18 @@ fn gen_read_only_clone_for_generics( } } - let original_predicates = generics - .where_clause - .as_ref() - .map(|w| &w.predicates); - + let original_predicates = generics.where_clause.as_ref().map(|w| &w.predicates); let combined_where = if extra_where.is_empty() && original_predicates.is_none() { quote! {} } else { quote! { where #(#extra_where,)* #original_predicates } }; - // Field-level: if field type contains any container param → read_only_clone, else → clone. - let field_contains_container_param = - |ty: &Type| container_params.iter().any(|tp| type_contains_ident(ty, tp)); - - let body = match &data.fields { - Fields::Named(named) => { - let field_conversions: Vec<_> = named - .named - .iter() - .map(|f| { - let field_name = f.ident.as_ref().unwrap(); - if is_field_skipped(f) { - quote! { #field_name: Default::default() } - } else if field_contains_container_param(&f.ty) { - if is_box_type(&f.ty) { - quote! { #field_name: Box::new(vecdb::ReadOnlyClone::read_only_clone(&*self.#field_name)) } - } else { - quote! { #field_name: vecdb::ReadOnlyClone::read_only_clone(&self.#field_name) } - } - } else { - quote! { #field_name: self.#field_name.clone() } - } - }) - .collect(); - quote! { #name { #(#field_conversions,)* } } - } - Fields::Unnamed(unnamed) => { - let field_conversions: Vec<_> = unnamed - .unnamed - .iter() - .enumerate() - .map(|(i, f)| { - let idx = syn::Index::from(i); - if is_field_skipped(f) { - quote! { Default::default() } - } else if field_contains_container_param(&f.ty) { - if is_box_type(&f.ty) { - quote! { Box::new(vecdb::ReadOnlyClone::read_only_clone(&*self.#idx)) } - } else { - quote! { vecdb::ReadOnlyClone::read_only_clone(&self.#idx) } - } - } else { - quote! { self.#idx.clone() } - } - }) - .collect(); - quote! { #name(#(#field_conversions,)*) } - } - Fields::Unit => quote! { #name }, - }; + let body = gen_roc_body(name, data, |ty| { + container_params + .iter() + .any(|tp| type_contains_ident(ty, tp)) + }); quote! { impl<#(#impl_params),*> vecdb::ReadOnlyClone for #name<#(#self_ty_args),*> #combined_where {