Spaces:
Sleeping
Sleeping
| """ | |
| Chronos 2 Zero-Shot Inference Pipeline | |
| Handles: | |
| 1. Loading Chronos 2 Large model (710M params) | |
| 2. Running zero-shot inference using predict_df() API | |
| 3. GPU/CPU device mapping | |
| 4. Saving predictions to parquet | |
| """ | |
| from pathlib import Path | |
| from typing import Optional, Dict, List | |
| import pandas as pd | |
| import torch | |
| from datetime import datetime | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class ChronosForecaster: | |
| """ | |
| Zero-shot forecaster using Chronos 2 Large model. | |
| Features: | |
| - Multivariate forecasting (multiple borders simultaneously) | |
| - Covariate support (615 future covariates) | |
| - Large context window (up to 8,192 hours) | |
| - DataFrame API for easy data handling | |
| """ | |
| def __init__( | |
| self, | |
| model_name: str = "amazon/chronos-2-large", | |
| device: str = "auto", | |
| torch_dtype: str = "float32" | |
| ): | |
| """ | |
| Initialize Chronos 2 forecaster. | |
| Args: | |
| model_name: HuggingFace model name (default: chronos-2-large) | |
| device: Device to run on ('auto', 'cuda', 'cpu') | |
| torch_dtype: Torch dtype ('float32', 'float16', 'bfloat16') | |
| """ | |
| self.model_name = model_name | |
| self.device = self._resolve_device(device) | |
| self.torch_dtype = self._resolve_dtype(torch_dtype) | |
| self.pipeline = None | |
| logger.info(f"ChronosForecaster initialized:") | |
| logger.info(f" Model: {model_name}") | |
| logger.info(f" Device: {self.device}") | |
| logger.info(f" Dtype: {self.torch_dtype}") | |
| def _resolve_device(self, device: str) -> str: | |
| """Resolve device string to actual device.""" | |
| if device == "auto": | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| return device | |
| def _resolve_dtype(self, dtype_str: str) -> torch.dtype: | |
| """Resolve dtype string to torch dtype.""" | |
| dtype_map = { | |
| "float32": torch.float32, | |
| "float16": torch.float16, | |
| "bfloat16": torch.bfloat16 | |
| } | |
| return dtype_map.get(dtype_str, torch.float32) | |
| def load_model(self): | |
| """Load Chronos 2 model from HuggingFace.""" | |
| if self.pipeline is not None: | |
| logger.info("Model already loaded") | |
| return | |
| logger.info(f"Loading {self.model_name}...") | |
| logger.info("This may take a few minutes on first load...") | |
| try: | |
| from chronos import Chronos2Pipeline | |
| # Load with device_map for GPU support | |
| self.pipeline = Chronos2Pipeline.from_pretrained( | |
| self.model_name, | |
| device_map=self.device if self.device == "cuda" else None, | |
| torch_dtype=self.torch_dtype | |
| ) | |
| # Move to device if not using device_map | |
| if self.device == "cpu": | |
| self.pipeline = self.pipeline.to(self.device) | |
| logger.info(f"Model loaded successfully on {self.device}") | |
| # Print GPU info if available | |
| if self.device == "cuda": | |
| gpu_name = torch.cuda.get_device_name(0) | |
| gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 | |
| logger.info(f"GPU: {gpu_name} ({gpu_memory:.1f} GB VRAM)") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| raise | |
| def predict( | |
| self, | |
| context_df: pd.DataFrame, | |
| future_df: pd.DataFrame, | |
| prediction_length: int = 336, | |
| id_column: str = "border", | |
| timestamp_column: str = "timestamp", | |
| num_samples: int = 100 | |
| ) -> pd.DataFrame: | |
| """ | |
| Run zero-shot inference using Chronos 2. | |
| Args: | |
| context_df: Historical data (timestamp, border, target, features) | |
| future_df: Future covariates (timestamp, border, future_covariates) | |
| prediction_length: Number of hours to forecast | |
| id_column: Column name for border ID | |
| timestamp_column: Column name for timestamp | |
| num_samples: Number of samples for probabilistic forecast | |
| Returns: | |
| forecasts_df: DataFrame with predictions (timestamp, border, mean, median, q10, q90) | |
| """ | |
| if self.pipeline is None: | |
| self.load_model() | |
| logger.info("Running zero-shot inference...") | |
| logger.info(f"Context shape: {context_df.shape}") | |
| logger.info(f"Future shape: {future_df.shape}") | |
| logger.info(f"Prediction length: {prediction_length} hours") | |
| logger.info(f"Borders: {context_df[id_column].nunique()}") | |
| try: | |
| # Run inference | |
| forecasts = self.pipeline.predict_df( | |
| context_df=context_df, | |
| future_df=future_df, | |
| prediction_length=prediction_length, | |
| id_column=id_column, | |
| timestamp_column=timestamp_column, | |
| num_samples=num_samples | |
| ) | |
| logger.info(f"Inference complete! Forecast shape: {forecasts.shape}") | |
| # Add metadata | |
| forecasts['forecast_date'] = context_df[timestamp_column].max() | |
| forecasts['model'] = self.model_name | |
| return forecasts | |
| except Exception as e: | |
| logger.error(f"Inference failed: {e}") | |
| raise | |
| def predict_single_border( | |
| self, | |
| border: str, | |
| context_df: pd.DataFrame, | |
| future_df: pd.DataFrame, | |
| prediction_length: int = 336, | |
| num_samples: int = 100 | |
| ) -> pd.DataFrame: | |
| """ | |
| Run inference for a single border (useful for testing). | |
| Args: | |
| border: Border name (e.g., 'AT_CZ') | |
| context_df: Historical data | |
| future_df: Future covariates | |
| prediction_length: Hours to forecast | |
| num_samples: Samples for probabilistic forecast | |
| Returns: | |
| forecasts_df: Predictions for single border | |
| """ | |
| logger.info(f"Running inference for border: {border}") | |
| # Filter for single border | |
| context_border = context_df[context_df['border'] == border].copy() | |
| future_border = future_df[future_df['border'] == border].copy() | |
| # Run prediction | |
| forecasts = self.predict( | |
| context_df=context_border, | |
| future_df=future_border, | |
| prediction_length=prediction_length, | |
| num_samples=num_samples | |
| ) | |
| return forecasts | |
| def save_forecasts( | |
| self, | |
| forecasts: pd.DataFrame, | |
| output_path: str, | |
| include_metadata: bool = True | |
| ): | |
| """ | |
| Save forecasts to parquet file. | |
| Args: | |
| forecasts: Forecast DataFrame | |
| output_path: Path to save parquet file | |
| include_metadata: Include model metadata | |
| """ | |
| logger.info(f"Saving forecasts to: {output_path}") | |
| # Create output directory if needed | |
| output_path = Path(output_path) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| # Add metadata | |
| if include_metadata: | |
| forecasts = forecasts.copy() | |
| forecasts['saved_at'] = datetime.now() | |
| # Save to parquet | |
| forecasts.to_parquet(output_path, index=False) | |
| logger.info(f"Saved {len(forecasts)} rows to {output_path}") | |
| def benchmark_inference( | |
| self, | |
| context_df: pd.DataFrame, | |
| future_df: pd.DataFrame, | |
| prediction_length: int = 336 | |
| ) -> Dict[str, float]: | |
| """ | |
| Benchmark inference speed and memory usage. | |
| Args: | |
| context_df: Historical data | |
| future_df: Future covariates | |
| prediction_length: Hours to forecast | |
| Returns: | |
| metrics: Dict with inference_time_sec, gpu_memory_mb | |
| """ | |
| import time | |
| logger.info("Benchmarking inference performance...") | |
| # Record start time and memory | |
| start_time = time.time() | |
| if self.device == "cuda": | |
| torch.cuda.reset_peak_memory_stats() | |
| # Run inference | |
| _ = self.predict( | |
| context_df=context_df, | |
| future_df=future_df, | |
| prediction_length=prediction_length | |
| ) | |
| # Record end time and memory | |
| end_time = time.time() | |
| inference_time = end_time - start_time | |
| metrics = { | |
| 'inference_time_sec': inference_time, | |
| 'borders': context_df['border'].nunique(), | |
| 'prediction_length': prediction_length | |
| } | |
| if self.device == "cuda": | |
| peak_memory = torch.cuda.max_memory_allocated() / 1e6 # MB | |
| metrics['gpu_memory_mb'] = peak_memory | |
| logger.info(f"Inference time: {inference_time:.2f}s") | |
| if 'gpu_memory_mb' in metrics: | |
| logger.info(f"Peak GPU memory: {metrics['gpu_memory_mb']:.1f} MB") | |
| return metrics | |