"""Unit tests for ClinicalTrials.gov tool.""" from collections.abc import Generator from typing import Any from unittest.mock import MagicMock, patch import pytest import requests from src.tools.clinicaltrials import ClinicalTrialsTool from src.utils.exceptions import SearchError from src.utils.models import Evidence @pytest.fixture def mock_clinicaltrials_response() -> dict[str, Any]: """Mock ClinicalTrials.gov API response.""" return { "studies": [ { "protocolSection": { "identificationModule": { "nctId": "NCT04098666", "briefTitle": "Metformin in Alzheimer's Dementia Prevention", }, "statusModule": { "overallStatus": "Recruiting", "startDateStruct": {"date": "2020-01-15"}, }, "descriptionModule": { "briefSummary": "This study evaluates metformin for Alzheimer's prevention." }, "designModule": {"phases": ["PHASE2"]}, "conditionsModule": {"conditions": ["Alzheimer Disease", "Dementia"]}, "armsInterventionsModule": { "interventions": [{"name": "Metformin", "type": "Drug"}] }, } } ] } @pytest.fixture def mock_requests_get( mock_clinicaltrials_response: dict[str, Any], ) -> Generator[MagicMock, None, None]: """Fixture to mock requests.get with a successful response.""" with patch("src.tools.clinicaltrials.requests.get") as mock_get: mock_response = MagicMock() mock_response.json.return_value = mock_clinicaltrials_response mock_response.raise_for_status = MagicMock() mock_get.return_value = mock_response yield mock_get class TestClinicalTrialsTool: """Tests for ClinicalTrialsTool.""" def test_tool_name(self) -> None: """Tool should have correct name.""" tool = ClinicalTrialsTool() assert tool.name == "clinicaltrials" @pytest.mark.asyncio async def test_search_returns_evidence(self, mock_requests_get: MagicMock) -> None: """Search should return Evidence objects.""" tool = ClinicalTrialsTool() results = await tool.search("metformin alzheimer", max_results=5) assert len(results) == 1 assert isinstance(results[0], Evidence) assert results[0].citation.source == "clinicaltrials" assert "NCT04098666" in results[0].citation.url assert "Metformin" in results[0].citation.title @pytest.mark.asyncio async def test_search_extracts_phase(self, mock_requests_get: MagicMock) -> None: """Search should extract trial phase.""" tool = ClinicalTrialsTool() results = await tool.search("metformin alzheimer") assert "PHASE2" in results[0].content @pytest.mark.asyncio async def test_search_extracts_status(self, mock_requests_get: MagicMock) -> None: """Search should extract trial status.""" tool = ClinicalTrialsTool() results = await tool.search("metformin alzheimer") assert "Recruiting" in results[0].content @pytest.mark.asyncio async def test_search_empty_results(self) -> None: """Search should handle empty results gracefully.""" with patch("src.tools.clinicaltrials.requests.get") as mock_get: mock_response = MagicMock() mock_response.json.return_value = {"studies": []} mock_response.raise_for_status = MagicMock() mock_get.return_value = mock_response tool = ClinicalTrialsTool() results = await tool.search("nonexistent query xyz") assert results == [] @pytest.mark.asyncio async def test_search_api_error(self) -> None: """Search should raise SearchError on API failure. Note: We patch the retry decorator to avoid 3x backoff delay in tests. """ with patch("src.tools.clinicaltrials.requests.get") as mock_get: mock_response = MagicMock() mock_response.raise_for_status.side_effect = requests.HTTPError("500 Server Error") mock_get.return_value = mock_response tool = ClinicalTrialsTool() # Patch the retry decorator's stop condition to fail immediately tool.search.retry.stop = lambda _: True # type: ignore[attr-defined] with pytest.raises(SearchError): await tool.search("metformin alzheimer") class TestClinicalTrialsIntegration: """Integration tests (marked for separate run).""" @pytest.mark.integration @pytest.mark.asyncio async def test_real_api_call(self) -> None: """Test actual API call (requires network).""" tool = ClinicalTrialsTool() results = await tool.search("metformin diabetes", max_results=3) assert len(results) > 0 assert all(isinstance(r, Evidence) for r in results) assert all(r.citation.source == "clinicaltrials" for r in results)