Alon Albalak
major update: all data saved on HF (prompts, results), unified utilities
57be184
"""Data loading and saving functionality"""
import json
import random
import datetime
import uuid
from pathlib import Path
from huggingface_hub import CommitScheduler
from src.config.settings import HF_RESULTS_REPO, HF_PROMPTS_REPO
from src.utils.hf_data_manager import HFDataManager
JSON_DATASET_DIR = Path("testing/data/results")
JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True)
JSON_DATASET_PATH = JSON_DATASET_DIR / f"results_{uuid.uuid4()}.json"
scheduler = CommitScheduler(
repo_id=HF_RESULTS_REPO,
repo_type="dataset",
folder_path=JSON_DATASET_DIR.as_posix(),
path_in_repo="data",
every=10
)
class DataManager:
"""Manages loading and saving of prompts and results data"""
def __init__(self):
self.prompts_data = []
self.results = None
def load_prompts_data(self):
"""Load prompts data"""
self.prompts_data = self.load_from_hf(HF_PROMPTS_REPO)
if not self.prompts_data:
raise RuntimeError("No prompts data loaded from Hugging Face.")
def get_random_prompt(self):
"""Get a random prompt from loaded data"""
if not self.prompts_data:
raise RuntimeError("No prompts data loaded. Call load_prompts_data() first.")
return random.choice(self.prompts_data)
def get_results(self):
"""Get all results data, loading if not already loaded."""
if self.results is None:
self.results = self.load_from_hf(HF_RESULTS_REPO)
return self.results
def add_results(self, new_results):
"""Add new results to the existing results list."""
if self.results is None:
raise RuntimeError("Results not loaded. Call get_results() first.")
self.results.extend(new_results)
def load_from_hf(self, hf_repo):
"""Load data from Hugging Face dataset repository."""
return HFDataManager.load_from_hf(hf_repo)
def save_interaction_to_hf(self, prompt_data, user_continuation, generated_response,
cosine_distance, session_id, num_user_tokens):
interaction = {
"prompt": prompt_data["prompt"],
"model": prompt_data["model"],
"llm_partial_response": prompt_data["llm_partial_response"],
"llm_full_response_original": prompt_data["llm_full_response_original"],
"user_continuation": user_continuation,
"full_response_from_user": generated_response,
"cosine_distance": cosine_distance,
"timestamp": datetime.datetime.now().isoformat(),
"continuation_source": session_id,
"num_user_tokens": num_user_tokens,
"continuation_prompt": "",
"full_continuation_prompt": ""
}
self.add_results([interaction])
with scheduler.lock:
with open(JSON_DATASET_PATH, "a") as f:
f.write(json.dumps(interaction) + "\n")
def filter_results_by_partial_response(self, results, prompt, partial_response):
"""Filter results to only include entries for the current prompt."""
return [r for r in results if r["prompt"] == prompt and r["llm_partial_response"] == partial_response]
def filter_results_by_session(self, results, session_id):
"""Filter results to only include entries from the specified session."""
return [r for r in results if r.get("continuation_source") == session_id]
def get_gallery_responses(self, min_score=0.3, limit=20):
"""Get gallery responses with minimum creativity score"""
all_results = self.get_results()
# Filter by minimum score and sort by score (descending)
filtered_results = [r for r in all_results if r["cosine_distance"] >= min_score]
filtered_results.sort(key=lambda x: x["cosine_distance"], reverse=True)
# Return top results
return filtered_results[:limit]
def get_inspire_me_examples(self, prompt, partial_response, limit=5):
"""Get inspiring examples for the current prompt"""
all_results = self.get_results()
# Filter to current prompt and get good examples (≥0.2 score)
examples = [r for r in all_results
if r["prompt"] == prompt
and r["llm_partial_response"] == partial_response
and r["cosine_distance"] >= 0.2]
# Sort by creativity score and return random sample
examples.sort(key=lambda x: x["cosine_distance"], reverse=True)
# Get diverse examples (not just the top ones)
if len(examples) > limit:
# Take some from top, some from middle range
top_examples = examples[:min(3, len(examples))]
remaining = examples[3:]
if remaining:
additional = random.sample(remaining, min(limit-len(top_examples), len(remaining)))
examples = top_examples + additional
else:
examples = top_examples
return random.sample(examples, min(limit, len(examples)))