Srikasi commited on
Commit
9fe3cf1
·
verified ·
1 Parent(s): e910d0b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -0
app.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, io, base64, time, yaml, requests
2
+ from PIL import Image
3
+ import gradio as gr
4
+ from ratelimits import RateLimiter
5
+
6
+ BACKEND_URL = os.getenv("BACKEND_URL")
7
+
8
+ cfg = yaml.safe_load(open("configs/infer.yaml"))
9
+ limiter = RateLimiter()
10
+
11
+ def b64_to_img(s: str):
12
+ data = base64.b64decode(s)
13
+ return Image.open(io.BytesIO(data)).convert("RGB")
14
+
15
+ def _infer(p, st, sc, h, w, sd, et, sess_state, request: gr.Request):
16
+ # same session/IP logic as yours
17
+ if request is not None and request.headers:
18
+ hdrs = {k.lower(): v for k, v in request.headers.items()}
19
+ xff = hdrs.get("x-forwarded-for")
20
+ if xff:
21
+ ip = xff.split(",")[0].strip()
22
+ elif request.client:
23
+ ip = request.client.host
24
+ else:
25
+ ip = "unknown"
26
+ else:
27
+ ip = "unknown"
28
+
29
+ now = time.time()
30
+ if "started_at" not in sess_state:
31
+ sess_state["started_at"] = now
32
+ if "count" not in sess_state:
33
+ sess_state["count"] = 0
34
+ if now - sess_state["started_at"] > limiter.per_session_max_age:
35
+ sess_state["started_at"] = now
36
+ sess_state["count"] = 0
37
+
38
+ ok, reason = limiter.pre_check(ip, sess_state)
39
+ if not ok:
40
+ blank = Image.new("RGB", (int(w), int(h)), (30, 30, 30))
41
+ return blank, blank, f"Rate limited: {reason}", sess_state
42
+
43
+ payload = {
44
+ "prompt": p,
45
+ "steps": int(st),
46
+ "scale": float(sc),
47
+ "height": int(h),
48
+ "width": int(w),
49
+ "seed": str(sd),
50
+ "eta": float(et),
51
+ }
52
+ resp = requests.post(BACKEND_URL, json=payload, timeout=120)
53
+ resp.raise_for_status()
54
+ out = resp.json()
55
+ base_img = b64_to_img(out["base_image"])
56
+ lora_img = b64_to_img(out["lora_image"])
57
+
58
+ limiter.post_consume(ip, out.get("duration", 0.0))
59
+
60
+ return base_img, lora_img, out.get("status", "ok"), sess_state
61
+
62
+
63
+ with gr.Blocks(title="Astro-Diffusion: Base vs LoRA") as demo:
64
+ session_state = gr.State({"count": 0, "started_at": time.time()})
65
+
66
+ gr.HTML(
67
+ """
68
+ <style>
69
+ .astro-header {
70
+ background: linear-gradient(90deg, #0f172a 0%, #1d4ed8 50%, #0ea5e9 100%);
71
+ padding: 0.9rem 1rem 0.85rem 1rem;
72
+ border-radius: 0.6rem;
73
+ margin-bottom: 0.9rem;
74
+ display: flex;
75
+ justify-content: space-between;
76
+ align-items: center;
77
+ gap: 1rem;
78
+ }
79
+ .astro-title {
80
+ color: #ffffff !important;
81
+ margin: 0;
82
+ font-weight: 700;
83
+ letter-spacing: 0.01em;
84
+ }
85
+ .astro-sub {
86
+ color: #ffffff !important;
87
+ margin: 0.3rem 0 0 0;
88
+ font-style: italic;
89
+ font-size: 0.8rem;
90
+ }
91
+ .astro-badge {
92
+ background: #facc15;
93
+ color: #0f172a;
94
+ padding: 0.4rem 1.05rem;
95
+ border-radius: 9999px;
96
+ font-weight: 700;
97
+ white-space: nowrap;
98
+ font-size: 0.95rem;
99
+ }
100
+ .prompt-panel {
101
+ background: #e8fff4;
102
+ padding: 0.5rem 0.5rem 0.2rem 0.5rem;
103
+ border-radius: 0.5rem;
104
+ margin-bottom: 0.5rem;
105
+ }
106
+ .gradio-container label,
107
+ label,
108
+ .gradio-container [class*="label"],
109
+ .gradio-container [class^="svelte-"][class*="label"],
110
+ .gradio-container .block p > label {
111
+ color: #000000 !important;
112
+ font-weight: 600;
113
+ }
114
+ .gradio-container [data-testid="block-label"],
115
+ .gradio-container [data-testid="block-label"] * {
116
+ color: #000000 !important;
117
+ font-weight: 600;
118
+ }
119
+ </style>
120
+ <div class="astro-header">
121
+ <div>
122
+ <h2 class="astro-title">Astro-Diffusion : Base SD vs custom LoRA</h2>
123
+ <p class="astro-sub">Video generation and more features coming up..!</p>
124
+ </div>
125
+ <div class="astro-badge">by Srivatsava Kasibhatla</div>
126
+ </div>
127
+ """
128
+ )
129
+
130
+ with gr.Group(elem_classes=["prompt-panel"]):
131
+ prompt = gr.Textbox(
132
+ value="a high-resolution spiral galaxy with blue star-forming arms and a bright yellow core",
133
+ label="Prompt",
134
+ )
135
+
136
+ with gr.Row():
137
+ steps = gr.Slider(10, 60, value=cfg.get("num_inference_steps", 30), step=1, label="Steps")
138
+ scale = gr.Slider(1.0, 12.0, value=cfg.get("guidance_scale", 7.5), step=0.5, label="Guidance")
139
+ height = gr.Number(value=min(int(cfg.get("height", 512)), 512), label="Height", minimum=32, maximum=512)
140
+ width = gr.Number(value=min(int(cfg.get("width", 512)), 512), label="Width", minimum=32, maximum=512)
141
+ seed = gr.Textbox(value=str(cfg.get("seed", 1234)), label="Seed")
142
+ eta = gr.Slider(0.0, 1.0, value=0.0, step=0.01, label="Eta")
143
+
144
+ btn = gr.Button("Generate")
145
+ out_base = gr.Image(label="Base Model Output")
146
+ out_lora = gr.Image(label="LoRA Model Output")
147
+ status = gr.Textbox(label="Status", interactive=False)
148
+
149
+ btn.click(
150
+ _infer,
151
+ [prompt, steps, scale, height, width, seed, eta, session_state],
152
+ [out_base, out_lora, status, session_state],
153
+ )
154
+
155
+
156
+ port = int(os.getenv("GRADIO_SERVER_PORT", "7861"))
157
+ share = os.getenv("GRADIO_PUBLIC_SHARE", "False") == "True"
158
+ demo.launch(
159
+ server_name="0.0.0.0",
160
+ server_port=port,
161
+ show_error=True,
162
+ share=share,
163
+ )
164
+