Srikasi commited on
Commit
c861631
·
verified ·
1 Parent(s): 26842b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -19
app.py CHANGED
@@ -1,11 +1,20 @@
1
  import os, io, base64, time, yaml, requests
2
  from PIL import Image
3
  import gradio as gr
 
4
 
5
  # frontend-only: call your backend (RunPod/pod/etc.)
6
- BACKEND_URL = os.getenv("BACKEND_URL", "http://localhost:7861")
 
 
 
 
 
 
 
 
 
7
 
8
- # cfg = yaml.safe_load(open("configs/infer.yaml"))
9
  # default UI values if no YAML
10
  cfg = {
11
  "height": 512,
@@ -13,43 +22,88 @@ cfg = {
13
  "num_inference_steps": 30,
14
  "guidance_scale": 7.5,
15
  "seed": 1234,
16
- "eta" : 0,
17
  }
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def b64_to_img(s: str):
21
  data = base64.b64decode(s)
22
  return Image.open(io.BytesIO(data)).convert("RGB")
23
 
24
 
25
- def _infer(p, st, sc, h, w, sd, et):
 
 
 
 
26
  payload = {
27
  "prompt": p,
28
  "steps": int(st),
29
  "scale": float(sc),
30
- "height": int(h),
31
- "width": int(w),
32
  "seed": str(sd),
33
  "eta": float(et),
34
  }
 
 
 
 
 
35
  try:
36
- r = requests.post(BACKEND_URL, json=payload, timeout=120)
37
  if r.status_code == 429:
38
- blank = Image.new("RGB", (int(w), int(h)), (30, 30, 30))
39
- msg = r.json().get("error", "rate limited by backend")
40
- return blank, blank, msg
 
 
 
41
  r.raise_for_status()
42
  out = r.json()
43
  base_img = b64_to_img(out["base_image"])
44
  lora_img = b64_to_img(out["lora_image"])
45
- return base_img, lora_img, out.get("status", "ok")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  except Exception as e:
47
- blank = Image.new("RGB", (int(w), int(h)), (120, 50, 50))
48
- return blank, blank, f"Backend error: {e}"
49
 
50
 
51
  def build_ui():
52
  with gr.Blocks(title="Astro-Diffusion: Base vs LoRA") as demo:
 
 
 
 
 
 
53
  gr.HTML(
54
  """
55
  <style>
@@ -68,12 +122,27 @@ def build_ui():
68
  margin: 0;
69
  font-weight: 700;
70
  letter-spacing: 0.01em;
 
71
  }
72
  .astro-sub {
73
  color: #ffffff !important;
74
  margin: 0.3rem 0 0 0;
75
  font-style: italic;
 
 
 
 
 
76
  font-size: 0.8rem;
 
 
 
 
 
 
 
 
 
77
  }
78
  .astro-badge {
79
  background: #facc15;
@@ -90,7 +159,7 @@ def build_ui():
90
  border-radius: 0.5rem;
91
  margin-bottom: 0.5rem;
92
  }
93
- .gradio-container label,
94
  label,
95
  .gradio-container [class*="label"],
96
  .gradio-container [class^="svelte-"][class*="label"],
@@ -108,6 +177,12 @@ def build_ui():
108
  <div>
109
  <h2 class="astro-title">Astro-Diffusion : Base SD vs custom LoRA</h2>
110
  <p class="astro-sub">Video generation and more features coming up..!</p>
 
 
 
 
 
 
111
  </div>
112
  <div class="astro-badge">by Srivatsava Kasibhatla</div>
113
  </div>
@@ -115,14 +190,20 @@ def build_ui():
115
  )
116
 
117
  with gr.Group(elem_classes=["prompt-panel"]):
 
 
 
 
 
118
  prompt = gr.Textbox(
119
- value="a high-resolution spiral galaxy with blue star-forming arms and a bright yellow core",
120
  label="Prompt",
121
  )
122
 
 
 
123
 
124
  with gr.Row():
125
-
126
  steps = gr.Slider(10, 60, value=cfg.get("num_inference_steps", 30), step=1, label="Steps")
127
  scale = gr.Slider(1.0, 12.0, value=cfg.get("guidance_scale", 7.5), step=0.5, label="Guidance")
128
  height = gr.Number(value=min(int(cfg.get("height", 512)), 512), label="Height", minimum=32, maximum=512)
@@ -135,12 +216,16 @@ def build_ui():
135
  out_lora = gr.Image(label="LoRA Model Output")
136
  status = gr.Textbox(label="Status", interactive=False)
137
 
 
138
  btn.click(
139
  _infer,
140
- [prompt, steps, scale, height, width, seed, eta],
141
- [out_base, out_lora, status],
142
  )
143
 
 
 
 
144
  return demo
145
 
146
 
@@ -148,4 +233,4 @@ if __name__ == "__main__":
148
  interface = build_ui()
149
  port = int(os.getenv("PORT", os.getenv("GRADIO_SERVER_PORT", "8080")))
150
  share = os.getenv("GRADIO_PUBLIC_SHARE", "True").lower() == "true"
151
- interface.launch(server_name="0.0.0.0", server_port=port, share=share)
 
1
  import os, io, base64, time, yaml, requests
2
  from PIL import Image
3
  import gradio as gr
4
+ from requests.exceptions import ConnectionError, Timeout, HTTPError
5
 
6
  # frontend-only: call your backend (RunPod/pod/etc.)
7
+ BACKEND_URL = os.getenv("BACKEND_URL", "http://localhost:7861").rstrip("/")
8
+ print(f"[HF] BACKEND_URL resolved to: {BACKEND_URL}")
9
+
10
+ # sample prompts
11
+ SAMPLE_PROMPTS = [
12
+ "a high-resolution spiral galaxy with blue star-forming arms and a bright yellow core",
13
+ "a crimson emission nebula with dark dust lanes and scattered newborn stars",
14
+ "a ringed gas giant with visible storm bands and subtle shadow on rings",
15
+ "an accretion disk around a black hole with relativistic jets, high contrast",
16
+ ]
17
 
 
18
  # default UI values if no YAML
19
  cfg = {
20
  "height": 512,
 
22
  "num_inference_steps": 30,
23
  "guidance_scale": 7.5,
24
  "seed": 1234,
25
+ "eta": 0,
26
  }
27
 
28
 
29
+ # ---- health check ----
30
+ def check_backend():
31
+ try:
32
+ r = requests.get(f"{BACKEND_URL}/health", timeout=5)
33
+ r.raise_for_status()
34
+ data = r.json()
35
+ if data.get("status") == "ok":
36
+ return "backend=READY"
37
+ except Exception:
38
+ pass
39
+ return "backend=DOWN"
40
+
41
+
42
  def b64_to_img(s: str):
43
  data = base64.b64decode(s)
44
  return Image.open(io.BytesIO(data)).convert("RGB")
45
 
46
 
47
+ def _infer(p, st, sc, h, w, sd, et, session_id):
48
+ # make sure we always have ints for blank images
49
+ h = int(h)
50
+ w = int(w)
51
+
52
  payload = {
53
  "prompt": p,
54
  "steps": int(st),
55
  "scale": float(sc),
56
+ "height": h,
57
+ "width": w,
58
  "seed": str(sd),
59
  "eta": float(et),
60
  }
61
+
62
+ # send session_id if we have one
63
+ if session_id:
64
+ payload["session_id"] = session_id
65
+
66
  try:
67
+ r = requests.post(f"{BACKEND_URL}/infer", json=payload, timeout=120)
68
  if r.status_code == 429:
69
+ blank = Image.new("RGB", (w, h), (30, 30, 30))
70
+ out = r.json()
71
+ # backend also returns session_id on 429
72
+ new_sid = out.get("session_id", session_id)
73
+ msg = out.get("error", "rate limited by backend")
74
+ return blank, blank, msg, new_sid
75
  r.raise_for_status()
76
  out = r.json()
77
  base_img = b64_to_img(out["base_image"])
78
  lora_img = b64_to_img(out["lora_image"])
79
+ new_sid = out.get("session_id", session_id)
80
+ return base_img, lora_img, out.get("status", "ok"), new_sid
81
+
82
+ except ConnectionError:
83
+ blank = Image.new("RGB", (w, h), (120, 50, 50))
84
+ return blank, blank, "Backend not reachable (connection refused). Start the backend and retry.", session_id
85
+
86
+ except Timeout:
87
+ blank = Image.new("RGB", (w, h), (120, 50, 50))
88
+ return blank, blank, "Backend took too long. Please try again later.", session_id
89
+
90
+ except HTTPError as e:
91
+ blank = Image.new("RGB", (w, h), (120, 50, 50))
92
+ return blank, blank, f"Backend returned HTTP Error: {e.response.status_code}", session_id
93
+
94
  except Exception as e:
95
+ blank = Image.new("RGB", (w, h), (120, 50, 50))
96
+ return blank, blank, f"Unknown client error: {e}", session_id
97
 
98
 
99
  def build_ui():
100
  with gr.Blocks(title="Astro-Diffusion: Base vs LoRA") as demo:
101
+ # session state lives in the browser/tab
102
+ session_state = gr.State(value="")
103
+
104
+ # header + status
105
+ status_lbl = gr.Markdown("checking backend...")
106
+
107
  gr.HTML(
108
  """
109
  <style>
 
122
  margin: 0;
123
  font-weight: 700;
124
  letter-spacing: 0.01em;
125
+ font-size: 1.4rem; /* added */
126
  }
127
  .astro-sub {
128
  color: #ffffff !important;
129
  margin: 0.3rem 0 0 0;
130
  font-style: italic;
131
+ font-size: 0.9rem;
132
+ }
133
+ .astro-note {
134
+ color: #ffffff !important;
135
+ margin: 0.25rem 0 0.25rem 0;
136
  font-size: 0.8rem;
137
+ opacity: 0.9;
138
+ }
139
+ .astro-link {
140
+ margin-top: 0.55rem;
141
+ }
142
+ .astro-link a {
143
+ color: #ffffff !important;
144
+ text-decoration: underline;
145
+ font-size: 0.78rem;
146
  }
147
  .astro-badge {
148
  background: #facc15;
 
159
  border-radius: 0.5rem;
160
  margin-bottom: 0.5rem;
161
  }
162
+ .gradio-container label,
163
  label,
164
  .gradio-container [class*="label"],
165
  .gradio-container [class^="svelte-"][class*="label"],
 
177
  <div>
178
  <h2 class="astro-title">Astro-Diffusion : Base SD vs custom LoRA</h2>
179
  <p class="astro-sub">Video generation and more features coming up..!</p>
180
+ <p class="astro-note">Shared hourly/daily limits globally for this demo. Please use sparingly.</p>
181
+ <p class="astro-link">
182
+ <a href="https://github.com/KSV2001/astro_diffusion" target="_blank" rel="noreferrer noopener">
183
+ Visit Srivatsava's GitHub repo
184
+ </a>
185
+ </p>
186
  </div>
187
  <div class="astro-badge">by Srivatsava Kasibhatla</div>
188
  </div>
 
190
  )
