Srikasi commited on
Commit
e910d0b
·
verified ·
1 Parent(s): cdf53d6

Create ratelimits.py

Browse files
Files changed (1) hide show
  1. ratelimits.py +287 -0
ratelimits.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # astro_diffusion/ratelimits.py
2
+ import time
3
+ import threading
4
+ import os
5
+ from typing import Tuple, Optional, Dict, Any
6
+
7
+
8
+ class RateLimiter:
9
+ """
10
+ In-memory, multi-scope limiter.
11
+
12
+ Scopes:
13
+ - per-session
14
+ - per-IP: hour, day
15
+ - global: hour, day, month
16
+ - time/cost: per-IP day, global day/month
17
+
18
+ All counters reset on container restart.
19
+ """
20
+
21
+ def __init__(self):
22
+ # ---------------- session limits ----------------
23
+ self.per_session_max_req = int(os.getenv("AD_SESSION_MAX_REQ", "5"))
24
+ self.per_session_max_age = int(os.getenv("AD_SESSION_MAX_AGE_SEC", str(15 * 60))) # 15 min
25
+
26
+ # ---------------- per-IP limits ----------------
27
+ # default per your request: 50/hr, 100/day
28
+ self.per_ip_max_req_hour = int(os.getenv("AD_IP_MAX_REQ_HOUR", "10"))
29
+ self.per_ip_max_req_day = int(os.getenv("AD_IP_MAX_REQ_DAY", "100"))
30
+ self.per_ip_max_active_sec_day = int(
31
+ os.getenv("AD_IP_MAX_ACTIVE_SEC_DAY", str(60 * 60))
32
+ ) # 1h active time per day
33
+
34
+ # ---------------- global limits ----------------
35
+ # default per your request: 50/hr, 100/day, 500/month
36
+ self.global_max_req_hour = int(os.getenv("AD_GLOBAL_MAX_REQ_HOUR", "50"))
37
+ self.global_max_req_day = int(os.getenv("AD_GLOBAL_MAX_REQ_DAY", "100"))
38
+ self.global_max_req_month = int(os.getenv("AD_GLOBAL_MAX_REQ_MONTH", "500"))
39
+ self.global_max_active_sec_day = int(
40
+ os.getenv("AD_GLOBAL_MAX_ACTIVE_SEC_DAY", str(6 * 60 * 60))
41
+ ) # 6h
42
+
43
+ # ---------------- cost limits ----------------
44
+ self.cost_per_sec = float(os.getenv("AD_COST_PER_SEC", "0.0005"))
45
+ self.daily_cost_limit = float(os.getenv("AD_DAILY_COST_LIMIT", "5.0"))
46
+ self.monthly_cost_limit = float(os.getenv("AD_MONTHLY_COST_LIMIT", "10.0"))
47
+
48
+ self._lock = threading.Lock()
49
+
50
+ # per-ip buckets
51
+ self._ip_hour: Dict[str, Dict[str, float]] = {}
52
+ self._ip_day: Dict[str, Dict[str, float]] = {}
53
+
54
+ # global buckets
55
+ now = time.time()
56
+ self._global_hour = {"count": 0, "reset_at": now + 3600}
57
+ self._global_day = {
58
+ "count": 0,
59
+ "active_sec": 0.0,
60
+ "cost": 0.0,
61
+ "reset_at": now + 86400,
62
+ }
63
+ # requests per month
64
+ self._global_month_req = {"count": 0, "reset_at": now + 30 * 86400}
65
+ # cost per month
66
+ self._global_month_cost = {"cost": 0.0, "reset_at": now + 30 * 86400}
67
+
68
+ # -------------------------------------------------
69
+ # helpers
70
+ # -------------------------------------------------
71
+ @staticmethod
72
+ def _fmt_wait(seconds: float) -> str:
73
+ if seconds < 0:
74
+ seconds = 0
75
+ # choose largest sensible unit
76
+ if seconds >= 3600:
77
+ h = seconds / 3600.0
78
+ return f"{h:.1f} hours"
79
+ elif seconds >= 60:
80
+ m = seconds / 60.0
81
+ return f"{m:.1f} minutes"
82
+ else:
83
+ return f"{seconds:.0f} seconds"
84
+
85
+ def _get_ip_hour_bucket(self, ip: str):
86
+ now = time.time()
87
+ b = self._ip_hour.get(ip)
88
+ if b is None or now >= b["reset_at"]:
89
+ b = {"count": 0, "reset_at": now + 3600}
90
+ self._ip_hour[ip] = b
91
+ return b
92
+
93
+ def _get_ip_day_bucket(self, ip: str):
94
+ now = time.time()
95
+ b = self._ip_day.get(ip)
96
+ if b is None or now >= b["reset_at"]:
97
+ b = {"count": 0, "active_sec": 0.0, "reset_at": now + 86400}
98
+ self._ip_day[ip] = b
99
+ return b
100
+
101
+ def _get_global_hour(self):
102
+ now = time.time()
103
+ g = self._global_hour
104
+ if now >= g["reset_at"]:
105
+ g["count"] = 0
106
+ g["reset_at"] = now + 3600
107
+ return g
108
+
109
+ def _get_global_day(self):
110
+ now = time.time()
111
+ g = self._global_day
112
+ if now >= g["reset_at"]:
113
+ g["count"] = 0
114
+ g["active_sec"] = 0.0
115
+ g["cost"] = 0.0
116
+ g["reset_at"] = now + 86400
117
+ return g
118
+
119
+ def _get_global_month_req(self):
120
+ now = time.time()
121
+ g = self._global_month_req
122
+ if now >= g["reset_at"]:
123
+ g["count"] = 0
124
+ g["reset_at"] = now + 30 * 86400
125
+ return g
126
+
127
+ def _get_global_month_cost(self):
128
+ now = time.time()
129
+ g = self._global_month_cost
130
+ if now >= g["reset_at"]:
131
+ g["cost"] = 0.0
132
+ g["reset_at"] = now + 30 * 86400
133
+ return g
134
+
135
+ # -------------------------------------------------
136
+ # public API
137
+ # -------------------------------------------------
138
+ def pre_check(
139
+ self,
140
+ ip: str,
141
+ session_state: Dict[str, Any],
142
+ ) -> Tuple[bool, Optional[str]]:
143
+ """
144
+ Check count/time/cost limits BEFORE doing work.
145
+ session_state is expected to have:
146
+ {
147
+ "count": int,
148
+ "started_at": float
149
+ }
150
+ """
151
+ now = time.time()
152
+ with self._lock:
153
+ # ---- session checks ----
154
+ sess_count = int(session_state.get("count", 0))
155
+ sess_started = float(session_state.get("started_at", now))
156
+
157
+ # session age
158
+ session_age = now - sess_started
159
+ if session_age > self.per_session_max_age:
160
+ # session too old
161
+ wait_str = self._fmt_wait(0)
162
+ return (
163
+ False,
164
+ f"session time cap reached ({self.per_session_max_age} sec). try again in {wait_str}",
165
+ )
166
+
167
+ # session request count
168
+ if sess_count >= self.per_session_max_req:
169
+ # they must wait until session-age window ends
170
+ remaining = self.per_session_max_age - session_age
171
+ wait_str = self._fmt_wait(remaining)
172
+ return (
173
+ False,
174
+ f"session request cap {self.per_session_max_req} reached. try again in {wait_str}",
175
+ )
176
+
177
+ # ---- per-IP checks ----
178
+ ip_h = self._get_ip_hour_bucket(ip)
179
+ if ip_h["count"] >= self.per_ip_max_req_hour:
180
+ remaining = ip_h["reset_at"] - now
181
+ wait_str = self._fmt_wait(remaining)
182
+ return (
183
+ False,
184
+ f"ip hourly cap {self.per_ip_max_req_hour} reached. try again in {wait_str}",
185
+ )
186
+
187
+ ip_d = self._get_ip_day_bucket(ip)
188
+ if ip_d["count"] >= self.per_ip_max_req_day:
189
+ remaining = ip_d["reset_at"] - now
190
+ wait_str = self._fmt_wait(remaining)
191
+ return (
192
+ False,
193
+ f"ip daily cap {self.per_ip_max_req_day} reached. try again in {wait_str}",
194
+ )
195
+ if ip_d["active_sec"] >= self.per_ip_max_active_sec_day:
196
+ remaining = ip_d["reset_at"] - now
197
+ wait_str = self._fmt_wait(remaining)
198
+ return (
199
+ False,
200
+ f"ip daily active time cap {self.per_ip_max_active_sec_day} sec reached. try again in {wait_str}",
201
+ )
202
+
203
+ # ---- global checks ----
204
+ g_h = self._get_global_hour()
205
+ if g_h["count"] >= self.global_max_req_hour:
206
+ remaining = g_h["reset_at"] - now
207
+ wait_str = self._fmt_wait(remaining)
208
+ return (
209
+ False,
210
+ f"global hourly cap {self.global_max_req_hour} reached. try again in {wait_str}",
211
+ )
212
+
213
+ g_d = self._get_global_day()
214
+ if g_d["count"] >= self.global_max_req_day:
215
+ remaining = g_d["reset_at"] - now
216
+ wait_str = self._fmt_wait(remaining)
217
+ return (
218
+ False,
219
+ f"global daily cap {self.global_max_req_day} reached. try again in {wait_str}",
220
+ )
221
+ if g_d["active_sec"] >= self.global_max_active_sec_day:
222
+ remaining = g_d["reset_at"] - now
223
+ wait_str = self._fmt_wait(remaining)
224
+ return (
225
+ False,
226
+ f"global daily active time cap {self.global_max_active_sec_day} sec reached. try again in {wait_str}",
227
+ )
228
+ if g_d["cost"] >= self.daily_cost_limit:
229
+ remaining = g_d["reset_at"] - now
230
+ wait_str = self._fmt_wait(remaining)
231
+ return (
232
+ False,
233
+ f"global daily cost cap {self.daily_cost_limit} reached. try again in {wait_str}",
234
+ )
235
+
236
+ g_m_req = self._get_global_month_req()
237
+ if g_m_req["count"] >= self.global_max_req_month:
238
+ remaining = g_m_req["reset_at"] - now
239
+ wait_str = self._fmt_wait(remaining)
240
+ return (
241
+ False,
242
+ f"global monthly cap {self.global_max_req_month} reached. try again in {wait_str}",
243
+ )
244
+
245
+ g_m_cost = self._get_global_month_cost()
246
+ if g_m_cost["cost"] >= self.monthly_cost_limit:
247
+ remaining = g_m_cost["reset_at"] - now
248
+ wait_str = self._fmt_wait(remaining)
249
+ return (
250
+ False,
251
+ f"global monthly cost cap {self.monthly_cost_limit} reached. try again in {wait_str}",
252
+ )
253
+
254
+ # ---- all clear -> increment counters that are request-based ----
255
+ ip_h["count"] += 1
256
+ ip_d["count"] += 1
257
+
258
+ g_h["count"] += 1
259
+ g_d["count"] += 1
260
+ g_m_req["count"] += 1
261
+
262
+ # session count increment
263
+ session_state["count"] = sess_count + 1
264
+ # ensure started_at exists
265
+ session_state.setdefault("started_at", now)
266
+
267
+ return True, None
268
+
269
+ def post_consume(
270
+ self,
271
+ ip: str,
272
+ duration_sec: float,
273
+ ) -> None:
274
+ """
275
+ Update time-based and cost-based buckets AFTER doing work.
276
+ """
277
+ cost = duration_sec * self.cost_per_sec
278
+ with self._lock:
279
+ ip_d = self._get_ip_day_bucket(ip)
280
+ ip_d["active_sec"] += duration_sec
281
+
282
+ g_d = self._get_global_day()
283
+ g_d["active_sec"] += duration_sec
284
+ g_d["cost"] += cost
285
+
286
+ g_m_cost = self._get_global_month_cost()
287
+ g_m_cost["cost"] += cost