//! Python base client and pattern factory generation. use std::fmt::Write; use crate::{ ClientConstants, ClientMetadata, CohortConstants, IndexSetPattern, PythonSyntax, StructuralPattern, format_json, generate_parameterized_field, index_to_field_name, }; /// Generate class-level constants for the BrkClient class. pub fn generate_class_constants(output: &mut String) { let constants = ClientConstants::collect(); // VERSION writeln!(output, " VERSION = \"{}\"\n", constants.version).unwrap(); // INDEXES, POOL_ID_TO_POOL_NAME write_class_const(output, "INDEXES", &format_json(&constants.indexes)); // Python needs string keys for pool map let pool_map: std::collections::BTreeMap = constants .pool_map .iter() .map(|(k, v)| (k.to_string(), *v)) .collect(); write_class_const(output, "POOL_ID_TO_POOL_NAME", &format_json(&pool_map)); // Cohort constants (no camelCase conversion for Python) for (name, value) in CohortConstants::all() { write_class_const(output, name, &format_json(&value)); } } fn write_class_const(output: &mut String, name: &str, json: &str) { let indented = json .lines() .enumerate() .map(|(i, line)| { if i == 0 { format!(" {} = {}", name, line) } else { format!(" {}", line) } }) .collect::>() .join("\n"); writeln!(output, "{}\n", indented).unwrap(); } /// Generate the base BrkClient class with HTTP functionality pub fn generate_base_client(output: &mut String) { writeln!( output, r#"class BrkError(Exception): """Custom error class for BRK client errors.""" def __init__(self, message: str, status: Optional[int] = None): super().__init__(message) self.status = status class BrkClientBase: """Base HTTP client for making requests.""" def __init__(self, base_url: str, timeout: float = 30.0): parsed = urlparse(base_url) self._host = parsed.netloc self._secure = parsed.scheme == 'https' self._timeout = timeout self._conn: Optional[Union[HTTPSConnection, HTTPConnection]] = None def _connect(self) -> Union[HTTPSConnection, HTTPConnection]: """Get or create HTTP connection.""" if self._conn is None: if self._secure: self._conn = HTTPSConnection(self._host, timeout=self._timeout) else: self._conn = HTTPConnection(self._host, timeout=self._timeout) return self._conn def get(self, path: str) -> bytes: """Make a GET request and return raw bytes.""" try: conn = self._connect() conn.request("GET", path) res = conn.getresponse() data = res.read() if res.status >= 400: raise BrkError(f"HTTP error: {{res.status}}", res.status) return data except (ConnectionError, OSError, TimeoutError) as e: self._conn = None raise BrkError(str(e)) def get_json(self, path: str) -> Any: """Make a GET request and return JSON.""" return json.loads(self.get(path)) def get_text(self, path: str) -> str: """Make a GET request and return text.""" return self.get(path).decode() def close(self) -> None: """Close the HTTP client.""" if self._conn: self._conn.close() self._conn = None def __enter__(self) -> BrkClientBase: return self def __exit__(self, exc_type: Optional[type], exc_val: Optional[BaseException], exc_tb: Optional[Any]) -> None: self.close() def _m(acc: str, s: str) -> str: """Build series name with suffix.""" if not s: return acc return f"{{acc}}_{{s}}" if acc else s def _p(prefix: str, acc: str) -> str: """Build series name with prefix.""" return f"{{prefix}}_{{acc}}" if acc else prefix "# ) .unwrap(); } /// Generate the SeriesData and SeriesEndpoint classes pub fn generate_endpoint_class(output: &mut String) { writeln!( output, r#"# Date conversion constants _GENESIS = date(2009, 1, 3) # day1 0, week1 0 _DAY_ONE = date(2009, 1, 9) # day1 1 (6 day gap after genesis) _EPOCH = datetime(2009, 1, 1, tzinfo=timezone.utc) _DATE_INDEXES = frozenset([ 'minute10', 'minute30', 'hour1', 'hour4', 'hour12', 'day1', 'day3', 'week1', 'month1', 'month3', 'month6', 'year1', 'year10', ]) def _index_to_date(index: str, i: int) -> Union[date, datetime]: """Convert an index value to a date/datetime for date-based indexes.""" if index == 'minute10': return _EPOCH + timedelta(minutes=i * 10) elif index == 'minute30': return _EPOCH + timedelta(minutes=i * 30) elif index == 'hour1': return _EPOCH + timedelta(hours=i) elif index == 'hour4': return _EPOCH + timedelta(hours=i * 4) elif index == 'hour12': return _EPOCH + timedelta(hours=i * 12) elif index == 'day1': return _GENESIS if i == 0 else _DAY_ONE + timedelta(days=i - 1) elif index == 'day3': return _EPOCH.date() - timedelta(days=1) + timedelta(days=i * 3) elif index == 'week1': return _GENESIS + timedelta(weeks=i) elif index == 'month1': return date(2009 + i // 12, i % 12 + 1, 1) elif index == 'month3': m = i * 3 return date(2009 + m // 12, m % 12 + 1, 1) elif index == 'month6': m = i * 6 return date(2009 + m // 12, m % 12 + 1, 1) elif index == 'year1': return date(2009 + i, 1, 1) elif index == 'year10': return date(2009 + i * 10, 1, 1) else: raise ValueError(f"{{index}} is not a date-based index") def _date_to_index(index: str, d: Union[date, datetime]) -> int: """Convert a date/datetime to an index value for date-based indexes. Returns the floor index (latest index whose date is <= the given date). For sub-day indexes (minute*, hour*), a plain date is treated as midnight UTC. """ if index in ('minute10', 'minute30', 'hour1', 'hour4', 'hour12'): if isinstance(d, datetime): dt = d if d.tzinfo else d.replace(tzinfo=timezone.utc) else: dt = datetime(d.year, d.month, d.day, tzinfo=timezone.utc) secs = int((dt - _EPOCH).total_seconds()) div = {{'minute10': 600, 'minute30': 1800, 'hour1': 3600, 'hour4': 14400, 'hour12': 43200}} return secs // div[index] dd = d.date() if isinstance(d, datetime) else d if index == 'day1': if dd < _DAY_ONE: return 0 return 1 + (dd - _DAY_ONE).days elif index == 'day3': return (dd - date(2008, 12, 31)).days // 3 elif index == 'week1': return (dd - _GENESIS).days // 7 elif index == 'month1': return (dd.year - 2009) * 12 + (dd.month - 1) elif index == 'month3': return (dd.year - 2009) * 4 + (dd.month - 1) // 3 elif index == 'month6': return (dd.year - 2009) * 2 + (dd.month - 1) // 6 elif index == 'year1': return dd.year - 2009 elif index == 'year10': return (dd.year - 2009) // 10 else: raise ValueError(f"{{index}} is not a date-based index") @dataclass class SeriesData(Generic[T]): """Series data with range information. Always int-indexed.""" version: int index: Index type: str total: int start: int end: int stamp: str data: List[T] @property def is_date_based(self) -> bool: """Whether this series uses a date-based index.""" return self.index in _DATE_INDEXES def indexes(self) -> List[int]: """Get raw index numbers.""" return list(range(self.start, self.end)) def keys(self) -> List[int]: """Get keys as index numbers.""" return self.indexes() def items(self) -> List[Tuple[int, T]]: """Get (index, value) pairs.""" return list(zip(self.indexes(), self.data)) def to_dict(self) -> Dict[int, T]: """Return {{index: value}} dict.""" return dict(zip(self.indexes(), self.data)) def __iter__(self) -> Iterator[Tuple[int, T]]: """Iterate over (index, value) pairs.""" return iter(zip(self.indexes(), self.data)) def __len__(self) -> int: return len(self.data) def to_polars(self) -> pl.DataFrame: """Convert to Polars DataFrame with 'index' and 'value' columns.""" try: import polars as pl # type: ignore[import-not-found] except ImportError: raise ImportError("polars is required: pip install polars") return pl.DataFrame({{"index": self.indexes(), "value": self.data}}) def to_pandas(self) -> pd.DataFrame: """Convert to Pandas DataFrame with 'index' and 'value' columns.""" try: import pandas as pd # type: ignore[import-not-found] except ImportError: raise ImportError("pandas is required: pip install pandas") return pd.DataFrame({{"index": self.indexes(), "value": self.data}}) @dataclass class DateSeriesData(SeriesData[T]): """Series data with date-based index. Extends SeriesData with date methods.""" def dates(self) -> List[Union[date, datetime]]: """Get dates for the index range. Returns datetime for sub-daily indexes, date for daily+.""" return [_index_to_date(self.index, i) for i in range(self.start, self.end)] def date_items(self) -> List[Tuple[Union[date, datetime], T]]: """Get (date, value) pairs.""" return list(zip(self.dates(), self.data)) def to_date_dict(self) -> Dict[Union[date, datetime], T]: """Return {{date: value}} dict.""" return dict(zip(self.dates(), self.data)) def to_polars(self, with_dates: bool = True) -> pl.DataFrame: """Convert to Polars DataFrame. Returns a DataFrame with columns: - 'date' and 'value' if with_dates=True (default) - 'index' and 'value' otherwise """ try: import polars as pl # type: ignore[import-not-found] except ImportError: raise ImportError("polars is required: pip install polars") if with_dates: return pl.DataFrame({{"date": self.dates(), "value": self.data}}) return pl.DataFrame({{"index": self.indexes(), "value": self.data}}) def to_pandas(self, with_dates: bool = True) -> pd.DataFrame: """Convert to Pandas DataFrame. Returns a DataFrame with columns: - 'date' and 'value' if with_dates=True (default) - 'index' and 'value' otherwise """ try: import pandas as pd # type: ignore[import-not-found] except ImportError: raise ImportError("pandas is required: pip install pandas") if with_dates: return pd.DataFrame({{"date": self.dates(), "value": self.data}}) return pd.DataFrame({{"index": self.indexes(), "value": self.data}}) # Type aliases for non-generic usage AnySeriesData = SeriesData[Any] AnyDateSeriesData = DateSeriesData[Any] class _EndpointConfig: """Shared endpoint configuration.""" client: BrkClientBase name: str index: Index start: Optional[int] end: Optional[int] def __init__(self, client: BrkClientBase, name: str, index: Index, start: Optional[int] = None, end: Optional[int] = None): self.client = client self.name = name self.index = index self.start = start self.end = end def path(self) -> str: return f"/api/series/{{self.name}}/{{self.index}}" def _build_path(self, format: Optional[str] = None) -> str: params = [] if self.start is not None: params.append(f"start={{self.start}}") if self.end is not None: params.append(f"end={{self.end}}") if format is not None: params.append(f"format={{format}}") query = "&".join(params) p = self.path() return f"{{p}}?{{query}}" if query else p def _new(self, start: Optional[int] = None, end: Optional[int] = None) -> _EndpointConfig: return _EndpointConfig(self.client, self.name, self.index, start, end) def get_series(self) -> SeriesData[Any]: return SeriesData(**self.client.get_json(self._build_path())) def get_date_series(self) -> DateSeriesData[Any]: return DateSeriesData(**self.client.get_json(self._build_path())) def get_csv(self) -> str: return self.client.get_text(self._build_path(format='csv')) class RangeBuilder(Generic[T]): """Builder with range specified.""" def __init__(self, config: _EndpointConfig): self._config = config def fetch(self) -> SeriesData[T]: """Fetch the range as parsed JSON.""" return self._config.get_series() def fetch_csv(self) -> str: """Fetch the range as CSV string.""" return self._config.get_csv() class SingleItemBuilder(Generic[T]): """Builder for single item access.""" def __init__(self, config: _EndpointConfig): self._config = config def fetch(self) -> SeriesData[T]: """Fetch the single item.""" return self._config.get_series() def fetch_csv(self) -> str: """Fetch as CSV.""" return self._config.get_csv() class SkippedBuilder(Generic[T]): """Builder after calling skip(n). Chain with take() to specify count.""" def __init__(self, config: _EndpointConfig): self._config = config def take(self, n: int) -> RangeBuilder[T]: """Take n items after the skipped position.""" start = self._config.start or 0 return RangeBuilder(self._config._new(start, start + n)) def fetch(self) -> SeriesData[T]: """Fetch from skipped position to end.""" return self._config.get_series() def fetch_csv(self) -> str: """Fetch as CSV.""" return self._config.get_csv() class DateRangeBuilder(RangeBuilder[T]): """Range builder that returns DateSeriesData.""" def fetch(self) -> DateSeriesData[T]: return self._config.get_date_series() class DateSingleItemBuilder(SingleItemBuilder[T]): """Single item builder that returns DateSeriesData.""" def fetch(self) -> DateSeriesData[T]: return self._config.get_date_series() class DateSkippedBuilder(SkippedBuilder[T]): """Skipped builder that returns DateSeriesData.""" def take(self, n: int) -> DateRangeBuilder[T]: start = self._config.start or 0 return DateRangeBuilder(self._config._new(start, start + n)) def fetch(self) -> DateSeriesData[T]: return self._config.get_date_series() class SeriesEndpoint(Generic[T]): """Builder for series endpoint queries with int-based indexing. Examples: data = endpoint.fetch() data = endpoint[5].fetch() data = endpoint[:10].fetch() data = endpoint.head(20).fetch() data = endpoint.skip(100).take(10).fetch() """ def __init__(self, client: BrkClientBase, name: str, index: Index): self._config = _EndpointConfig(client, name, index) @overload def __getitem__(self, key: int) -> SingleItemBuilder[T]: ... @overload def __getitem__(self, key: slice) -> RangeBuilder[T]: ... def __getitem__(self, key: Union[int, slice]) -> Union[SingleItemBuilder[T], RangeBuilder[T]]: """Access single item or slice by integer index.""" if isinstance(key, int): return SingleItemBuilder(self._config._new(key, key + 1)) return RangeBuilder(self._config._new(key.start, key.stop)) def head(self, n: int = 10) -> RangeBuilder[T]: """Get the first n items.""" return RangeBuilder(self._config._new(end=n)) def tail(self, n: int = 10) -> RangeBuilder[T]: """Get the last n items.""" return RangeBuilder(self._config._new(end=0) if n == 0 else self._config._new(start=-n)) def skip(self, n: int) -> SkippedBuilder[T]: """Skip the first n items.""" return SkippedBuilder(self._config._new(start=n)) def fetch(self) -> SeriesData[T]: """Fetch all data.""" return self._config.get_series() def fetch_csv(self) -> str: """Fetch all data as CSV.""" return self._config.get_csv() def path(self) -> str: """Get the base endpoint path.""" return self._config.path() class DateSeriesEndpoint(Generic[T]): """Builder for series endpoint queries with date-based indexing. Accepts dates in __getitem__ and returns DateSeriesData from fetch(). Examples: data = endpoint.fetch() data = endpoint[date(2020, 1, 1)].fetch() data = endpoint[date(2020, 1, 1):date(2023, 1, 1)].fetch() data = endpoint[:10].fetch() """ def __init__(self, client: BrkClientBase, name: str, index: Index): self._config = _EndpointConfig(client, name, index) @overload def __getitem__(self, key: int) -> DateSingleItemBuilder[T]: ... @overload def __getitem__(self, key: datetime) -> DateSingleItemBuilder[T]: ... @overload def __getitem__(self, key: date) -> DateSingleItemBuilder[T]: ... @overload def __getitem__(self, key: slice) -> DateRangeBuilder[T]: ... def __getitem__(self, key: Union[int, slice, date, datetime]) -> Union[DateSingleItemBuilder[T], DateRangeBuilder[T]]: """Access single item or slice. Accepts int, date, or datetime.""" if isinstance(key, (date, datetime)): idx = _date_to_index(self._config.index, key) return DateSingleItemBuilder(self._config._new(idx, idx + 1)) if isinstance(key, int): return DateSingleItemBuilder(self._config._new(key, key + 1)) start, stop = key.start, key.stop if isinstance(start, (date, datetime)): start = _date_to_index(self._config.index, start) if isinstance(stop, (date, datetime)): stop = _date_to_index(self._config.index, stop) return DateRangeBuilder(self._config._new(start, stop)) def head(self, n: int = 10) -> DateRangeBuilder[T]: """Get the first n items.""" return DateRangeBuilder(self._config._new(end=n)) def tail(self, n: int = 10) -> DateRangeBuilder[T]: """Get the last n items.""" return DateRangeBuilder(self._config._new(end=0) if n == 0 else self._config._new(start=-n)) def skip(self, n: int) -> DateSkippedBuilder[T]: """Skip the first n items.""" return DateSkippedBuilder(self._config._new(start=n)) def fetch(self) -> DateSeriesData[T]: """Fetch all data.""" return self._config.get_date_series() def fetch_csv(self) -> str: """Fetch all data as CSV.""" return self._config.get_csv() def path(self) -> str: """Get the base endpoint path.""" return self._config.path() # Type aliases for non-generic usage AnySeriesEndpoint = SeriesEndpoint[Any] AnyDateSeriesEndpoint = DateSeriesEndpoint[Any] class SeriesPattern(Protocol[T]): """Protocol for series patterns with different index sets.""" @property def name(self) -> str: """Get the series name.""" ... def indexes(self) -> List[str]: """Get the list of available indexes for this series.""" ... def get(self, index: Index) -> Optional[SeriesEndpoint[T]]: """Get an endpoint builder for a specific index, if supported.""" ... "# ) .unwrap(); } /// Generate index accessor classes pub fn generate_index_accessors(output: &mut String, patterns: &[IndexSetPattern]) { if patterns.is_empty() { return; } // Generate static index tuples writeln!(output, "# Static index tuples").unwrap(); for (i, pattern) in patterns.iter().enumerate() { write!(output, "_i{} = (", i + 1).unwrap(); for (j, index) in pattern.indexes.iter().enumerate() { if j > 0 { write!(output, ", ").unwrap(); } write!(output, "'{}'", index.name()).unwrap(); } // Single-element tuple needs trailing comma if pattern.indexes.len() == 1 { write!(output, ",").unwrap(); } writeln!(output, ")").unwrap(); } writeln!(output).unwrap(); // Generate helper functions writeln!( output, r#"def _ep(c: BrkClientBase, n: str, i: Index) -> SeriesEndpoint[Any]: return SeriesEndpoint(c, n, i) def _dep(c: BrkClientBase, n: str, i: Index) -> DateSeriesEndpoint[Any]: return DateSeriesEndpoint(c, n, i) "# ) .unwrap(); writeln!(output, "# Index accessor classes\n").unwrap(); for (i, pattern) in patterns.iter().enumerate() { let by_class_name = format!("_{}By", pattern.name); let idx_var = format!("_i{}", i + 1); // Generate the By class with compact methods writeln!(output, "class {}(Generic[T]):", by_class_name).unwrap(); writeln!( output, " def __init__(self, c: BrkClientBase, n: str): self._c, self._n = c, n" ) .unwrap(); for index in &pattern.indexes { let method_name = index_to_field_name(index); let index_name = index.name(); let (builder_type, helper) = if index.is_date_based() { ("DateSeriesEndpoint", "_dep") } else { ("SeriesEndpoint", "_ep") }; writeln!( output, " def {}(self) -> {}[T]: return {}(self._c, self._n, '{}')", method_name, builder_type, helper, index_name ) .unwrap(); } writeln!(output).unwrap(); // Generate the main accessor class writeln!(output, "class {}(Generic[T]):", pattern.name).unwrap(); writeln!(output, " by: {}[T]", by_class_name).unwrap(); writeln!( output, " def __init__(self, c: BrkClientBase, n: str): self._n, self.by = n, {}(c, n)", by_class_name ) .unwrap(); writeln!(output, " @property").unwrap(); writeln!(output, " def name(self) -> str: return self._n").unwrap(); writeln!( output, " def indexes(self) -> List[str]: return list({})", idx_var ) .unwrap(); writeln!( output, " def get(self, index: Index) -> Optional[SeriesEndpoint[T]]: return _ep(self.by._c, self._n, index) if index in {} else None", idx_var ) .unwrap(); writeln!(output).unwrap(); } } /// Generate structural pattern classes pub fn generate_structural_patterns( output: &mut String, patterns: &[StructuralPattern], metadata: &ClientMetadata, ) { if patterns.is_empty() { return; } writeln!(output, "# Reusable structural pattern classes\n").unwrap(); for pattern in patterns { // Generate class if pattern.is_generic { writeln!(output, "class {}(Generic[T]):", pattern.name).unwrap(); } else { writeln!(output, "class {}:", pattern.name).unwrap(); } writeln!( output, " \"\"\"Pattern struct for repeated tree structure.\"\"\"" ) .unwrap(); // Skip constructor for non-parameterizable patterns (inlined at tree level) if !metadata.is_parameterizable(&pattern.name) { writeln!(output, " pass\n").unwrap(); continue; } writeln!(output, " ").unwrap(); if pattern.is_templated() { writeln!( output, " def __init__(self, client: BrkClientBase, acc: str, disc: str):" ) .unwrap(); } else { writeln!( output, " def __init__(self, client: BrkClientBase, acc: str):" ) .unwrap(); } writeln!( output, " \"\"\"Create pattern node with accumulated series name.\"\"\"" ) .unwrap(); let syntax = PythonSyntax; for field in &pattern.fields { generate_parameterized_field(output, &syntax, field, pattern, metadata, " "); } writeln!(output).unwrap(); } }