VibecoderMcSwaggins commited on
Commit
316dc7d
·
unverified ·
1 Parent(s): 9286db5

feat: implement async-safe rate limiting (Phase 17) (#40)

Browse files

- Add 'limits' library dependency
- Implement async RateLimiter with finer-grained polling (0.01s)
- Refactor PubMedTool to use shared singleton limiter
- Add comprehensive unit tests and demo script

examples/rate_limiting_demo.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Demo script to verify rate limiting works correctly."""
3
+
4
+ import asyncio
5
+ import time
6
+
7
+ from src.tools.pubmed import PubMedTool
8
+ from src.tools.rate_limiter import RateLimiter, get_pubmed_limiter, reset_pubmed_limiter
9
+
10
+
11
+ async def test_basic_limiter():
12
+ """Test basic rate limiter behavior."""
13
+ print("=" * 60)
14
+ print("Rate Limiting Demo")
15
+ print("=" * 60)
16
+
17
+ # Test 1: Basic limiter
18
+ print("\n[Test 1] Testing 3/second limiter...")
19
+ limiter = RateLimiter("3/second")
20
+
21
+ start = time.monotonic()
22
+ for i in range(6):
23
+ await limiter.acquire()
24
+ elapsed = time.monotonic() - start
25
+ print(f" Request {i+1} at {elapsed:.2f}s")
26
+
27
+ total = time.monotonic() - start
28
+ print(f" Total time for 6 requests: {total:.2f}s (expected ~2s)")
29
+
30
+
31
+ async def test_pubmed_limiter():
32
+ """Test PubMed-specific limiter."""
33
+ print("\n[Test 2] Testing PubMed limiter (shared)...")
34
+
35
+ reset_pubmed_limiter() # Clean state
36
+
37
+ # Without API key: 3/sec
38
+ limiter = get_pubmed_limiter(api_key=None)
39
+ print(f" Rate without key: {limiter.rate}")
40
+
41
+ # Multiple tools should share the same limiter
42
+ tool1 = PubMedTool()
43
+ tool2 = PubMedTool()
44
+
45
+ # Verify they share the limiter
46
+ print(f" Tools share limiter: {tool1._limiter is tool2._limiter}")
47
+
48
+
49
+ async def test_concurrent_requests():
50
+ """Test rate limiting under concurrent load."""
51
+ print("\n[Test 3] Testing concurrent request limiting...")
52
+
53
+ limiter = RateLimiter("5/second")
54
+
55
+ async def make_request(i: int):
56
+ await limiter.acquire()
57
+ return time.monotonic()
58
+
59
+ start = time.monotonic()
60
+ # Launch 10 concurrent requests
61
+ tasks = [make_request(i) for i in range(10)]
62
+ times = await asyncio.gather(*tasks)
63
+
64
+ # Calculate distribution
65
+ relative_times = [t - start for t in times]
66
+ print(f" Request times: {[f'{t:.2f}s' for t in sorted(relative_times)]}")
67
+
68
+ total = max(relative_times)
69
+ print(f" All 10 requests completed in {total:.2f}s (expected ~2s)")
70
+
71
+
72
+ async def main():
73
+ await test_basic_limiter()
74
+ await test_pubmed_limiter()
75
+ await test_concurrent_requests()
76
+
77
+ print("\n" + "=" * 60)
78
+ print("Demo complete!")
79
+
80
+
81
+ if __name__ == "__main__":
82
+ asyncio.run(main())
pyproject.toml CHANGED
@@ -24,6 +24,7 @@ dependencies = [
24
  "tenacity>=8.2", # Retry logic
25
  "structlog>=24.1", # Structured logging
26
  "requests>=2.32.5", # ClinicalTrials.gov (httpx blocked by WAF)
 
27
  ]
28
 
29
  [project.optional-dependencies]
 
24
  "tenacity>=8.2", # Retry logic
25
  "structlog>=24.1", # Structured logging
26
  "requests>=2.32.5", # ClinicalTrials.gov (httpx blocked by WAF)
27
+ "limits>=3.0", # Rate limiting
28
  ]
29
 
30
  [project.optional-dependencies]
src/tools/pubmed.py CHANGED
@@ -1,6 +1,5 @@
1
  """PubMed search tool using NCBI E-utilities."""
2
 
3
- import asyncio
4
  from typing import Any
5
 
6
  import httpx
@@ -8,6 +7,7 @@ import xmltodict
8
  from tenacity import retry, stop_after_attempt, wait_exponential
9
 
10
  from src.tools.query_utils import preprocess_query
 
11
  from src.utils.config import settings
12
  from src.utils.exceptions import RateLimitError, SearchError
13
  from src.utils.models import Citation, Evidence
@@ -17,7 +17,6 @@ class PubMedTool:
17
  """Search tool for PubMed/NCBI."""
18
 
19
  BASE_URL = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
20
- RATE_LIMIT_DELAY = 0.34 # ~3 requests/sec without API key
21
  HTTP_TOO_MANY_REQUESTS = 429
22
 
23
  def __init__(self, api_key: str | None = None) -> None:
@@ -25,7 +24,9 @@ class PubMedTool:
25
  # Ignore placeholder values from .env.example
26
  if self.api_key == "your-ncbi-key-here":
27
  self.api_key = None
28
- self._last_request_time = 0.0
 
 
29
 
30
  @property
31
  def name(self) -> str:
@@ -33,12 +34,7 @@ class PubMedTool:
33
 
34
  async def _rate_limit(self) -> None:
35
  """Enforce NCBI rate limiting."""
36
- loop = asyncio.get_running_loop()
37
- now = loop.time()
38
- elapsed = now - self._last_request_time
39
- if elapsed < self.RATE_LIMIT_DELAY:
40
- await asyncio.sleep(self.RATE_LIMIT_DELAY - elapsed)
41
- self._last_request_time = loop.time()
42
 
43
  def _build_params(self, **kwargs: Any) -> dict[str, Any]:
44
  """Build request params with optional API key."""
 
1
  """PubMed search tool using NCBI E-utilities."""
2
 
 
3
  from typing import Any
4
 
5
  import httpx
 
7
  from tenacity import retry, stop_after_attempt, wait_exponential
8
 
9
  from src.tools.query_utils import preprocess_query
10
+ from src.tools.rate_limiter import get_pubmed_limiter
11
  from src.utils.config import settings
12
  from src.utils.exceptions import RateLimitError, SearchError
13
  from src.utils.models import Citation, Evidence
 
17
  """Search tool for PubMed/NCBI."""
18
 
19
  BASE_URL = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
 
20
  HTTP_TOO_MANY_REQUESTS = 429
21
 
22
  def __init__(self, api_key: str | None = None) -> None:
 
24
  # Ignore placeholder values from .env.example
25
  if self.api_key == "your-ncbi-key-here":
26
  self.api_key = None
27
+
28
+ # Use shared rate limiter
29
+ self._limiter = get_pubmed_limiter(self.api_key)
30
 
31
  @property
32
  def name(self) -> str:
 
34
 
35
  async def _rate_limit(self) -> None:
36
  """Enforce NCBI rate limiting."""
37
+ await self._limiter.acquire()
 
 
 
 
 
38
 
39
  def _build_params(self, **kwargs: Any) -> dict[str, Any]:
40
  """Build request params with optional API key."""
src/tools/rate_limiter.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Rate limiting utilities using the limits library."""
2
+
3
+ import asyncio
4
+ from typing import ClassVar
5
+
6
+ from limits import RateLimitItem, parse
7
+ from limits.storage import MemoryStorage
8
+ from limits.strategies import MovingWindowRateLimiter
9
+
10
+
11
+ class RateLimiter:
12
+ """
13
+ Async-compatible rate limiter using limits library.
14
+
15
+ Uses moving window algorithm for smooth rate limiting.
16
+ """
17
+
18
+ def __init__(self, rate: str) -> None:
19
+ """
20
+ Initialize rate limiter.
21
+
22
+ Args:
23
+ rate: Rate string like "3/second" or "10/second"
24
+ """
25
+ self.rate = rate
26
+ self._storage = MemoryStorage()
27
+ self._limiter = MovingWindowRateLimiter(self._storage)
28
+ self._rate_limit: RateLimitItem = parse(rate)
29
+ self._identity = "default" # Single identity for shared limiting
30
+
31
+ async def acquire(self, wait: bool = True) -> bool:
32
+ """
33
+ Acquire permission to make a request.
34
+
35
+ ASYNC-SAFE: Uses asyncio.sleep(), never time.sleep().
36
+ The polling pattern allows other coroutines to run while waiting.
37
+
38
+ Args:
39
+ wait: If True, wait until allowed. If False, return immediately.
40
+
41
+ Returns:
42
+ True if allowed, False if not (only when wait=False)
43
+ """
44
+ while True:
45
+ # Check if we can proceed (synchronous, fast - ~microseconds)
46
+ if self._limiter.hit(self._rate_limit, self._identity):
47
+ return True
48
+
49
+ if not wait:
50
+ return False
51
+
52
+ # CRITICAL: Use asyncio.sleep(), NOT time.sleep()
53
+ # This yields control to the event loop, allowing other
54
+ # coroutines (UI, parallel searches) to run.
55
+ # Using 0.01s for fine-grained responsiveness.
56
+ await asyncio.sleep(0.01)
57
+
58
+ def reset(self) -> None:
59
+ """Reset the rate limiter (for testing)."""
60
+ self._storage.reset()
61
+
62
+
63
+ # Singleton limiter for PubMed/NCBI
64
+ _pubmed_limiter: RateLimiter | None = None
65
+
66
+
67
+ def get_pubmed_limiter(api_key: str | None = None) -> RateLimiter:
68
+ """
69
+ Get the shared PubMed rate limiter.
70
+
71
+ Rate depends on whether API key is provided:
72
+ - Without key: 3 requests/second
73
+ - With key: 10 requests/second
74
+
75
+ Args:
76
+ api_key: NCBI API key (optional)
77
+
78
+ Returns:
79
+ Shared RateLimiter instance
80
+ """
81
+ global _pubmed_limiter
82
+
83
+ if _pubmed_limiter is None:
84
+ rate = "10/second" if api_key else "3/second"
85
+ _pubmed_limiter = RateLimiter(rate)
86
+
87
+ return _pubmed_limiter
88
+
89
+
90
+ def reset_pubmed_limiter() -> None:
91
+ """Reset the PubMed limiter (for testing)."""
92
+ global _pubmed_limiter
93
+ _pubmed_limiter = None
94
+
95
+
96
+ # Factory for other APIs
97
+ class RateLimiterFactory:
98
+ """Factory for creating/getting rate limiters for different APIs."""
99
+
100
+ _limiters: ClassVar[dict[str, RateLimiter]] = {}
101
+
102
+ @classmethod
103
+ def get(cls, api_name: str, rate: str) -> RateLimiter:
104
+ """
105
+ Get or create a rate limiter for an API.
106
+
107
+ Args:
108
+ api_name: Unique identifier for the API
109
+ rate: Rate limit string (e.g., "10/second")
110
+
111
+ Returns:
112
+ RateLimiter instance (shared for same api_name)
113
+ """
114
+ if api_name not in cls._limiters:
115
+ cls._limiters[api_name] = RateLimiter(rate)
116
+ return cls._limiters[api_name]
117
+
118
+ @classmethod
119
+ def reset_all(cls) -> None:
120
+ """Reset all limiters (for testing)."""
121
+ cls._limiters.clear()
tests/unit/tools/test_rate_limiting.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for rate limiting functionality."""
2
+
3
+ import asyncio
4
+ import time
5
+
6
+ import pytest
7
+
8
+ from src.tools.rate_limiter import RateLimiter, get_pubmed_limiter, reset_pubmed_limiter
9
+
10
+
11
+ class TestRateLimiter:
12
+ """Test suite for rate limiter."""
13
+
14
+ def test_create_limiter_without_api_key(self) -> None:
15
+ """Should create 3/sec limiter without API key."""
16
+ limiter = RateLimiter(rate="3/second")
17
+ assert limiter.rate == "3/second"
18
+
19
+ def test_create_limiter_with_api_key(self) -> None:
20
+ """Should create 10/sec limiter with API key."""
21
+ limiter = RateLimiter(rate="10/second")
22
+ assert limiter.rate == "10/second"
23
+
24
+ @pytest.mark.asyncio
25
+ async def test_limiter_allows_requests_under_limit(self) -> None:
26
+ """Should allow requests under the rate limit."""
27
+ limiter = RateLimiter(rate="10/second")
28
+
29
+ # 3 requests should all succeed immediately
30
+ for _ in range(3):
31
+ allowed = await limiter.acquire()
32
+ assert allowed is True
33
+
34
+ @pytest.mark.asyncio
35
+ async def test_limiter_blocks_when_exceeded(self) -> None:
36
+ """Should wait when rate limit exceeded."""
37
+ limiter = RateLimiter(rate="2/second")
38
+
39
+ # First 2 should be instant
40
+ await limiter.acquire()
41
+ await limiter.acquire()
42
+
43
+ # Third should block briefly
44
+ start = time.monotonic()
45
+ await limiter.acquire()
46
+ elapsed = time.monotonic() - start
47
+
48
+ # Should have waited ~0.5 seconds (half second window for 2/sec)
49
+ assert elapsed >= 0.3
50
+
51
+ @pytest.mark.asyncio
52
+ async def test_limiter_resets_after_window(self) -> None:
53
+ """Rate limit should reset after time window."""
54
+ limiter = RateLimiter(rate="5/second")
55
+
56
+ # Use up the limit
57
+ for _ in range(5):
58
+ await limiter.acquire()
59
+
60
+ # Wait for window to pass
61
+ await asyncio.sleep(1.1)
62
+
63
+ # Should be allowed again
64
+ start = time.monotonic()
65
+ await limiter.acquire()
66
+ elapsed = time.monotonic() - start
67
+
68
+ assert elapsed < 0.1 # Should be nearly instant
69
+
70
+
71
+ class TestGetPubmedLimiter:
72
+ """Test PubMed-specific limiter factory."""
73
+
74
+ @pytest.fixture(autouse=True)
75
+ def setup_teardown(self):
76
+ """Reset limiter before and after each test."""
77
+ reset_pubmed_limiter()
78
+ yield
79
+ reset_pubmed_limiter()
80
+
81
+ def test_limiter_without_api_key(self) -> None:
82
+ """Should return 3/sec limiter without key."""
83
+ limiter = get_pubmed_limiter(api_key=None)
84
+ assert "3" in limiter.rate
85
+
86
+ def test_limiter_with_api_key(self) -> None:
87
+ """Should return 10/sec limiter with key."""
88
+ limiter = get_pubmed_limiter(api_key="my-api-key")
89
+ assert "10" in limiter.rate
90
+
91
+ def test_limiter_is_singleton(self) -> None:
92
+ """Same API key should return same limiter instance."""
93
+ limiter1 = get_pubmed_limiter(api_key="key1")
94
+ limiter2 = get_pubmed_limiter(api_key="key1")
95
+ assert limiter1 is limiter2
96
+
97
+ def test_different_keys_different_limiters(self) -> None:
98
+ """Different API keys should return different limiters."""
99
+ limiter1 = get_pubmed_limiter(api_key="key1")
100
+ limiter2 = get_pubmed_limiter(api_key="key2")
101
+ # Clear cache for clean test
102
+ # Actually, different keys SHOULD share the same limiter
103
+ # since we're limiting against the same API
104
+ assert limiter1 is limiter2 # Shared NCBI rate limit
uv.lock CHANGED
@@ -1066,6 +1066,7 @@ dependencies = [
1066
  { name = "gradio", extra = ["mcp"] },
1067
  { name = "httpx" },
1068
  { name = "huggingface-hub" },
 
1069
  { name = "openai" },
1070
  { name = "pydantic" },
1071
  { name = "pydantic-ai" },
@@ -1116,6 +1117,7 @@ requires-dist = [
1116
  { name = "gradio", extras = ["mcp"], specifier = ">=6.0.0" },
1117
  { name = "httpx", specifier = ">=0.27" },
1118
  { name = "huggingface-hub", specifier = ">=0.20.0" },
 
1119
  { name = "llama-index", marker = "extra == 'modal'", specifier = ">=0.11.0" },
1120
  { name = "llama-index-embeddings-openai", marker = "extra == 'modal'" },
1121
  { name = "llama-index-llms-openai", marker = "extra == 'modal'" },
@@ -2259,6 +2261,20 @@ wheels = [
2259
  { url = "https://files.pythonhosted.org/packages/ca/ec/65f7d563aa4a62dd58777e8f6aa882f15db53b14eb29aba0c28a20f7eb26/kubernetes-34.1.0-py2.py3-none-any.whl", hash = "sha256:bffba2272534e224e6a7a74d582deb0b545b7c9879d2cd9e4aae9481d1f2cc2a", size = 2008380 },
2260
  ]
2261
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2262
  [[package]]
2263
  name = "llama-cloud"
2264
  version = "0.1.35"
 
1066
  { name = "gradio", extra = ["mcp"] },
1067
  { name = "httpx" },
1068
  { name = "huggingface-hub" },
1069
+ { name = "limits" },
1070
  { name = "openai" },
1071
  { name = "pydantic" },
1072
  { name = "pydantic-ai" },
 
1117
  { name = "gradio", extras = ["mcp"], specifier = ">=6.0.0" },
1118
  { name = "httpx", specifier = ">=0.27" },
1119
  { name = "huggingface-hub", specifier = ">=0.20.0" },
1120
+ { name = "limits", specifier = ">=3.0" },
1121
  { name = "llama-index", marker = "extra == 'modal'", specifier = ">=0.11.0" },
1122
  { name = "llama-index-embeddings-openai", marker = "extra == 'modal'" },
1123
  { name = "llama-index-llms-openai", marker = "extra == 'modal'" },
 
2261
  { url = "https://files.pythonhosted.org/packages/ca/ec/65f7d563aa4a62dd58777e8f6aa882f15db53b14eb29aba0c28a20f7eb26/kubernetes-34.1.0-py2.py3-none-any.whl", hash = "sha256:bffba2272534e224e6a7a74d582deb0b545b7c9879d2cd9e4aae9481d1f2cc2a", size = 2008380 },
2262
  ]
2263
 
2264
+ [[package]]
2265
+ name = "limits"
2266
+ version = "5.6.0"
2267
+ source = { registry = "https://pypi.org/simple" }
2268
+ dependencies = [
2269
+ { name = "deprecated" },
2270
+ { name = "packaging" },
2271
+ { name = "typing-extensions" },
2272
+ ]
2273
+ sdist = { url = "https://files.pythonhosted.org/packages/bb/e5/c968d43a65128cd54fb685f257aafb90cd5e4e1c67d084a58f0e4cbed557/limits-5.6.0.tar.gz", hash = "sha256:807fac75755e73912e894fdd61e2838de574c5721876a19f7ab454ae1fffb4b5", size = 182984 }
2274
+ wheels = [
2275
+ { url = "https://files.pythonhosted.org/packages/40/96/4fcd44aed47b8fcc457653b12915fcad192cd646510ef3f29fd216f4b0ab/limits-5.6.0-py3-none-any.whl", hash = "sha256:b585c2104274528536a5b68864ec3835602b3c4a802cd6aa0b07419798394021", size = 60604 },
2276
+ ]
2277
+
2278
  [[package]]
2279
  name = "llama-cloud"
2280
  version = "0.1.35"