Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Smoke Test for Chronos 2 Zero-Shot Inference | |
| Tests: 1 border × 7 days (168 hours) | |
| """ | |
| import time | |
| import pandas as pd | |
| import numpy as np | |
| import polars as pl | |
| from datetime import datetime, timedelta | |
| from chronos import Chronos2Pipeline | |
| import torch | |
| from src.forecasting.feature_availability import FeatureAvailability | |
| from src.forecasting.dynamic_forecast import DynamicForecast | |
| print("="*60) | |
| print("CHRONOS 2 ZERO-SHOT INFERENCE - SMOKE TEST") | |
| print("="*60) | |
| # Step 1: Load dataset | |
| print("\n[1/6] Loading dataset from HuggingFace...") | |
| start_time = time.time() | |
| from datasets import load_dataset | |
| import os | |
| # Use HF token for private dataset access | |
| hf_token = os.getenv("HF_TOKEN") | |
| if not hf_token: | |
| raise ValueError("HF_TOKEN not found in environment. Please set HF_TOKEN.") | |
| dataset = load_dataset( | |
| "evgueni-p/fbmc-features-24month", | |
| split="train", | |
| token=hf_token | |
| ) | |
| df = pl.from_pandas(dataset.to_pandas()) | |
| # Ensure timestamp is datetime (check if conversion needed) | |
| if df['timestamp'].dtype == pl.String: | |
| df = df.with_columns(pl.col('timestamp').str.to_datetime()) | |
| elif df['timestamp'].dtype != pl.Datetime: | |
| df = df.with_columns(pl.col('timestamp').cast(pl.Datetime)) | |
| print(f"[OK] Loaded {len(df)} rows, {len(df.columns)} columns") | |
| print(f" Date range: {df['timestamp'].min()} to {df['timestamp'].max()}") | |
| print(f" Load time: {time.time() - start_time:.1f}s") | |
| # Feature categorization using FeatureAvailability module | |
| print("\n[Feature Categorization]") | |
| categories = FeatureAvailability.categorize_features(df.columns) | |
| # Validate categorization | |
| is_valid, warnings = FeatureAvailability.validate_categorization(categories, verbose=False) | |
| # Report categories | |
| print(f" Full-horizon D+14: {len(categories['full_horizon_d14'])} (temporal + weather + outages + LTA)") | |
| print(f" Partial D+1: {len(categories['partial_d1'])} (load forecasts)") | |
| print(f" Historical only: {len(categories['historical'])} (prices, generation, demand, lags, etc.)") | |
| print(f" Total features: {sum(len(v) for v in categories.values())}") | |
| if not is_valid: | |
| print("\n[!] WARNING: Feature categorization issues:") | |
| for w in warnings: | |
| print(f" - {w}") | |
| # For Chronos-2: combine full+partial for future covariates | |
| # (Chronos-2 supports partial availability via masking) | |
| known_future_cols = categories['full_horizon_d14'] + categories['partial_d1'] | |
| past_only_cols = categories['historical'] | |
| # Step 2: Identify target borders | |
| print("\n[2/6] Identifying target borders...") | |
| target_cols = [col for col in df.columns if col.startswith('target_border_')] | |
| borders = [col.replace('target_border_', '') for col in target_cols] | |
| print(f"[OK] Found {len(borders)} borders") | |
| # Select first border for test | |
| test_border = borders[0] | |
| print(f"[*] Test border: {test_border}") | |
| # Step 3: Prepare test data with DynamicForecast | |
| print("\n[3/6] Preparing test data...") | |
| # Use a date that has 7 days of future data available | |
| # Dataset ends at 2025-09-30 23:00, so we need run_date such that | |
| # forecast ends at most at 2025-09-30 23:00 | |
| # For 168 hours (7 days), run_date should be at most 2025-09-23 23:00 | |
| prediction_hours = 168 # 7 days | |
| max_date = df['timestamp'].max() | |
| run_date = max_date - timedelta(hours=prediction_hours) | |
| context_hours = 512 | |
| print(f" Run date: {run_date}") | |
| print(f" Context: {context_hours} hours (historical)") | |
| print(f" Forecast: {prediction_hours} hours (7 days, D+1 to D+7)") | |
| print(f" Forecast range: {run_date + timedelta(hours=1)} to {run_date + timedelta(hours=prediction_hours)}") | |
| # Initialize DynamicForecast | |
| forecaster = DynamicForecast( | |
| dataset=df, | |
| context_hours=context_hours, | |
| forecast_hours=prediction_hours | |
| ) | |
| # Prepare data with time-aware extraction | |
| context_data, future_data = forecaster.prepare_forecast_data(run_date, test_border) | |
| # Validate no data leakage | |
| is_valid, errors = forecaster.validate_no_leakage(context_data, future_data, run_date) | |
| if not is_valid: | |
| print("\n[ERROR] Data leakage detected:") | |
| for err in errors: | |
| print(f" - {err}") | |
| exit(1) | |
| print(f"[OK] Data preparation complete (leakage validation passed)") | |
| print(f" Context shape: {context_data.shape}") | |
| print(f" Future shape: {future_data.shape}") | |
| print(f" Context dates: {context_data['timestamp'].min()} to {context_data['timestamp'].max()}") | |
| print(f" Future dates: {future_data['timestamp'].min()} to {future_data['timestamp'].max()}") | |
| # Step 4: Load model | |
| print("\n[4/6] Loading Chronos 2 model on GPU...") | |
| model_start = time.time() | |
| pipeline = Chronos2Pipeline.from_pretrained( | |
| 'amazon/chronos-2', | |
| device_map='cuda', | |
| dtype=torch.float32 | |
| ) | |
| model_time = time.time() - model_start | |
| print(f"[OK] Model loaded in {model_time:.1f}s") | |
| print(f" Device: {next(pipeline.model.parameters()).device}") | |
| # Step 5: Run inference | |
| print(f"\n[5/6] Running zero-shot inference...") | |
| print(f" Border: {test_border}") | |
| print(f" Prediction: {prediction_hours} hours (7 days)") | |
| print(f" Samples: 100 (for probabilistic forecast)") | |
| inference_start = time.time() | |
| try: | |
| # Call API with separate context and future dataframes | |
| forecasts = pipeline.predict_df( | |
| context_data, # Historical data (positional parameter) | |
| future_df=future_data, # Future covariates (named parameter) | |
| prediction_length=prediction_hours, | |
| id_column='border', | |
| timestamp_column='timestamp', | |
| target='target' | |
| ) | |
| inference_time = time.time() - inference_start | |
| print(f"[OK] Inference complete in {inference_time:.1f}s") | |
| print(f" Forecast shape: {forecasts.shape}") | |
| # Step 6: Validate results | |
| print("\n[6/6] Validating results...") | |
| # Check for NaN values | |
| nan_count = forecasts.isna().sum().sum() | |
| print(f" NaN values: {nan_count}") | |
| if 'mean' in forecasts.columns: | |
| mean_forecast = forecasts['mean'] | |
| print(f" Forecast statistics:") | |
| print(f" Mean: {mean_forecast.mean():.2f} MW") | |
| print(f" Min: {mean_forecast.min():.2f} MW") | |
| print(f" Max: {mean_forecast.max():.2f} MW") | |
| print(f" Std: {mean_forecast.std():.2f} MW") | |
| # Sanity checks | |
| if mean_forecast.min() < 0: | |
| print(" [!] WARNING: Negative forecasts detected") | |
| if mean_forecast.max() > 20000: | |
| print(" [!] WARNING: Unreasonably high forecasts") | |
| if nan_count == 0 and mean_forecast.min() >= 0 and mean_forecast.max() < 20000: | |
| print(" [OK] Validation passed!") | |
| # Performance summary | |
| print("\n" + "="*60) | |
| print("SMOKE TEST SUMMARY") | |
| print("="*60) | |
| print(f"Border tested: {test_border}") | |
| print(f"Forecast length: {prediction_hours} hours (7 days)") | |
| print(f"Inference time: {inference_time:.1f}s") | |
| print(f"Speed: {prediction_hours / inference_time:.1f} hours/second") | |
| # Estimate full run time | |
| total_borders = len(borders) | |
| full_forecast_hours = 336 # 14 days | |
| estimated_time = (inference_time / prediction_hours) * full_forecast_hours * total_borders | |
| print(f"\nEstimated time for full run:") | |
| print(f" {total_borders} borders × {full_forecast_hours} hours") | |
| print(f" = {estimated_time / 60:.1f} minutes ({estimated_time / 3600:.1f} hours)") | |
| # Target check | |
| if inference_time < 300: # 5 minutes | |
| print(f"\n[OK] Performance target met! (<5 min for 7-day forecast)") | |
| else: | |
| print(f"\n[!] Performance slower than target (expected <5 min)") | |
| print("="*60) | |
| print("[OK] SMOKE TEST PASSED!") | |
| print("="*60) | |
| except Exception as e: | |
| print(f"\n[ERROR] Inference failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| exit(1) | |
| # Total time | |
| total_time = time.time() - start_time | |
| print(f"\nTotal test time: {total_time:.1f}s ({total_time / 60:.1f} min)") | |