Spaces:
Runtime error
Runtime error
| import os, io, base64, time, yaml, requests | |
| from PIL import Image | |
| import gradio as gr | |
| BACKEND_URL = os.getenv("BACKEND_URL", "http://localhost:7861") | |
| cfg = yaml.safe_load(open("configs/infer.yaml")) | |
| def b64_to_img(s: str): | |
| data = base64.b64decode(s) | |
| return Image.open(io.BytesIO(data)).convert("RGB") | |
| def _infer(p, st, sc, h, w, sd, et): | |
| payload = { | |
| "prompt": p, | |
| "steps": int(st), | |
| "scale": float(sc), | |
| "height": int(h), | |
| "width": int(w), | |
| "seed": str(sd), | |
| "eta": float(et), | |
| } | |
| try: | |
| r = requests.post(BACKEND_URL, json=payload, timeout=120) | |
| if r.status_code == 429: | |
| # backend rate limit | |
| blank = Image.new("RGB", (int(w), int(h)), (30, 30, 30)) | |
| msg = r.json().get("error", "rate limited by backend") | |
| return blank, blank, msg | |
| r.raise_for_status() | |
| out = r.json() | |
| base_img = b64_to_img(out["base_image"]) | |
| lora_img = b64_to_img(out["lora_image"]) | |
| return base_img, lora_img, out.get("status", "ok") | |
| except Exception as e: | |
| blank = Image.new("RGB", (int(w), int(h)), (120, 50, 50)) | |
| return blank, blank, f"Backend error: {e}" | |
| with gr.Blocks(title="Astro-Diffusion: Base vs LoRA") as demo: | |
| gr.HTML( | |
| """ | |
| <style> | |
| .astro-header { | |
| background: linear-gradient(90deg, #0f172a 0%, #1d4ed8 50%, #0ea5e9 100%); | |
| padding: 0.9rem 1rem 0.85rem 1rem; | |
| border-radius: 0.6rem; | |
| margin-bottom: 0.9rem; | |
| display: flex; | |
| justify-content: space-between; | |
| align-items: center; | |
| gap: 1rem; | |
| } | |
| .astro-title { | |
| color: #ffffff !important; | |
| margin: 0; | |
| font-weight: 700; | |
| letter-spacing: 0.01em; | |
| } | |
| .astro-sub { | |
| color: #ffffff !important; | |
| margin: 0.3rem 0 0 0; | |
| font-style: italic; | |
| font-size: 0.8rem; | |
| } | |
| .astro-badge { | |
| background: #facc15; | |
| color: #0f172a; | |
| padding: 0.4rem 1.05rem; | |
| border-radius: 9999px; | |
| font-weight: 700; | |
| white-space: nowrap; | |
| font-size: 0.95rem; | |
| } | |
| .prompt-panel { | |
| background: #e8fff4; | |
| padding: 0.5rem 0.5rem 0.2rem 0.5rem; | |
| border-radius: 0.5rem; | |
| margin-bottom: 0.5rem; | |
| } | |
| .gradio-container label, | |
| label, | |
| .gradio-container [class*="label"], | |
| .gradio-container [class^="svelte-"][class*="label"], | |
| .gradio-container .block p > label { | |
| color: #000000 !important; | |
| font-weight: 600; | |
| } | |
| .gradio-container [data-testid="block-label"], | |
| .gradio-container [data-testid="block-label"] * { | |
| color: #000000 !important; | |
| font-weight: 600; | |
| } | |
| </style> | |
| <div class="astro-header"> | |
| <div> | |
| <h2 class="astro-title">Astro-Diffusion : Base SD vs custom LoRA</h2> | |
| <p class="astro-sub">Video generation and more features coming up..!</p> | |
| </div> | |
| <div class="astro-badge">by Srivatsava Kasibhatla</div> | |
| </div> | |
| """ | |
| ) | |
| with gr.Group(elem_classes=["prompt-panel"]): | |
| prompt = gr.Textbox( | |
| value="a high-resolution spiral galaxy with blue star-forming arms and a bright yellow core", | |
| label="Prompt", | |
| ) | |
| with gr.Row(): | |
| steps = gr.Slider(10, 60, value=cfg.get("num_inference_steps", 30), step=1, label="Steps") | |
| scale = gr.Slider(1.0, 12.0, value=cfg.get("guidance_scale", 7.5), step=0.5, label="Guidance") | |
| height = gr.Number(value=min(int(cfg.get("height", 512)), 512), label="Height", minimum=32, maximum=512) | |
| width = gr.Number(value=min(int(cfg.get("width", 512)), 512), label="Width", minimum=32, maximum=512) | |
| seed = gr.Textbox(value=str(cfg.get("seed", 1234)), label="Seed") | |
| eta = gr.Slider(0.0, 1.0, value=0.0, step=0.01, label="Eta") | |
| btn = gr.Button("Generate") | |
| out_base = gr.Image(label="Base Model Output") | |
| out_lora = gr.Image(label="LoRA Model Output") | |
| status = gr.Textbox(label="Status", interactive=False) | |
| btn.click( | |
| _infer, | |
| [prompt, steps, scale, height, width, seed, eta], | |
| [out_base, out_lora, status], | |
| ) | |
| if __name__ == "__main__": | |
| interface = build_ui() | |
| port = int(os.getenv("PORT", os.getenv("GRADIO_SERVER_PORT", "7861"))) | |
| share = os.getenv("GRADIO_PUBLIC_SHARE", "True").lower() == "true" | |
| interface.launch(server_name="0.0.0.0", server_port=port, share=share) | |