import os, io, base64, time, yaml, requests from PIL import Image import gradio as gr BACKEND_URL = os.getenv("BACKEND_URL", "http://localhost:7861") cfg = yaml.safe_load(open("configs/infer.yaml")) def b64_to_img(s: str): data = base64.b64decode(s) return Image.open(io.BytesIO(data)).convert("RGB") def _infer(p, st, sc, h, w, sd, et): payload = { "prompt": p, "steps": int(st), "scale": float(sc), "height": int(h), "width": int(w), "seed": str(sd), "eta": float(et), } try: r = requests.post(BACKEND_URL, json=payload, timeout=120) if r.status_code == 429: # backend rate limit blank = Image.new("RGB", (int(w), int(h)), (30, 30, 30)) msg = r.json().get("error", "rate limited by backend") return blank, blank, msg r.raise_for_status() out = r.json() base_img = b64_to_img(out["base_image"]) lora_img = b64_to_img(out["lora_image"]) return base_img, lora_img, out.get("status", "ok") except Exception as e: blank = Image.new("RGB", (int(w), int(h)), (120, 50, 50)) return blank, blank, f"Backend error: {e}" with gr.Blocks(title="Astro-Diffusion: Base vs LoRA") as demo: gr.HTML( """

Astro-Diffusion : Base SD vs custom LoRA

Video generation and more features coming up..!

by Srivatsava Kasibhatla
""" ) with gr.Group(elem_classes=["prompt-panel"]): prompt = gr.Textbox( value="a high-resolution spiral galaxy with blue star-forming arms and a bright yellow core", label="Prompt", ) with gr.Row(): steps = gr.Slider(10, 60, value=cfg.get("num_inference_steps", 30), step=1, label="Steps") scale = gr.Slider(1.0, 12.0, value=cfg.get("guidance_scale", 7.5), step=0.5, label="Guidance") height = gr.Number(value=min(int(cfg.get("height", 512)), 512), label="Height", minimum=32, maximum=512) width = gr.Number(value=min(int(cfg.get("width", 512)), 512), label="Width", minimum=32, maximum=512) seed = gr.Textbox(value=str(cfg.get("seed", 1234)), label="Seed") eta = gr.Slider(0.0, 1.0, value=0.0, step=0.01, label="Eta") btn = gr.Button("Generate") out_base = gr.Image(label="Base Model Output") out_lora = gr.Image(label="LoRA Model Output") status = gr.Textbox(label="Status", interactive=False) btn.click( _infer, [prompt, steps, scale, height, width, seed, eta], [out_base, out_lora, status], ) if __name__ == "__main__": interface = build_ui() port = int(os.getenv("PORT", os.getenv("GRADIO_SERVER_PORT", "7861"))) share = os.getenv("GRADIO_PUBLIC_SHARE", "True").lower() == "true" interface.launch(server_name="0.0.0.0", server_port=port, share=share)