import os, io, base64, time, yaml, requests from PIL import Image import gradio as gr from ratelimits import RateLimiter BACKEND_URL = os.getenv("BACKEND_URL") cfg = yaml.safe_load(open("configs/infer.yaml")) limiter = RateLimiter() def b64_to_img(s: str): data = base64.b64decode(s) return Image.open(io.BytesIO(data)).convert("RGB") def _infer(p, st, sc, h, w, sd, et, sess_state, request: gr.Request): # same session/IP logic as yours if request is not None and request.headers: hdrs = {k.lower(): v for k, v in request.headers.items()} xff = hdrs.get("x-forwarded-for") if xff: ip = xff.split(",")[0].strip() elif request.client: ip = request.client.host else: ip = "unknown" else: ip = "unknown" now = time.time() if "started_at" not in sess_state: sess_state["started_at"] = now if "count" not in sess_state: sess_state["count"] = 0 if now - sess_state["started_at"] > limiter.per_session_max_age: sess_state["started_at"] = now sess_state["count"] = 0 ok, reason = limiter.pre_check(ip, sess_state) if not ok: blank = Image.new("RGB", (int(w), int(h)), (30, 30, 30)) return blank, blank, f"Rate limited: {reason}", sess_state payload = { "prompt": p, "steps": int(st), "scale": float(sc), "height": int(h), "width": int(w), "seed": str(sd), "eta": float(et), } resp = requests.post(BACKEND_URL, json=payload, timeout=120) resp.raise_for_status() out = resp.json() base_img = b64_to_img(out["base_image"]) lora_img = b64_to_img(out["lora_image"]) limiter.post_consume(ip, out.get("duration", 0.0)) return base_img, lora_img, out.get("status", "ok"), sess_state with gr.Blocks(title="Astro-Diffusion: Base vs LoRA") as demo: session_state = gr.State({"count": 0, "started_at": time.time()}) gr.HTML( """

Astro-Diffusion : Base SD vs custom LoRA

Video generation and more features coming up..!

by Srivatsava Kasibhatla
""" ) with gr.Group(elem_classes=["prompt-panel"]): prompt = gr.Textbox( value="a high-resolution spiral galaxy with blue star-forming arms and a bright yellow core", label="Prompt", ) with gr.Row(): steps = gr.Slider(10, 60, value=cfg.get("num_inference_steps", 30), step=1, label="Steps") scale = gr.Slider(1.0, 12.0, value=cfg.get("guidance_scale", 7.5), step=0.5, label="Guidance") height = gr.Number(value=min(int(cfg.get("height", 512)), 512), label="Height", minimum=32, maximum=512) width = gr.Number(value=min(int(cfg.get("width", 512)), 512), label="Width", minimum=32, maximum=512) seed = gr.Textbox(value=str(cfg.get("seed", 1234)), label="Seed") eta = gr.Slider(0.0, 1.0, value=0.0, step=0.01, label="Eta") btn = gr.Button("Generate") out_base = gr.Image(label="Base Model Output") out_lora = gr.Image(label="LoRA Model Output") status = gr.Textbox(label="Status", interactive=False) btn.click( _infer, [prompt, steps, scale, height, width, seed, eta, session_state], [out_base, out_lora, status, session_state], ) port = int(os.getenv("GRADIO_SERVER_PORT", "7861")) share = os.getenv("GRADIO_PUBLIC_SHARE", "False") == "True" demo.launch( server_name="0.0.0.0", server_port=port, show_error=True, share=share, )