Spaces:
Running
Running
| """Shared research memory layer for all orchestration modes. | |
| Design Pattern: Dependency Injection | |
| - Receives embedding service via constructor | |
| - Uses service_loader.get_embedding_service() as default (Strategy Pattern) | |
| - Allows testing with mock services | |
| SOLID Principles: | |
| - Dependency Inversion: Depends on EmbeddingServiceProtocol, not concrete class | |
| - Open/Closed: Works with any service implementing the protocol | |
| """ | |
| from typing import TYPE_CHECKING, Any, get_args | |
| import structlog | |
| from src.agents.graph.state import Conflict, Hypothesis | |
| from src.utils.models import Citation, Evidence, SourceName | |
| if TYPE_CHECKING: | |
| from src.services.embedding_protocol import EmbeddingServiceProtocol | |
| logger = structlog.get_logger() | |
| class ResearchMemory: | |
| """Shared cognitive state for research workflows. | |
| This is the memory layer that ALL modes use. | |
| It mimics the LangGraph state management but for manual orchestration. | |
| The embedding service is selected via get_embedding_service(), which returns: | |
| - LlamaIndexRAGService (premium tier) if OPENAI_API_KEY is available | |
| - EmbeddingService (free tier) as fallback | |
| """ | |
| def __init__(self, query: str, embedding_service: "EmbeddingServiceProtocol | None" = None): | |
| """Initialize ResearchMemory with a query and optional embedding service. | |
| Args: | |
| query: The research query to track evidence for. | |
| embedding_service: Service for semantic search and deduplication. | |
| Uses get_embedding_service() if not provided, | |
| which selects the best available service. | |
| """ | |
| self.query = query | |
| self.hypotheses: list[Hypothesis] = [] | |
| self.conflicts: list[Conflict] = [] | |
| self.evidence_ids: list[str] = [] | |
| self._evidence_cache: dict[str, Evidence] = {} | |
| self.iteration_count: int = 0 | |
| # Use service loader for tiered service selection (Strategy Pattern) | |
| if embedding_service is None: | |
| from src.utils.service_loader import get_embedding_service | |
| self._embedding_service: EmbeddingServiceProtocol = get_embedding_service() | |
| else: | |
| self._embedding_service = embedding_service | |
| async def store_evidence(self, evidence: list[Evidence]) -> list[str]: | |
| """Store evidence and return new IDs (deduped).""" | |
| if not self._embedding_service: | |
| return [] | |
| # Deduplicate and store (deduplicate() already calls add_evidence() internally) | |
| unique = await self._embedding_service.deduplicate(evidence) | |
| # Track IDs and cache (evidence already stored by deduplicate()) | |
| new_ids = [] | |
| for ev in unique: | |
| ev_id = ev.citation.url | |
| new_ids.append(ev_id) | |
| self._evidence_cache[ev_id] = ev | |
| self.evidence_ids.extend(new_ids) | |
| if new_ids: | |
| logger.info("Stored new evidence", count=len(new_ids)) | |
| return new_ids | |
| def get_all_evidence(self) -> list[Evidence]: | |
| """Get all accumulated evidence objects.""" | |
| return list(self._evidence_cache.values()) | |
| async def get_relevant_evidence(self, n: int = 20) -> list[Evidence]: | |
| """Retrieve relevant evidence for current query.""" | |
| if not self._embedding_service: | |
| return [] | |
| results = await self._embedding_service.search_similar(self.query, n_results=n) | |
| evidence_list = [] | |
| for r in results: | |
| meta = r.get("metadata", {}) | |
| authors_str = meta.get("authors", "") | |
| authors = [a.strip() for a in authors_str.split(",")] if authors_str else [] | |
| # Reconstruct Evidence object | |
| source_raw = meta.get("source", "web") | |
| # Validate source against canonical SourceName type (avoids drift) | |
| valid_sources = get_args(SourceName) | |
| source_name: Any = source_raw if source_raw in valid_sources else "web" | |
| citation = Citation( | |
| source=source_name, | |
| title=meta.get("title", "Unknown"), | |
| url=meta.get("url", r.get("id", "")), | |
| date=meta.get("date", "Unknown"), | |
| authors=authors, | |
| ) | |
| evidence_list.append( | |
| Evidence( | |
| content=r.get("content", ""), | |
| citation=citation, | |
| relevance=1.0 - r.get("distance", 0.5), # Approx conversion | |
| ) | |
| ) | |
| return evidence_list | |
| def add_hypothesis(self, hypothesis: Hypothesis) -> None: | |
| """Add a hypothesis to tracking.""" | |
| self.hypotheses.append(hypothesis) | |
| logger.info("Added hypothesis", id=hypothesis.id, confidence=hypothesis.confidence) | |
| def add_conflict(self, conflict: Conflict) -> None: | |
| """Add a detected conflict.""" | |
| self.conflicts.append(conflict) | |
| logger.info("Added conflict", id=conflict.id) | |
| def get_open_conflicts(self) -> list[Conflict]: | |
| """Get unresolved conflicts.""" | |
| return [c for c in self.conflicts if c.status == "open"] | |
| def get_confirmed_hypotheses(self) -> list[Hypothesis]: | |
| """Get high-confidence hypotheses.""" | |
| return [h for h in self.hypotheses if h.confidence > 0.8] | |