191
 
192
  with gr.Group(elem_classes=["prompt-panel"]):
193
+ sample_dropdown = gr.Dropdown(
194
+ choices=SAMPLE_PROMPTS,
195
+ value=SAMPLE_PROMPTS[0],
196
+ label="Sample prompts",
197
+ )
198
  prompt = gr.Textbox(
199
+ value=SAMPLE_PROMPTS[0],
200
  label="Prompt",
201
  )
202
 
203
+ # when user picks a sample, copy it into the textbox
204
+ sample_dropdown.change(fn=lambda x: x, inputs=sample_dropdown, outputs=prompt)
205
 
206
  with gr.Row():
 
207
  steps = gr.Slider(10, 60, value=cfg.get("num_inference_steps", 30), step=1, label="Steps")
208
  scale = gr.Slider(1.0, 12.0, value=cfg.get("guidance_scale", 7.5), step=0.5, label="Guidance")
209
  height = gr.Number(value=min(int(cfg.get("height", 512)), 512), label="Height", minimum=32, maximum=512)
 
216
  out_lora = gr.Image(label="LoRA Model Output")
217
  status = gr.Textbox(label="Status", interactive=False)
218
 
219
+ # send session_state, receive updated session_state
220
  btn.click(
221
  _infer,
222
+ [prompt, steps, scale, height, width, seed, eta, session_state],
223
+ [out_base, out_lora, status, session_state],
224
  )
225
 
226
+ # ping once when UI loads
227
+ demo.load(fn=check_backend, inputs=None, outputs=status_lbl)
228
+
229
  return demo
230
 
231
 
 
233
  interface = build_ui()
234
  port = int(os.getenv("PORT", os.getenv("GRADIO_SERVER_PORT", "8080")))
235
  share = os.getenv("GRADIO_PUBLIC_SHARE", "True").lower() == "true"
236
+ interface.launch(server_name="0.0.0.0", server_port=port, share=share)