astro_diffusion / app.py
Srikasi's picture
Update app.py
3092e54 verified
raw
history blame
4.72 kB
import os, io, base64, time, yaml, requests
from PIL import Image
import gradio as gr
BACKEND_URL = os.getenv("BACKEND_URL", "http://localhost:7861")
cfg = yaml.safe_load(open("configs/infer.yaml"))
def b64_to_img(s: str):
data = base64.b64decode(s)
return Image.open(io.BytesIO(data)).convert("RGB")
def _infer(p, st, sc, h, w, sd, et):
payload = {
"prompt": p,
"steps": int(st),
"scale": float(sc),
"height": int(h),
"width": int(w),
"seed": str(sd),
"eta": float(et),
}
try:
r = requests.post(BACKEND_URL, json=payload, timeout=120)
if r.status_code == 429:
# backend rate limit
blank = Image.new("RGB", (int(w), int(h)), (30, 30, 30))
msg = r.json().get("error", "rate limited by backend")
return blank, blank, msg
r.raise_for_status()
out = r.json()
base_img = b64_to_img(out["base_image"])
lora_img = b64_to_img(out["lora_image"])
return base_img, lora_img, out.get("status", "ok")
except Exception as e:
blank = Image.new("RGB", (int(w), int(h)), (120, 50, 50))
return blank, blank, f"Backend error: {e}"
with gr.Blocks(title="Astro-Diffusion: Base vs LoRA") as demo:
gr.HTML(
"""
<style>
.astro-header {
background: linear-gradient(90deg, #0f172a 0%, #1d4ed8 50%, #0ea5e9 100%);
padding: 0.9rem 1rem 0.85rem 1rem;
border-radius: 0.6rem;
margin-bottom: 0.9rem;
display: flex;
justify-content: space-between;
align-items: center;
gap: 1rem;
}
.astro-title {
color: #ffffff !important;
margin: 0;
font-weight: 700;
letter-spacing: 0.01em;
}
.astro-sub {
color: #ffffff !important;
margin: 0.3rem 0 0 0;
font-style: italic;
font-size: 0.8rem;
}
.astro-badge {
background: #facc15;
color: #0f172a;
padding: 0.4rem 1.05rem;
border-radius: 9999px;
font-weight: 700;
white-space: nowrap;
font-size: 0.95rem;
}
.prompt-panel {
background: #e8fff4;
padding: 0.5rem 0.5rem 0.2rem 0.5rem;
border-radius: 0.5rem;
margin-bottom: 0.5rem;
}
.gradio-container label,
label,
.gradio-container [class*="label"],
.gradio-container [class^="svelte-"][class*="label"],
.gradio-container .block p > label {
color: #000000 !important;
font-weight: 600;
}
.gradio-container [data-testid="block-label"],
.gradio-container [data-testid="block-label"] * {
color: #000000 !important;
font-weight: 600;
}
</style>
<div class="astro-header">
<div>
<h2 class="astro-title">Astro-Diffusion : Base SD vs custom LoRA</h2>
<p class="astro-sub">Video generation and more features coming up..!</p>
</div>
<div class="astro-badge">by Srivatsava Kasibhatla</div>
</div>
"""
)
with gr.Group(elem_classes=["prompt-panel"]):
prompt = gr.Textbox(
value="a high-resolution spiral galaxy with blue star-forming arms and a bright yellow core",
label="Prompt",
)
with gr.Row():
steps = gr.Slider(10, 60, value=cfg.get("num_inference_steps", 30), step=1, label="Steps")
scale = gr.Slider(1.0, 12.0, value=cfg.get("guidance_scale", 7.5), step=0.5, label="Guidance")
height = gr.Number(value=min(int(cfg.get("height", 512)), 512), label="Height", minimum=32, maximum=512)
width = gr.Number(value=min(int(cfg.get("width", 512)), 512), label="Width", minimum=32, maximum=512)
seed = gr.Textbox(value=str(cfg.get("seed", 1234)), label="Seed")
eta = gr.Slider(0.0, 1.0, value=0.0, step=0.01, label="Eta")
btn = gr.Button("Generate")
out_base = gr.Image(label="Base Model Output")
out_lora = gr.Image(label="LoRA Model Output")
status = gr.Textbox(label="Status", interactive=False)
btn.click(
_infer,
[prompt, steps, scale, height, width, seed, eta],
[out_base, out_lora, status],
)
if __name__ == "__main__":
interface = build_ui()
port = int(os.getenv("PORT", os.getenv("GRADIO_SERVER_PORT", "7861")))
share = os.getenv("GRADIO_PUBLIC_SHARE", "True").lower() == "true"
interface.launch(server_name="0.0.0.0", server_port=port, share=share)