astro_diffusion / ratelimits.py
Srikasi's picture
Create ratelimits.py
e910d0b verified
raw
history blame
10.4 kB
# astro_diffusion/ratelimits.py
import time
import threading
import os
from typing import Tuple, Optional, Dict, Any
class RateLimiter:
"""
In-memory, multi-scope limiter.
Scopes:
- per-session
- per-IP: hour, day
- global: hour, day, month
- time/cost: per-IP day, global day/month
All counters reset on container restart.
"""
def __init__(self):
# ---------------- session limits ----------------
self.per_session_max_req = int(os.getenv("AD_SESSION_MAX_REQ", "5"))
self.per_session_max_age = int(os.getenv("AD_SESSION_MAX_AGE_SEC", str(15 * 60))) # 15 min
# ---------------- per-IP limits ----------------
# default per your request: 50/hr, 100/day
self.per_ip_max_req_hour = int(os.getenv("AD_IP_MAX_REQ_HOUR", "10"))
self.per_ip_max_req_day = int(os.getenv("AD_IP_MAX_REQ_DAY", "100"))
self.per_ip_max_active_sec_day = int(
os.getenv("AD_IP_MAX_ACTIVE_SEC_DAY", str(60 * 60))
) # 1h active time per day
# ---------------- global limits ----------------
# default per your request: 50/hr, 100/day, 500/month
self.global_max_req_hour = int(os.getenv("AD_GLOBAL_MAX_REQ_HOUR", "50"))
self.global_max_req_day = int(os.getenv("AD_GLOBAL_MAX_REQ_DAY", "100"))
self.global_max_req_month = int(os.getenv("AD_GLOBAL_MAX_REQ_MONTH", "500"))
self.global_max_active_sec_day = int(
os.getenv("AD_GLOBAL_MAX_ACTIVE_SEC_DAY", str(6 * 60 * 60))
) # 6h
# ---------------- cost limits ----------------
self.cost_per_sec = float(os.getenv("AD_COST_PER_SEC", "0.0005"))
self.daily_cost_limit = float(os.getenv("AD_DAILY_COST_LIMIT", "5.0"))
self.monthly_cost_limit = float(os.getenv("AD_MONTHLY_COST_LIMIT", "10.0"))
self._lock = threading.Lock()
# per-ip buckets
self._ip_hour: Dict[str, Dict[str, float]] = {}
self._ip_day: Dict[str, Dict[str, float]] = {}
# global buckets
now = time.time()
self._global_hour = {"count": 0, "reset_at": now + 3600}
self._global_day = {
"count": 0,
"active_sec": 0.0,
"cost": 0.0,
"reset_at": now + 86400,
}
# requests per month
self._global_month_req = {"count": 0, "reset_at": now + 30 * 86400}
# cost per month
self._global_month_cost = {"cost": 0.0, "reset_at": now + 30 * 86400}
# -------------------------------------------------
# helpers
# -------------------------------------------------
@staticmethod
def _fmt_wait(seconds: float) -> str:
if seconds < 0:
seconds = 0
# choose largest sensible unit
if seconds >= 3600:
h = seconds / 3600.0
return f"{h:.1f} hours"
elif seconds >= 60:
m = seconds / 60.0
return f"{m:.1f} minutes"
else:
return f"{seconds:.0f} seconds"
def _get_ip_hour_bucket(self, ip: str):
now = time.time()
b = self._ip_hour.get(ip)
if b is None or now >= b["reset_at"]:
b = {"count": 0, "reset_at": now + 3600}
self._ip_hour[ip] = b
return b
def _get_ip_day_bucket(self, ip: str):
now = time.time()
b = self._ip_day.get(ip)
if b is None or now >= b["reset_at"]:
b = {"count": 0, "active_sec": 0.0, "reset_at": now + 86400}
self._ip_day[ip] = b
return b
def _get_global_hour(self):
now = time.time()
g = self._global_hour
if now >= g["reset_at"]:
g["count"] = 0
g["reset_at"] = now + 3600
return g
def _get_global_day(self):
now = time.time()
g = self._global_day
if now >= g["reset_at"]:
g["count"] = 0
g["active_sec"] = 0.0
g["cost"] = 0.0
g["reset_at"] = now + 86400
return g
def _get_global_month_req(self):
now = time.time()
g = self._global_month_req
if now >= g["reset_at"]:
g["count"] = 0
g["reset_at"] = now + 30 * 86400
return g
def _get_global_month_cost(self):
now = time.time()
g = self._global_month_cost
if now >= g["reset_at"]:
g["cost"] = 0.0
g["reset_at"] = now + 30 * 86400
return g
# -------------------------------------------------
# public API
# -------------------------------------------------
def pre_check(
self,
ip: str,
session_state: Dict[str, Any],
) -> Tuple[bool, Optional[str]]:
"""
Check count/time/cost limits BEFORE doing work.
session_state is expected to have:
{
"count": int,
"started_at": float
}
"""
now = time.time()
with self._lock:
# ---- session checks ----
sess_count = int(session_state.get("count", 0))
sess_started = float(session_state.get("started_at", now))
# session age
session_age = now - sess_started
if session_age > self.per_session_max_age:
# session too old
wait_str = self._fmt_wait(0)
return (
False,
f"session time cap reached ({self.per_session_max_age} sec). try again in {wait_str}",
)
# session request count
if sess_count >= self.per_session_max_req:
# they must wait until session-age window ends
remaining = self.per_session_max_age - session_age
wait_str = self._fmt_wait(remaining)
return (
False,
f"session request cap {self.per_session_max_req} reached. try again in {wait_str}",
)
# ---- per-IP checks ----
ip_h = self._get_ip_hour_bucket(ip)
if ip_h["count"] >= self.per_ip_max_req_hour:
remaining = ip_h["reset_at"] - now
wait_str = self._fmt_wait(remaining)
return (
False,
f"ip hourly cap {self.per_ip_max_req_hour} reached. try again in {wait_str}",
)
ip_d = self._get_ip_day_bucket(ip)
if ip_d["count"] >= self.per_ip_max_req_day:
remaining = ip_d["reset_at"] - now
wait_str = self._fmt_wait(remaining)
return (
False,
f"ip daily cap {self.per_ip_max_req_day} reached. try again in {wait_str}",
)
if ip_d["active_sec"] >= self.per_ip_max_active_sec_day:
remaining = ip_d["reset_at"] - now
wait_str = self._fmt_wait(remaining)
return (
False,
f"ip daily active time cap {self.per_ip_max_active_sec_day} sec reached. try again in {wait_str}",
)
# ---- global checks ----
g_h = self._get_global_hour()
if g_h["count"] >= self.global_max_req_hour:
remaining = g_h["reset_at"] - now
wait_str = self._fmt_wait(remaining)
return (
False,
f"global hourly cap {self.global_max_req_hour} reached. try again in {wait_str}",
)
g_d = self._get_global_day()
if g_d["count"] >= self.global_max_req_day:
remaining = g_d["reset_at"] - now
wait_str = self._fmt_wait(remaining)
return (
False,
f"global daily cap {self.global_max_req_day} reached. try again in {wait_str}",
)
if g_d["active_sec"] >= self.global_max_active_sec_day:
remaining = g_d["reset_at"] - now
wait_str = self._fmt_wait(remaining)
return (
False,
f"global daily active time cap {self.global_max_active_sec_day} sec reached. try again in {wait_str}",
)
if g_d["cost"] >= self.daily_cost_limit:
remaining = g_d["reset_at"] - now
wait_str = self._fmt_wait(remaining)
return (
False,
f"global daily cost cap {self.daily_cost_limit} reached. try again in {wait_str}",
)
g_m_req = self._get_global_month_req()
if g_m_req["count"] >= self.global_max_req_month:
remaining = g_m_req["reset_at"] - now
wait_str = self._fmt_wait(remaining)
return (
False,
f"global monthly cap {self.global_max_req_month} reached. try again in {wait_str}",
)
g_m_cost = self._get_global_month_cost()
if g_m_cost["cost"] >= self.monthly_cost_limit:
remaining = g_m_cost["reset_at"] - now
wait_str = self._fmt_wait(remaining)
return (
False,
f"global monthly cost cap {self.monthly_cost_limit} reached. try again in {wait_str}",
)
# ---- all clear -> increment counters that are request-based ----
ip_h["count"] += 1
ip_d["count"] += 1
g_h["count"] += 1
g_d["count"] += 1
g_m_req["count"] += 1
# session count increment
session_state["count"] = sess_count + 1
# ensure started_at exists
session_state.setdefault("started_at", now)
return True, None
def post_consume(
self,
ip: str,
duration_sec: float,
) -> None:
"""
Update time-based and cost-based buckets AFTER doing work.
"""
cost = duration_sec * self.cost_per_sec
with self._lock:
ip_d = self._get_ip_day_bucket(ip)
ip_d["active_sec"] += duration_sec
g_d = self._get_global_day()
g_d["active_sec"] += duration_sec
g_d["cost"] += cost
g_m_cost = self._get_global_month_cost()
g_m_cost["cost"] += cost