Spaces:
Sleeping
Sleeping
| """Data loading and saving functionality""" | |
| import json | |
| import random | |
| import datetime | |
| import uuid | |
| from pathlib import Path | |
| from huggingface_hub import CommitScheduler | |
| from src.config.settings import HF_RESULTS_REPO, HF_PROMPTS_REPO | |
| from src.utils.hf_data_manager import HFDataManager | |
| JSON_DATASET_DIR = Path("testing/data/results") | |
| JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True) | |
| JSON_DATASET_PATH = JSON_DATASET_DIR / f"results_{uuid.uuid4()}.json" | |
| scheduler = CommitScheduler( | |
| repo_id=HF_RESULTS_REPO, | |
| repo_type="dataset", | |
| folder_path=JSON_DATASET_DIR.as_posix(), | |
| path_in_repo="data", | |
| every=10 | |
| ) | |
| class DataManager: | |
| """Manages loading and saving of prompts and results data""" | |
| def __init__(self): | |
| self.prompts_data = [] | |
| self.results = None | |
| def load_prompts_data(self): | |
| """Load prompts data""" | |
| self.prompts_data = self.load_from_hf(HF_PROMPTS_REPO) | |
| if not self.prompts_data: | |
| raise RuntimeError("No prompts data loaded from Hugging Face.") | |
| def get_random_prompt(self): | |
| """Get a random prompt from loaded data""" | |
| if not self.prompts_data: | |
| raise RuntimeError("No prompts data loaded. Call load_prompts_data() first.") | |
| return random.choice(self.prompts_data) | |
| def get_results(self): | |
| """Get all results data, loading if not already loaded.""" | |
| if self.results is None: | |
| self.results = self.load_from_hf(HF_RESULTS_REPO) | |
| return self.results | |
| def add_results(self, new_results): | |
| """Add new results to the existing results list.""" | |
| if self.results is None: | |
| raise RuntimeError("Results not loaded. Call get_results() first.") | |
| self.results.extend(new_results) | |
| def load_from_hf(self, hf_repo): | |
| """Load data from Hugging Face dataset repository.""" | |
| return HFDataManager.load_from_hf(hf_repo) | |
| def save_interaction_to_hf(self, prompt_data, user_continuation, generated_response, | |
| cosine_distance, session_id, num_user_tokens): | |
| interaction = { | |
| "prompt": prompt_data["prompt"], | |
| "model": prompt_data["model"], | |
| "llm_partial_response": prompt_data["llm_partial_response"], | |
| "llm_full_response_original": prompt_data["llm_full_response_original"], | |
| "user_continuation": user_continuation, | |
| "full_response_from_user": generated_response, | |
| "cosine_distance": cosine_distance, | |
| "timestamp": datetime.datetime.now().isoformat(), | |
| "continuation_source": session_id, | |
| "num_user_tokens": num_user_tokens, | |
| "continuation_prompt": "", | |
| "full_continuation_prompt": "" | |
| } | |
| self.add_results([interaction]) | |
| with scheduler.lock: | |
| with open(JSON_DATASET_PATH, "a") as f: | |
| f.write(json.dumps(interaction) + "\n") | |
| def filter_results_by_partial_response(self, results, prompt, partial_response): | |
| """Filter results to only include entries for the current prompt.""" | |
| return [r for r in results if r["prompt"] == prompt and r["llm_partial_response"] == partial_response] | |
| def filter_results_by_session(self, results, session_id): | |
| """Filter results to only include entries from the specified session.""" | |
| return [r for r in results if r.get("continuation_source") == session_id] | |
| def get_gallery_responses(self, min_score=0.3, limit=20): | |
| """Get gallery responses with minimum creativity score""" | |
| all_results = self.get_results() | |
| # Filter by minimum score and sort by score (descending) | |
| filtered_results = [r for r in all_results if r["cosine_distance"] >= min_score] | |
| filtered_results.sort(key=lambda x: x["cosine_distance"], reverse=True) | |
| # Return top results | |
| return filtered_results[:limit] | |
| def get_inspire_me_examples(self, prompt, partial_response, limit=5): | |
| """Get inspiring examples for the current prompt""" | |
| all_results = self.get_results() | |
| # Filter to current prompt and get good examples (≥0.2 score) | |
| examples = [r for r in all_results | |
| if r["prompt"] == prompt | |
| and r["llm_partial_response"] == partial_response | |
| and r["cosine_distance"] >= 0.2] | |
| # Sort by creativity score and return random sample | |
| examples.sort(key=lambda x: x["cosine_distance"], reverse=True) | |
| # Get diverse examples (not just the top ones) | |
| if len(examples) > limit: | |
| # Take some from top, some from middle range | |
| top_examples = examples[:min(3, len(examples))] | |
| remaining = examples[3:] | |
| if remaining: | |
| additional = random.sample(remaining, min(limit-len(top_examples), len(remaining))) | |
| examples = top_examples + additional | |
| else: | |
| examples = top_examples | |
| return random.sample(examples, min(limit, len(examples))) |