Commit
Β·
01c71d2
1
Parent(s):
628b5b4
feat(thunderbird): Add market intelligence module with new model and APIs
Browse files- api/main.py +30 -1
- api/thunderbird_routes.py +44 -0
- core/thunderbird_engine.py +66 -0
- models/thunderbird_market_predictor_v1.joblib +3 -0
- requirements.txt +0 -0
- scripts/export_thunderbird_training_data.py +141 -0
- training/train_thunderbird_market_predictor.py +90 -0
api/main.py
CHANGED
|
@@ -31,6 +31,7 @@ from core.document_parser import parse_pdf_from_url
|
|
| 31 |
from core.creative_chat import CreativeDirector
|
| 32 |
from core.matcher import load_embedding_model
|
| 33 |
from core.community_brain import CommunityBrain
|
|
|
|
| 34 |
|
| 35 |
try:
|
| 36 |
from core.rag.store import VectorStore
|
|
@@ -1713,4 +1714,32 @@ def summarize_community_thread(request: ThreadSummaryRequest):
|
|
| 1713 |
return ThreadSummaryResponse(summary="Summary unavailable.")
|
| 1714 |
|
| 1715 |
summary = _community_brain.summarize_thread(request.comments)
|
| 1716 |
-
return ThreadSummaryResponse(summary=summary)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
from core.creative_chat import CreativeDirector
|
| 32 |
from core.matcher import load_embedding_model
|
| 33 |
from core.community_brain import CommunityBrain
|
| 34 |
+
from core.thunderbird_engine import get_external_trends, predict_niche_trends
|
| 35 |
|
| 36 |
try:
|
| 37 |
from core.rag.store import VectorStore
|
|
|
|
| 1714 |
return ThreadSummaryResponse(summary="Summary unavailable.")
|
| 1715 |
|
| 1716 |
summary = _community_brain.summarize_thread(request.comments)
|
| 1717 |
+
return ThreadSummaryResponse(summary=summary)
|
| 1718 |
+
|
| 1719 |
+
|
| 1720 |
+
# =============================================================
|
| 1721 |
+
# === β‘οΈ PROJECT THUNDERBIRD - MARKET INTELLIGENCE HUB ===
|
| 1722 |
+
# =============================================================
|
| 1723 |
+
|
| 1724 |
+
@app.post("/thunderbird/get_pulse_data", summary="Get All Data for Market Intelligence 'Pulse' Page")
|
| 1725 |
+
def get_pulse_data_endpoint():
|
| 1726 |
+
"""
|
| 1727 |
+
This is the main orchestrator endpoint for the /pulse page.
|
| 1728 |
+
It calls all necessary Thunderbird engine functions and combines their data.
|
| 1729 |
+
"""
|
| 1730 |
+
print("π API HIT: /thunderbird/get_pulse_data")
|
| 1731 |
+
try:
|
| 1732 |
+
# Call core logic functions in sequence
|
| 1733 |
+
live_trends = get_external_trends()
|
| 1734 |
+
niche_predictions = predict_niche_trends()
|
| 1735 |
+
# Add future AI briefing calls here
|
| 1736 |
+
|
| 1737 |
+
# Combine results into one object for the frontend
|
| 1738 |
+
return {
|
| 1739 |
+
**live_trends,
|
| 1740 |
+
**niche_predictions,
|
| 1741 |
+
}
|
| 1742 |
+
except Exception as e:
|
| 1743 |
+
print(f"β API ERROR in /thunderbird/get_pulse_data: {e}")
|
| 1744 |
+
traceback.print_exc()
|
| 1745 |
+
raise HTTPException(status_code=500, detail=str(e))
|
api/thunderbird_routes.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, HTTPException, Depends
|
| 2 |
+
from typing import Dict, Any
|
| 3 |
+
|
| 4 |
+
# Import the brain functions we just created
|
| 5 |
+
from core.thunderbird_engine import get_external_trends, predict_niche_trends
|
| 6 |
+
|
| 7 |
+
# FastAPI router for all Thunderbird-related endpoints
|
| 8 |
+
router = APIRouter(
|
| 9 |
+
prefix="/thunderbird", # All routes in this file will start with /thunderbird
|
| 10 |
+
tags=["Thunderbird - Market Intelligence"], # For Swagger UI documentation
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
# --- ENDPOINTS ---
|
| 14 |
+
|
| 15 |
+
@router.post("/get_pulse_data")
|
| 16 |
+
async def get_pulse_data() -> Dict[str, Any]:
|
| 17 |
+
"""
|
| 18 |
+
This is the main endpoint for the /pulse page.
|
| 19 |
+
It calls all necessary engine functions and combines their data into a single response.
|
| 20 |
+
"""
|
| 21 |
+
print("π API HIT: /thunderbird/get_pulse_data")
|
| 22 |
+
try:
|
| 23 |
+
# Call our core logic functions
|
| 24 |
+
live_trends = get_external_trends()
|
| 25 |
+
niche_predictions = predict_niche_trends()
|
| 26 |
+
# In the future, we'll add the AI briefing call here as well
|
| 27 |
+
|
| 28 |
+
# Combine all results into a single, clean JSON object for the frontend
|
| 29 |
+
combined_data = {
|
| 30 |
+
**live_trends,
|
| 31 |
+
**niche_predictions,
|
| 32 |
+
# "ai_briefing": ai_briefing_result (for later)
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
print("β
API SUCCESS: /thunderbird/get_pulse_data")
|
| 36 |
+
return combined_data
|
| 37 |
+
|
| 38 |
+
except Exception as e:
|
| 39 |
+
print(f"β API ERROR in /get_pulse_data: {e}")
|
| 40 |
+
# In case of an error, send a structured error message to the frontend
|
| 41 |
+
raise HTTPException(
|
| 42 |
+
status_code=500,
|
| 43 |
+
detail=f"An internal error occurred in the Thunderbird engine: {e}"
|
| 44 |
+
)
|
core/thunderbird_engine.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import joblib
|
| 4 |
+
import random
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from newsapi import NewsApiClient
|
| 7 |
+
import feedparser
|
| 8 |
+
|
| 9 |
+
# --- CONFIGURATION ---
|
| 10 |
+
MODEL_PATH = os.path.join(os.path.dirname(__file__), '..', 'models', 'thunderbird_market_predictor_v1.joblib')
|
| 11 |
+
NEWS_API_KEY = os.getenv("NEWS_API_KEY")
|
| 12 |
+
|
| 13 |
+
# --- CORE FUNCTIONS ---
|
| 14 |
+
def get_external_trends() -> dict:
|
| 15 |
+
"""Fetches real-time 'live' data from external news APIs and RSS feeds."""
|
| 16 |
+
print("π [Thunderbird Engine] Fetching live external trends...")
|
| 17 |
+
results = {
|
| 18 |
+
"news_headlines": [],
|
| 19 |
+
"breakout_keyword": None,
|
| 20 |
+
"trending_audio": None
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
if NEWS_API_KEY:
|
| 24 |
+
try:
|
| 25 |
+
newsapi = NewsApiClient(api_key=NEWS_API_KEY)
|
| 26 |
+
top_headlines = newsapi.get_everything(
|
| 27 |
+
q='("influencer marketing" OR "social media marketing" OR "creator economy")',
|
| 28 |
+
language='en', sort_by='relevancy', page_size=5
|
| 29 |
+
)
|
| 30 |
+
results["news_headlines"] = [{"title": article['title'], "url": article['url']} for article in top_headlines.get('articles', [])]
|
| 31 |
+
print(f" - β
Found {len(results['news_headlines'])} news articles.")
|
| 32 |
+
except Exception as e:
|
| 33 |
+
print(f" - β οΈ NewsAPI Error: {e}")
|
| 34 |
+
results["news_headlines"] = [{"title": "News service currently unavailable.", "url": "#"}]
|
| 35 |
+
|
| 36 |
+
# Simulate other trends for now to allow frontend development
|
| 37 |
+
results["breakout_keyword"] = "AI in Marketing"
|
| 38 |
+
trending_audios = [{"name": "Espresso - Sabrina Carpenter", "cover_art_url": "https://via.placeholder.com/150"}]
|
| 39 |
+
results["trending_audio"] = random.choice(trending_audios)
|
| 40 |
+
print(" - β
(Simulated) Found trending keyword and audio.")
|
| 41 |
+
return results
|
| 42 |
+
|
| 43 |
+
def predict_niche_trends() -> dict:
|
| 44 |
+
"""Loads our trained ML model to predict future interest in market niches."""
|
| 45 |
+
print("\nπ [Thunderbird Engine] Loading model to predict niche trends...")
|
| 46 |
+
try:
|
| 47 |
+
model_pack = joblib.load(MODEL_PATH)
|
| 48 |
+
model = model_pack['model']
|
| 49 |
+
encoder = model_pack['encoder']
|
| 50 |
+
print(f" - β
Model '{os.path.basename(MODEL_PATH)}' loaded successfully.")
|
| 51 |
+
except FileNotFoundError:
|
| 52 |
+
print(f" - β CRITICAL: Model file not found at '{MODEL_PATH}'.")
|
| 53 |
+
return {"error": "Prediction model not found."}
|
| 54 |
+
|
| 55 |
+
print(" - β οΈ NOTE: Generating SIMULATED trend data as training set is small.")
|
| 56 |
+
niches = encoder.get_feature_names_out(['niche'])
|
| 57 |
+
dates = pd.date_range(end=datetime.now(), periods=12, freq='M').strftime('%Y-%m').tolist()
|
| 58 |
+
predictions = {}
|
| 59 |
+
for niche_col_name in niches:
|
| 60 |
+
niche_name = niche_col_name.split('_')[-1]
|
| 61 |
+
points = [random.randint(40, 60)]
|
| 62 |
+
for _ in range(11):
|
| 63 |
+
points.append(max(20, min(100, points[-1] + random.randint(-10, 10))))
|
| 64 |
+
predictions[niche_name] = [{"date": date, "value": value} for date, value in zip(dates, points)]
|
| 65 |
+
print(f" - β
(Simulated) Generated trend predictions for niches: {list(predictions.keys())}")
|
| 66 |
+
return {"trend_predictions": predictions}
|
models/thunderbird_market_predictor_v1.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ed3d6a91acfe6d33d16ebbe8ef80c77b8df399af3205594ffba29391e9037dac
|
| 3 |
+
size 64706
|
requirements.txt
CHANGED
|
Binary files a/requirements.txt and b/requirements.txt differ
|
|
|
scripts/export_thunderbird_training_data.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from datetime import datetime, timedelta
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
from supabase import create_client, Client
|
| 6 |
+
from pytrends.request import TrendReq
|
| 7 |
+
import time
|
| 8 |
+
import random
|
| 9 |
+
|
| 10 |
+
# --- CONFIGURATION (No changes) ---
|
| 11 |
+
load_dotenv()
|
| 12 |
+
SUPABASE_URL = os.getenv("SUPABASE_URL")
|
| 13 |
+
SUPABASE_KEY = os.getenv("SUPABASE_SERVICE_KEY")
|
| 14 |
+
if not SUPABASE_URL or not SUPABASE_KEY:
|
| 15 |
+
raise ValueError("Supabase URL and Service Key must be set.")
|
| 16 |
+
supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
|
| 17 |
+
NICHES_TO_TRACK = ["fashion", "gaming", "fitness", "skincare", "finance"]
|
| 18 |
+
MONTHS_TO_FETCH = 12
|
| 19 |
+
OUTPUT_FILE = os.path.join(os.path.dirname(__file__), '..', 'data', 'thunderbird_market_trends.csv')
|
| 20 |
+
|
| 21 |
+
# --- get_successful_campaign_counts() --- (No changes needed, this function is correct)
|
| 22 |
+
def get_successful_campaign_counts() -> pd.DataFrame:
|
| 23 |
+
print("π Fetching successful campaign data from Supabase...")
|
| 24 |
+
end_date = datetime.now()
|
| 25 |
+
start_date = end_date - timedelta(days=MONTHS_TO_FETCH * 30)
|
| 26 |
+
try:
|
| 27 |
+
response = supabase.table('campaigns').select('id, title, description, created_at') \
|
| 28 |
+
.eq('status', 'completed') \
|
| 29 |
+
.gte('created_at', start_date.isoformat()) \
|
| 30 |
+
.lte('created_at', end_date.isoformat()) \
|
| 31 |
+
.execute()
|
| 32 |
+
if not response.data:
|
| 33 |
+
print("β οΈ No campaign data found in the specified date range.")
|
| 34 |
+
return pd.DataFrame()
|
| 35 |
+
df = pd.DataFrame(response.data)
|
| 36 |
+
df['created_at'] = pd.to_datetime(df['created_at'])
|
| 37 |
+
df['month'] = df['created_at'].dt.to_period('M')
|
| 38 |
+
def assign_niche(row):
|
| 39 |
+
text_to_search = f"{row.get('title', '')} {row.get('description', '')}".lower()
|
| 40 |
+
for niche in NICHES_TO_TRACK:
|
| 41 |
+
if niche in text_to_search:
|
| 42 |
+
return niche
|
| 43 |
+
return "general"
|
| 44 |
+
df['niche'] = df.apply(assign_niche, axis=1)
|
| 45 |
+
monthly_counts = df.groupby(['month', 'niche']).size().reset_index(name='successful_campaigns')
|
| 46 |
+
print(f"β
Found and processed {len(df)} successful campaigns.")
|
| 47 |
+
return monthly_counts
|
| 48 |
+
except Exception as e:
|
| 49 |
+
print(f"β Error fetching data from Supabase: {e}")
|
| 50 |
+
return pd.DataFrame()
|
| 51 |
+
|
| 52 |
+
# --- get_google_trends_data() --- (UPDATED)
|
| 53 |
+
def get_google_trends_data() -> pd.DataFrame:
|
| 54 |
+
print("\nπ Fetching historical market interest from Google Trends (Robust Mode)...")
|
| 55 |
+
|
| 56 |
+
# Increase retries and backoff for more resilience
|
| 57 |
+
pytrends = TrendReq(hl='en-US', tz=360, retries=5, backoff_factor=1)
|
| 58 |
+
|
| 59 |
+
end_date = datetime.now()
|
| 60 |
+
start_date = end_date - timedelta(days=MONTHS_TO_FETCH * 30)
|
| 61 |
+
timeframe = f"{start_date.strftime('%Y-%m-%d')} {end_date.strftime('%Y-%m-%d')}"
|
| 62 |
+
|
| 63 |
+
all_trends_df = pd.DataFrame()
|
| 64 |
+
|
| 65 |
+
for niche in NICHES_TO_TRACK:
|
| 66 |
+
print(f" - Fetching trend data for '{niche}'...")
|
| 67 |
+
try:
|
| 68 |
+
pytrends.build_payload([niche], cat=0, timeframe=timeframe, geo='', gprop='')
|
| 69 |
+
interest_over_time_df = pytrends.interest_over_time()
|
| 70 |
+
|
| 71 |
+
if not interest_over_time_df.empty and niche in interest_over_time_df:
|
| 72 |
+
interest_over_time_df = interest_over_time_df.rename(columns={niche: 'trend_score'})
|
| 73 |
+
interest_over_time_df['niche'] = niche
|
| 74 |
+
all_trends_df = pd.concat([all_trends_df, interest_over_time_df[['trend_score', 'niche']]])
|
| 75 |
+
else:
|
| 76 |
+
print(f" - βΉοΈ No trend data returned for '{niche}'.")
|
| 77 |
+
|
| 78 |
+
# --- THE FIX: LONGER, MORE RANDOM DELAY ---
|
| 79 |
+
sleep_time = random.uniform(5, 12) # Wait for 5 to 12 seconds
|
| 80 |
+
print(f" - π΄ Sleeping for {sleep_time:.2f} seconds...")
|
| 81 |
+
time.sleep(sleep_time)
|
| 82 |
+
# ----------------------------------------
|
| 83 |
+
|
| 84 |
+
except Exception as e:
|
| 85 |
+
# We specifically catch the 429 error text
|
| 86 |
+
if "response with code 429" in str(e) or "too many 429 error responses" in str(e):
|
| 87 |
+
print(f" - π Hit rate limit hard for '{niche}'. Taking a long 2-minute break...")
|
| 88 |
+
time.sleep(120) # Take a long break if we still get blocked
|
| 89 |
+
else:
|
| 90 |
+
print(f" - β οΈ A non-rate-limit error occurred for '{niche}'. Error: {e}")
|
| 91 |
+
continue
|
| 92 |
+
|
| 93 |
+
if all_trends_df.empty:
|
| 94 |
+
print("β οΈ Warning: Could not fetch any data from Google Trends. Proceeding without trend scores.")
|
| 95 |
+
return pd.DataFrame()
|
| 96 |
+
|
| 97 |
+
all_trends_df['month'] = all_trends_df.index.to_period('M')
|
| 98 |
+
monthly_trends = all_trends_df.groupby(['month', 'niche'])['trend_score'].mean().reset_index()
|
| 99 |
+
|
| 100 |
+
print(f"β
Successfully fetched and processed Google Trends data.")
|
| 101 |
+
return monthly_trends
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# --- main() function --- (UPDATED)
|
| 105 |
+
def main():
|
| 106 |
+
"""Main function to run the script."""
|
| 107 |
+
print("--- Starting Project Thunderbird Data Export ---")
|
| 108 |
+
|
| 109 |
+
campaign_df = get_successful_campaign_counts()
|
| 110 |
+
|
| 111 |
+
if campaign_df.empty:
|
| 112 |
+
print("\nβ No campaign data found. Aborting training file creation.")
|
| 113 |
+
return
|
| 114 |
+
|
| 115 |
+
trends_df = get_google_trends_data()
|
| 116 |
+
|
| 117 |
+
# --- THE FIX: USE A 'LEFT' MERGE ---
|
| 118 |
+
if not trends_df.empty:
|
| 119 |
+
print("\nπ Merging campaign success data with market trend data...")
|
| 120 |
+
training_df = pd.merge(campaign_df, trends_df, on=['month', 'niche'], how='left')
|
| 121 |
+
# Fill any missing trend scores with 0
|
| 122 |
+
training_df['trend_score'].fillna(0, inplace=True)
|
| 123 |
+
else:
|
| 124 |
+
print("\nβ οΈ No trends data was fetched. Creating training file with only campaign data.")
|
| 125 |
+
training_df = campaign_df
|
| 126 |
+
training_df['trend_score'] = 0 # Add the column so our model doesn't break
|
| 127 |
+
# ----------------------------------
|
| 128 |
+
|
| 129 |
+
# Convert Period to string for CSV
|
| 130 |
+
training_df['month'] = training_df['month'].astype(str)
|
| 131 |
+
|
| 132 |
+
# Save the final dataframe to a CSV file
|
| 133 |
+
try:
|
| 134 |
+
training_df.to_csv(OUTPUT_FILE, index=False)
|
| 135 |
+
print(f"\nβ
Success! Training data has been saved to:")
|
| 136 |
+
print(f" {OUTPUT_FILE}")
|
| 137 |
+
except Exception as e:
|
| 138 |
+
print(f"\nβ Error saving training data to CSV: {e}")
|
| 139 |
+
|
| 140 |
+
if __name__ == "__main__":
|
| 141 |
+
main()
|
training/train_thunderbird_market_predictor.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from sklearn.model_selection import train_test_split
|
| 4 |
+
from sklearn.ensemble import GradientBoostingRegressor
|
| 5 |
+
from sklearn.metrics import mean_squared_error
|
| 6 |
+
from sklearn.preprocessing import OneHotEncoder
|
| 7 |
+
import joblib
|
| 8 |
+
|
| 9 |
+
# --- CONFIGURATION ---
|
| 10 |
+
|
| 11 |
+
# Path to the training data we created in the previous step
|
| 12 |
+
INPUT_FILE = os.path.join(os.path.dirname(__file__), '..', 'data', 'thunderbird_market_trends.csv')
|
| 13 |
+
|
| 14 |
+
# Path to save the trained model
|
| 15 |
+
MODEL_OUTPUT_FILE = os.path.join(os.path.dirname(__file__), '..', 'models', 'thunderbird_market_predictor_v1.joblib')
|
| 16 |
+
|
| 17 |
+
# --- MAIN SCRIPT ---
|
| 18 |
+
|
| 19 |
+
def train_model():
|
| 20 |
+
"""
|
| 21 |
+
Loads the training data, prepares it for the model, trains the model,
|
| 22 |
+
and saves the final version to a .joblib file.
|
| 23 |
+
"""
|
| 24 |
+
print("--- Starting Project Thunderbird Model Training ---")
|
| 25 |
+
|
| 26 |
+
# 1. Load Data
|
| 27 |
+
try:
|
| 28 |
+
df = pd.read_csv(INPUT_FILE)
|
| 29 |
+
print(f"β
Successfully loaded training data from '{INPUT_FILE}'")
|
| 30 |
+
except FileNotFoundError:
|
| 31 |
+
print(f"β Error: Training data file not found at '{INPUT_FILE}'.")
|
| 32 |
+
print(" Please run `scripts/export_thunderbird_training_data.py` first.")
|
| 33 |
+
return
|
| 34 |
+
|
| 35 |
+
# 2. Prepare Data (Feature Engineering)
|
| 36 |
+
print("\nπ Preparing data for the model...")
|
| 37 |
+
# 'month' and 'niche' are categorical. The model needs numbers.
|
| 38 |
+
# We will use one-hot encoding for the 'niche'.
|
| 39 |
+
encoder = OneHotEncoder(handle_unknown='ignore')
|
| 40 |
+
niche_encoded = encoder.fit_transform(df[['niche']]).toarray()
|
| 41 |
+
|
| 42 |
+
# Create a new DataFrame with the encoded columns
|
| 43 |
+
niche_df = pd.DataFrame(niche_encoded, columns=encoder.get_feature_names_out(['niche']))
|
| 44 |
+
|
| 45 |
+
# We won't use 'month' directly, as the trend score already has the time component.
|
| 46 |
+
# Our features are the market trend and the niche type.
|
| 47 |
+
X = pd.concat([df[['trend_score']], niche_df], axis=1)
|
| 48 |
+
|
| 49 |
+
# Our target is to predict how many successful campaigns there will be.
|
| 50 |
+
y = df['successful_campaigns']
|
| 51 |
+
|
| 52 |
+
print(f"β
Data prepared. Features: {X.columns.tolist()}")
|
| 53 |
+
|
| 54 |
+
# 3. Split data for training and testing
|
| 55 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
| 56 |
+
print(f"\nπ Splitting data: {len(X_train)} rows for training, {len(X_test)} rows for testing.")
|
| 57 |
+
|
| 58 |
+
# 4. Train the Model
|
| 59 |
+
print("\nπ§ Training the Gradient Boosting Regressor model...")
|
| 60 |
+
|
| 61 |
+
# Gradient Boosting is a good choice for this kind of tabular data.
|
| 62 |
+
model = GradientBoostingRegressor(n_estimators=100, learning_rate=0.1, max_depth=3, random_state=42)
|
| 63 |
+
model.fit(X_train, y_train)
|
| 64 |
+
|
| 65 |
+
print("β
Model training complete.")
|
| 66 |
+
|
| 67 |
+
# 5. Evaluate the Model (optional, but good practice)
|
| 68 |
+
predictions = model.predict(X_test)
|
| 69 |
+
mse = mean_squared_error(y_test, predictions)
|
| 70 |
+
print(f"\nπ Model evaluation (Mean Squared Error): {mse:.2f}")
|
| 71 |
+
print(" (Lower is better. A small number means the model's predictions are close to the real values).")
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# 6. Save the Trained Model and the Encoder
|
| 75 |
+
print(f"\nπΎ Saving the trained model and encoder...")
|
| 76 |
+
try:
|
| 77 |
+
# We need to save BOTH the model AND the encoder, so we can use it for predictions later.
|
| 78 |
+
model_and_encoder = {
|
| 79 |
+
'model': model,
|
| 80 |
+
'encoder': encoder
|
| 81 |
+
}
|
| 82 |
+
joblib.dump(model_and_encoder, MODEL_OUTPUT_FILE)
|
| 83 |
+
print(f"β
Success! Model has been saved to:")
|
| 84 |
+
print(f" {MODEL_OUTPUT_FILE}")
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print(f"\nβ Error saving model file: {e}")
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
if __name__ == "__main__":
|
| 90 |
+
train_model()
|