File size: 3,223 Bytes
1922dbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15459e9
 
1922dbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""Judge agent wrapper for Magentic integration."""

from collections.abc import AsyncIterable
from typing import Any

from agent_framework import (
    AgentRunResponse,
    AgentRunResponseUpdate,
    AgentThread,
    BaseAgent,
    ChatMessage,
    Role,
)

from src.orchestrator import JudgeHandlerProtocol
from src.utils.models import Evidence, JudgeAssessment


class JudgeAgent(BaseAgent):  # type: ignore[misc]
    """Wraps JudgeHandler as an AgentProtocol for Magentic."""

    def __init__(
        self,
        judge_handler: JudgeHandlerProtocol,
        evidence_store: dict[str, list[Evidence]],
    ) -> None:
        super().__init__(
            name="JudgeAgent",
            description="Evaluates evidence quality and determines if sufficient for synthesis",
        )
        self._handler = judge_handler
        self._evidence_store = evidence_store  # Shared state for evidence

    async def run(
        self,
        messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
        *,
        thread: AgentThread | None = None,
        **kwargs: Any,
    ) -> AgentRunResponse:
        """Assess evidence quality."""
        # Extract original question from messages
        question = ""
        if isinstance(messages, list):
            for msg in reversed(messages):
                if isinstance(msg, ChatMessage) and msg.role == Role.USER and msg.text:
                    question = msg.text
                    break
                elif isinstance(msg, str):
                    question = msg
                    break
        elif isinstance(messages, str):
            question = messages
        elif isinstance(messages, ChatMessage) and messages.text:
            question = messages.text

        # Get evidence from shared store
        evidence = self._evidence_store.get("current", [])

        # Assess
        assessment: JudgeAssessment = await self._handler.assess(question, evidence)

        # Format response
        response_text = f"""## Assessment

**Sufficient**: {assessment.sufficient}
**Confidence**: {assessment.confidence:.0%}
**Recommendation**: {assessment.recommendation}

### Scores
- Mechanism: {assessment.details.mechanism_score}/10
- Clinical: {assessment.details.clinical_evidence_score}/10

### Reasoning
{assessment.reasoning}
"""

        if assessment.next_search_queries:
            response_text += "\n### Next Queries\n" + "\n".join(
                f"- {q}" for q in assessment.next_search_queries
            )

        return AgentRunResponse(
            messages=[ChatMessage(role=Role.ASSISTANT, text=response_text)],
            response_id=f"judge-{assessment.recommendation}",
            additional_properties={"assessment": assessment.model_dump()},
        )

    async def run_stream(
        self,
        messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
        *,
        thread: AgentThread | None = None,
        **kwargs: Any,
    ) -> AsyncIterable[AgentRunResponseUpdate]:
        """Streaming wrapper for judge."""
        result = await self.run(messages, thread=thread, **kwargs)
        yield AgentRunResponseUpdate(messages=result.messages, response_id=result.response_id)