Spaces:
Runtime error
Runtime error
| # astro_diffusion/ratelimits.py | |
| import time | |
| import threading | |
| import os | |
| from typing import Tuple, Optional, Dict, Any | |
| class RateLimiter: | |
| """ | |
| In-memory, multi-scope limiter. | |
| Scopes: | |
| - per-session | |
| - per-IP: hour, day | |
| - global: hour, day, month | |
| - time/cost: per-IP day, global day/month | |
| All counters reset on container restart. | |
| """ | |
| def __init__(self): | |
| # ---------------- session limits ---------------- | |
| self.per_session_max_req = int(os.getenv("AD_SESSION_MAX_REQ", "5")) | |
| self.per_session_max_age = int(os.getenv("AD_SESSION_MAX_AGE_SEC", str(15 * 60))) # 15 min | |
| # ---------------- per-IP limits ---------------- | |
| # default per your request: 50/hr, 100/day | |
| self.per_ip_max_req_hour = int(os.getenv("AD_IP_MAX_REQ_HOUR", "10")) | |
| self.per_ip_max_req_day = int(os.getenv("AD_IP_MAX_REQ_DAY", "100")) | |
| self.per_ip_max_active_sec_day = int( | |
| os.getenv("AD_IP_MAX_ACTIVE_SEC_DAY", str(60 * 60)) | |
| ) # 1h active time per day | |
| # ---------------- global limits ---------------- | |
| # default per your request: 50/hr, 100/day, 500/month | |
| self.global_max_req_hour = int(os.getenv("AD_GLOBAL_MAX_REQ_HOUR", "50")) | |
| self.global_max_req_day = int(os.getenv("AD_GLOBAL_MAX_REQ_DAY", "100")) | |
| self.global_max_req_month = int(os.getenv("AD_GLOBAL_MAX_REQ_MONTH", "500")) | |
| self.global_max_active_sec_day = int( | |
| os.getenv("AD_GLOBAL_MAX_ACTIVE_SEC_DAY", str(6 * 60 * 60)) | |
| ) # 6h | |
| # ---------------- cost limits ---------------- | |
| self.cost_per_sec = float(os.getenv("AD_COST_PER_SEC", "0.0005")) | |
| self.daily_cost_limit = float(os.getenv("AD_DAILY_COST_LIMIT", "5.0")) | |
| self.monthly_cost_limit = float(os.getenv("AD_MONTHLY_COST_LIMIT", "10.0")) | |
| self._lock = threading.Lock() | |
| # per-ip buckets | |
| self._ip_hour: Dict[str, Dict[str, float]] = {} | |
| self._ip_day: Dict[str, Dict[str, float]] = {} | |
| # global buckets | |
| now = time.time() | |
| self._global_hour = {"count": 0, "reset_at": now + 3600} | |
| self._global_day = { | |
| "count": 0, | |
| "active_sec": 0.0, | |
| "cost": 0.0, | |
| "reset_at": now + 86400, | |
| } | |
| # requests per month | |
| self._global_month_req = {"count": 0, "reset_at": now + 30 * 86400} | |
| # cost per month | |
| self._global_month_cost = {"cost": 0.0, "reset_at": now + 30 * 86400} | |
| # ------------------------------------------------- | |
| # helpers | |
| # ------------------------------------------------- | |
| def _fmt_wait(seconds: float) -> str: | |
| if seconds < 0: | |
| seconds = 0 | |
| # choose largest sensible unit | |
| if seconds >= 3600: | |
| h = seconds / 3600.0 | |
| return f"{h:.1f} hours" | |
| elif seconds >= 60: | |
| m = seconds / 60.0 | |
| return f"{m:.1f} minutes" | |
| else: | |
| return f"{seconds:.0f} seconds" | |
| def _get_ip_hour_bucket(self, ip: str): | |
| now = time.time() | |
| b = self._ip_hour.get(ip) | |
| if b is None or now >= b["reset_at"]: | |
| b = {"count": 0, "reset_at": now + 3600} | |
| self._ip_hour[ip] = b | |
| return b | |
| def _get_ip_day_bucket(self, ip: str): | |
| now = time.time() | |
| b = self._ip_day.get(ip) | |
| if b is None or now >= b["reset_at"]: | |
| b = {"count": 0, "active_sec": 0.0, "reset_at": now + 86400} | |
| self._ip_day[ip] = b | |
| return b | |
| def _get_global_hour(self): | |
| now = time.time() | |
| g = self._global_hour | |
| if now >= g["reset_at"]: | |
| g["count"] = 0 | |
| g["reset_at"] = now + 3600 | |
| return g | |
| def _get_global_day(self): | |
| now = time.time() | |
| g = self._global_day | |
| if now >= g["reset_at"]: | |
| g["count"] = 0 | |
| g["active_sec"] = 0.0 | |
| g["cost"] = 0.0 | |
| g["reset_at"] = now + 86400 | |
| return g | |
| def _get_global_month_req(self): | |
| now = time.time() | |
| g = self._global_month_req | |
| if now >= g["reset_at"]: | |
| g["count"] = 0 | |
| g["reset_at"] = now + 30 * 86400 | |
| return g | |
| def _get_global_month_cost(self): | |
| now = time.time() | |
| g = self._global_month_cost | |
| if now >= g["reset_at"]: | |
| g["cost"] = 0.0 | |
| g["reset_at"] = now + 30 * 86400 | |
| return g | |
| # ------------------------------------------------- | |
| # public API | |
| # ------------------------------------------------- | |
| def pre_check( | |
| self, | |
| ip: str, | |
| session_state: Dict[str, Any], | |
| ) -> Tuple[bool, Optional[str]]: | |
| """ | |
| Check count/time/cost limits BEFORE doing work. | |
| session_state is expected to have: | |
| { | |
| "count": int, | |
| "started_at": float | |
| } | |
| """ | |
| now = time.time() | |
| with self._lock: | |
| # ---- session checks ---- | |
| sess_count = int(session_state.get("count", 0)) | |
| sess_started = float(session_state.get("started_at", now)) | |
| # session age | |
| session_age = now - sess_started | |
| if session_age > self.per_session_max_age: | |
| # session too old | |
| wait_str = self._fmt_wait(0) | |
| return ( | |
| False, | |
| f"session time cap reached ({self.per_session_max_age} sec). try again in {wait_str}", | |
| ) | |
| # session request count | |
| if sess_count >= self.per_session_max_req: | |
| # they must wait until session-age window ends | |
| remaining = self.per_session_max_age - session_age | |
| wait_str = self._fmt_wait(remaining) | |
| return ( | |
| False, | |
| f"session request cap {self.per_session_max_req} reached. try again in {wait_str}", | |
| ) | |
| # ---- per-IP checks ---- | |
| ip_h = self._get_ip_hour_bucket(ip) | |
| if ip_h["count"] >= self.per_ip_max_req_hour: | |
| remaining = ip_h["reset_at"] - now | |
| wait_str = self._fmt_wait(remaining) | |
| return ( | |
| False, | |
| f"ip hourly cap {self.per_ip_max_req_hour} reached. try again in {wait_str}", | |
| ) | |
| ip_d = self._get_ip_day_bucket(ip) | |
| if ip_d["count"] >= self.per_ip_max_req_day: | |
| remaining = ip_d["reset_at"] - now | |
| wait_str = self._fmt_wait(remaining) | |
| return ( | |
| False, | |
| f"ip daily cap {self.per_ip_max_req_day} reached. try again in {wait_str}", | |
| ) | |
| if ip_d["active_sec"] >= self.per_ip_max_active_sec_day: | |
| remaining = ip_d["reset_at"] - now | |
| wait_str = self._fmt_wait(remaining) | |
| return ( | |
| False, | |
| f"ip daily active time cap {self.per_ip_max_active_sec_day} sec reached. try again in {wait_str}", | |
| ) | |
| # ---- global checks ---- | |
| g_h = self._get_global_hour() | |
| if g_h["count"] >= self.global_max_req_hour: | |
| remaining = g_h["reset_at"] - now | |
| wait_str = self._fmt_wait(remaining) | |
| return ( | |
| False, | |
| f"global hourly cap {self.global_max_req_hour} reached. try again in {wait_str}", | |
| ) | |
| g_d = self._get_global_day() | |
| if g_d["count"] >= self.global_max_req_day: | |
| remaining = g_d["reset_at"] - now | |
| wait_str = self._fmt_wait(remaining) | |
| return ( | |
| False, | |
| f"global daily cap {self.global_max_req_day} reached. try again in {wait_str}", | |
| ) | |
| if g_d["active_sec"] >= self.global_max_active_sec_day: | |
| remaining = g_d["reset_at"] - now | |
| wait_str = self._fmt_wait(remaining) | |
| return ( | |
| False, | |
| f"global daily active time cap {self.global_max_active_sec_day} sec reached. try again in {wait_str}", | |
| ) | |
| if g_d["cost"] >= self.daily_cost_limit: | |
| remaining = g_d["reset_at"] - now | |
| wait_str = self._fmt_wait(remaining) | |
| return ( | |
| False, | |
| f"global daily cost cap {self.daily_cost_limit} reached. try again in {wait_str}", | |
| ) | |
| g_m_req = self._get_global_month_req() | |
| if g_m_req["count"] >= self.global_max_req_month: | |
| remaining = g_m_req["reset_at"] - now | |
| wait_str = self._fmt_wait(remaining) | |
| return ( | |
| False, | |
| f"global monthly cap {self.global_max_req_month} reached. try again in {wait_str}", | |
| ) | |
| g_m_cost = self._get_global_month_cost() | |
| if g_m_cost["cost"] >= self.monthly_cost_limit: | |
| remaining = g_m_cost["reset_at"] - now | |
| wait_str = self._fmt_wait(remaining) | |
| return ( | |
| False, | |
| f"global monthly cost cap {self.monthly_cost_limit} reached. try again in {wait_str}", | |
| ) | |
| # ---- all clear -> increment counters that are request-based ---- | |
| ip_h["count"] += 1 | |
| ip_d["count"] += 1 | |
| g_h["count"] += 1 | |
| g_d["count"] += 1 | |
| g_m_req["count"] += 1 | |
| # session count increment | |
| session_state["count"] = sess_count + 1 | |
| # ensure started_at exists | |
| session_state.setdefault("started_at", now) | |
| return True, None | |
| def post_consume( | |
| self, | |
| ip: str, | |
| duration_sec: float, | |
| ) -> None: | |
| """ | |
| Update time-based and cost-based buckets AFTER doing work. | |
| """ | |
| cost = duration_sec * self.cost_per_sec | |
| with self._lock: | |
| ip_d = self._get_ip_day_bucket(ip) | |
| ip_d["active_sec"] += duration_sec | |
| g_d = self._get_global_day() | |
| g_d["active_sec"] += duration_sec | |
| g_d["cost"] += cost | |
| g_m_cost = self._get_global_month_cost() | |
| g_m_cost["cost"] += cost |