ariG23498 HF Staff commited on
Commit
75df5eb
Β·
1 Parent(s): cb90111
Files changed (1) hide show
  1. app.py +171 -155
app.py CHANGED
@@ -1,230 +1,246 @@
 
 
 
 
1
  import gradio as gr
2
  import spaces
3
  import torch
 
4
  from transformers import (
5
  AutoProcessor,
6
  AutoModelForZeroShotObjectDetection,
7
- Owlv2ForObjectDetection,
8
- OmDetTurboForObjectDetection,
9
  )
10
- from PIL import Image
11
- import time
12
 
 
 
 
13
 
14
- def extract_model_short_name(model_id):
15
  return model_id.split("/")[-1].replace("-", " ").replace("_", " ")
16
 
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  model_llmdet_id = "iSEE-Laboratory/llmdet_tiny"
19
  processor_llmdet = AutoProcessor.from_pretrained(model_llmdet_id)
20
- model_llmdet = AutoModelForZeroShotObjectDetection.from_pretrained(model_llmdet_id)
 
 
 
 
 
 
 
21
 
 
22
  model_mm_grounding_id = "rziga/mm_grounding_dino_tiny_o365v1_goldg"
23
  processor_mm_grounding = AutoProcessor.from_pretrained(model_mm_grounding_id)
24
- model_mm_grounding = AutoModelForZeroShotObjectDetection.from_pretrained(
25
- model_mm_grounding_id
 
 
 
 
 
26
  )
27
 
 
28
  model_omdet_id = "omlab/omdet-turbo-swin-tiny-hf"
29
  processor_omdet = AutoProcessor.from_pretrained(model_omdet_id)
30
- model_omdet = AutoModelForZeroShotObjectDetection.from_pretrained(model_omdet_id)
 
 
 
 
 
 
 
31
 
 
32
  model_owlv2_id = "google/owlv2-large-patch14-ensemble"
33
  processor_owlv2 = AutoProcessor.from_pretrained(model_owlv2_id)
34
- model_owlv2 = AutoModelForZeroShotObjectDetection.from_pretrained(model_owlv2_id)
35
-
36
- model_llmdet_name = extract_model_short_name(model_llmdet_id)
37
- model_mm_grounding_name = extract_model_short_name(model_mm_grounding_id)
38
- model_omdet_name = extract_model_short_name(model_omdet_id)
39
- model_owlv2_name = extract_model_short_name(model_owlv2_id)
 
 
40
 
 
 
 
41
 
42
  @spaces.GPU
43
- def detect(model, processor, image: Image.Image, prompts: list, threshold: float):
 
 
 
 
 
 
 
 
44
  t0 = time.perf_counter()
45
- device = "cuda" if torch.cuda.is_available() else "cpu"
46
- model.to(device).eval()
47
  texts = [prompts]
48
- inputs = processor(images=image, text=texts, return_tensors="pt").to(device)
 
49
  with torch.inference_mode():
50
- outputs = model(**inputs)
51
- results = processor.post_process_grounded_object_detection(
 
 
 
 
 
 
52
  outputs, threshold=threshold, target_sizes=[image.size[::-1]]
53
- )
54
- result = results[0]
55
  annotations = []
 
 
 
 
 
56
 
57
- if isinstance(model, Owlv2ForObjectDetection) or isinstance(
58
- model, OmDetTurboForObjectDetection
59
- ):
60
- key = "labels"
61
- check = True
62
- else:
63
- key = "text_labels"
64
- check = False
65
-
66
- for box, score, label in zip(result["boxes"], result["scores"], result[key]):
67
- if score >= threshold:
68
- if check:
69
- label_id = label
70
- label_name = prompts[label_id]
71
  else:
72
- label_name = label
73
- xmin, ymin, xmax, ymax = [int(x) for x in box.tolist()]
74
- annotations.append(((xmin, ymin, xmax, ymax), f"{label_name} {score:.2f}"))
 
 
 
 
 
75
  elapsed_ms = (time.perf_counter() - t0) * 1000
76
- time_taken = f"**Inference time ({model_omdet_name}):** {elapsed_ms:.0f} ms"
77
  return annotations, time_taken
78
 
 
 
79
 
80
  def run_detection(
81
  image: Image.Image,
82
  prompts_str: str,
83
- threshold_llm,
84
- threshold_mm,
85
- threshold_owlv2,
86
- threshold_omdet,
87
  ):
88
- prompts = [p.strip() for p in prompts_str.split(",")]
89
- ann_llm, time_llm = detect(
90
- model_llmdet, processor_llmdet, image, prompts, threshold_llm
91
- )
92
- ann_mm, time_mm = detect(
93
- model_mm_grounding, processor_mm_grounding, image, prompts, threshold_mm
94
- )
95
- ann_owlv2, time_owlv2 = detect(
96
- model_owlv2, processor_owlv2, image, prompts, threshold_owlv2
97
- )
98
- ann_omdet, time_omdet = detect(
99
- model_omdet, processor_omdet, image, prompts, threshold_omdet
100
- )
101
  return (
102
- (image, ann_llm),
103
- time_llm,
104
- (image, ann_mm),
105
- time_mm,
106
- (image, ann_owlv2),
107
- time_owlv2,
108
- (image, ann_omdet),
109
- time_omdet,
110
  )
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  with gr.Blocks() as app:
114
- gr.Markdown("# Zero-Shot Object Detection Arena")
115
- gr.Markdown(
116
- "### Compare different zero-shot object detection models on the same image and prompts."
117
- )
118
  with gr.Row():
119
  with gr.Column(scale=1):
120
  image = gr.Image(type="pil", label="Upload an image", height=400)
121
  prompts = gr.Textbox(
122
- label="Prompts (comma-separated)", value="a cat, a remote control"
 
 
123
  )
124
  with gr.Accordion("Per-model confidence thresholds", open=True):
125
- threshold_llm = gr.Slider(
126
- label="Threshold for LLMDet", minimum=0.0, maximum=1.0, value=0.3
127
- )
128
- threshold_mm = gr.Slider(
129
- label="Threshold for MM GroundingDINO Tiny",
130
- minimum=0.0,
131
- maximum=1.0,
132
- value=0.3,
133
- )
134
- threshold_owlv2 = gr.Slider(
135
- label="Threshold for OwlV2 Large",
136
- minimum=0.0,
137
- maximum=1.0,
138
- value=0.1,
139
- )
140
- threshold_omdet = gr.Slider(
141
- label="Threshold for OMDet Turbo Swin Tiny",
142
- minimum=0.0,
143
- maximum=1.0,
144
- value=0.2,
145
- )
146
  generate_btn = gr.Button(value="Detect")
 
147
  with gr.Row():
148
  with gr.Column(scale=2):
149
- output_image_llm = gr.AnnotatedImage(
150
- label=f"Annotated image for {model_llmdet_name}", height=400
151
- )
152
  output_time_llm = gr.Markdown()
153
  with gr.Column(scale=2):
154
- output_image_mm = gr.AnnotatedImage(
155
- label=f"Annotated image for {model_mm_grounding_name}", height=400
156
- )
157
  output_time_mm = gr.Markdown()
 
158
  with gr.Row():
159
  with gr.Column(scale=2):
160
- output_image_owlv2 = gr.AnnotatedImage(
161
- label=f"Annotated image for {model_owlv2_name}", height=400
162
- )
163
  output_time_owlv2 = gr.Markdown()
164
  with gr.Column(scale=2):
165
- output_image_omdet = gr.AnnotatedImage(
166
- label=f"Annotated image for {model_omdet_name}", height=400
167
- )
168
  output_time_omdet = gr.Markdown()
 
169
  gr.Markdown("### Examples")
170
  example_data = [
171
- [
172
- "http://images.cocodataset.org/val2017/000000039769.jpg",
173
- "a cat, a remote control",
174
- 0.30,
175
- 0.30,
176
- 0.10,
177
- 0.30,
178
- ],
179
- [
180
- "http://images.cocodataset.org/val2017/000000000139.jpg",
181
- "a person, a tv, a remote",
182
- 0.35,
183
- 0.30,
184
- 0.12,
185
- 0.30,
186
- ],
187
  ]
188
 
189
  gr.Examples(
190
  examples=example_data,
191
- inputs=[
192
- image,
193
- prompts,
194
- threshold_llm,
195
- threshold_mm,
196
- threshold_owlv2,
197
- threshold_omdet,
198
- ],
199
  label="Click an example to populate the inputs",
200
  )
201
- inputs = [
202
- image,
203
- prompts,
204
- threshold_llm,
205
- threshold_mm,
206
- threshold_owlv2,
207
- threshold_omdet,
208
- ]
209
  outputs = [
210
- output_image_llm,
211
- output_time_llm,
212
- output_image_mm,
213
- output_time_mm,
214
- output_image_owlv2,
215
- output_time_owlv2,
216
- output_image_omdet,
217
- output_time_omdet,
218
  ]
219
- generate_btn.click(
220
- fn=run_detection,
221
- inputs=inputs,
222
- outputs=outputs,
223
- )
224
- image.upload(
225
- fn=run_detection,
226
- inputs=inputs,
227
- outputs=outputs,
228
- )
229
 
230
- app.launch()
 
 
1
+ import time
2
+ from dataclasses import dataclass
3
+ from typing import List, Tuple
4
+
5
  import gradio as gr
6
  import spaces
7
  import torch
8
+ from PIL import Image
9
  from transformers import (
10
  AutoProcessor,
11
  AutoModelForZeroShotObjectDetection,
 
 
12
  )
 
 
13
 
14
+ # ---------------------------
15
+ # Setup
16
+ # ---------------------------
17
 
18
+ def extract_model_short_name(model_id: str) -> str:
19
  return model_id.split("/")[-1].replace("-", " ").replace("_", " ")
20
 
21
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
22
 
23
+ # (Optional) modest speed-ups
24
+ torch.set_grad_enabled(False)
25
+
26
+ # Model bundles for cleaner wiring
27
+ @dataclass
28
+ class ZSDetBundle:
29
+ model_id: str
30
+ model_name: str
31
+ processor: AutoProcessor
32
+ model: AutoModelForZeroShotObjectDetection
33
+ use_label_ids: bool # True for OWLv2/OMDet (labels are indices), False for others
34
+
35
+ # LLMDet
36
  model_llmdet_id = "iSEE-Laboratory/llmdet_tiny"
37
  processor_llmdet = AutoProcessor.from_pretrained(model_llmdet_id)
38
+ model_llmdet = AutoModelForZeroShotObjectDetection.from_pretrained(model_llmdet_id).to(DEVICE).eval()
39
+ bundle_llmdet = ZSDetBundle(
40
+ model_id=model_llmdet_id,
41
+ model_name=extract_model_short_name(model_llmdet_id),
42
+ processor=processor_llmdet,
43
+ model=model_llmdet,
44
+ use_label_ids=False,
45
+ )
46
 
47
+ # MM GroundingDINO
48
  model_mm_grounding_id = "rziga/mm_grounding_dino_tiny_o365v1_goldg"
49
  processor_mm_grounding = AutoProcessor.from_pretrained(model_mm_grounding_id)
50
+ model_mm_grounding = AutoModelForZeroShotObjectDetection.from_pretrained(model_mm_grounding_id).to(DEVICE).eval()
51
+ bundle_mm_grounding = ZSDetBundle(
52
+ model_id=model_mm_grounding_id,
53
+ model_name=extract_model_short_name(model_mm_grounding_id),
54
+ processor=processor_mm_grounding,
55
+ model=model_mm_grounding,
56
+ use_label_ids=False,
57
  )
58
 
59
+ # OMDet Turbo
60
  model_omdet_id = "omlab/omdet-turbo-swin-tiny-hf"
61
  processor_omdet = AutoProcessor.from_pretrained(model_omdet_id)
62
+ model_omdet = AutoModelForZeroShotObjectDetection.from_pretrained(model_omdet_id).to(DEVICE).eval()
63
+ bundle_omdet = ZSDetBundle(
64
+ model_id=model_omdet_id,
65
+ model_name=extract_model_short_name(model_omdet_id),
66
+ processor=processor_omdet,
67
+ model=model_omdet,
68
+ use_label_ids=True, # returns label indices
69
+ )
70
 
71
+ # OWLv2
72
  model_owlv2_id = "google/owlv2-large-patch14-ensemble"
73
  processor_owlv2 = AutoProcessor.from_pretrained(model_owlv2_id)
74
+ model_owlv2 = AutoModelForZeroShotObjectDetection.from_pretrained(model_owlv2_id).to(DEVICE).eval()
75
+ bundle_owlv2 = ZSDetBundle(
76
+ model_id=model_owlv2_id,
77
+ model_name=extract_model_short_name(model_owlv2_id),
78
+ processor=processor_owlv2,
79
+ model=model_owlv2,
80
+ use_label_ids=True, # returns label indices
81
+ )
82
 
83
+ # ---------------------------
84
+ # Inference
85
+ # ---------------------------
86
 
87
  @spaces.GPU
88
+ def detect(
89
+ bundle: ZSDetBundle,
90
+ image: Image.Image,
91
+ prompts: List[str],
92
+ threshold: float,
93
+ ) -> Tuple[List[Tuple[Tuple[int, int, int, int], str]], str]:
94
+ """
95
+ Returns [(bbox, label_score_str), ...], time_str
96
+ """
97
  t0 = time.perf_counter()
98
+
99
+ # HF zero-shot OD expects list-of-list text
100
  texts = [prompts]
101
+ inputs = bundle.processor(images=image, text=texts, return_tensors="pt").to(DEVICE)
102
+
103
  with torch.inference_mode():
104
+ if DEVICE == "cuda":
105
+ # Use autocast to speed up mixed-precision-friendly ops
106
+ with torch.amp.autocast():
107
+ outputs = bundle.model(**inputs)
108
+ else:
109
+ outputs = bundle.model(**inputs)
110
+
111
+ results = bundle.processor.post_process_grounded_object_detection(
112
  outputs, threshold=threshold, target_sizes=[image.size[::-1]]
113
+ )[0]
114
+
115
  annotations = []
116
+ key = "labels" if bundle.use_label_ids else "text_labels"
117
+
118
+ for box, score, label in zip(results["boxes"], results["scores"], results[key]):
119
+ if float(score) < threshold:
120
+ continue
121
 
122
+ if bundle.use_label_ids:
123
+ # Map label index -> prompt string
124
+ label_idx = int(label) if isinstance(label, torch.Tensor) else int(label)
125
+ if 0 <= label_idx < len(prompts):
126
+ label_name = prompts[label_idx]
 
 
 
 
 
 
 
 
 
127
  else:
128
+ label_name = str(label_idx)
129
+ else:
130
+ # Direct text label
131
+ label_name = label if isinstance(label, str) else str(label)
132
+
133
+ xmin, ymin, xmax, ymax = map(lambda v: int(v), box.tolist())
134
+ annotations.append(((xmin, ymin, xmax, ymax), f"{label_name} {float(score):.2f}"))
135
+
136
  elapsed_ms = (time.perf_counter() - t0) * 1000
137
+ time_taken = f"**Inference time ({bundle.model_name}):** {elapsed_ms:.0f} ms"
138
  return annotations, time_taken
139
 
140
+ def parse_prompts(prompts_str: str) -> List[str]:
141
+ return [p.strip() for p in prompts_str.split(",") if p.strip()]
142
 
143
  def run_detection(
144
  image: Image.Image,
145
  prompts_str: str,
146
+ threshold_llm: float,
147
+ threshold_mm: float,
148
+ threshold_owlv2: float,
149
+ threshold_omdet: float,
150
  ):
151
+ prompts = parse_prompts(prompts_str)
152
+
153
+ ann_llm, time_llm = detect(bundle_llmdet, image, prompts, threshold_llm)
154
+ ann_mm, time_mm = detect(bundle_mm_grounding, image, prompts, threshold_mm)
155
+ ann_owlv2, time_owlv2 = detect(bundle_owlv2, image, prompts, threshold_owlv2)
156
+ ann_omdet, time_omdet = detect(bundle_omdet, image, prompts, threshold_omdet)
157
+
 
 
 
 
 
 
158
  return (
159
+ (image, ann_llm), time_llm,
160
+ (image, ann_mm), time_mm,
161
+ (image, ann_owlv2), time_owlv2,
162
+ (image, ann_omdet), time_omdet,
 
 
 
 
163
  )
164
 
165
+ # ---------------------------
166
+ # Compact Description
167
+ # ---------------------------
168
+
169
+ description_md = """
170
+ # Zero-Shot Object Detection Arena
171
+
172
+ Compare **four zero-shot object detectors** on the same image + prompts.
173
+ Upload an image (or pick an example), add **comma-separated prompts**, tweak per-model **thresholds**, and hit **Detect**.
174
+ You'll see bounding boxes, scores, and **per-model inference time**.
175
+
176
+ **Models**
177
+ - LLMDet Tiny β€” [`iSEE-Laboratory/llmdet_tiny`](https://huggingface.co/iSEE-Laboratory/llmdet_tiny)
178
+ - MM GroundingDINO Tiny O365v1 GoldG β€” [`rziga/mm_grounding_dino_tiny_o365v1_goldg`](https://huggingface.co/rziga/mm_grounding_dino_tiny_o365v1_goldg)
179
+ - OMDet Turbo Swin Tiny β€” [`omlab/omdet-turbo-swin-tiny-hf`](https://huggingface.co/omlab/omdet-turbo-swin-tiny-hf)
180
+ - OWL-V2 Large Patch14 Ensemble β€” [`google/owlv2-large-patch14-ensemble`](https://huggingface.co/google/owlv2-large-patch14-ensemble)
181
+
182
+ **Tip:** Lower thresholds ↑ recall but may ↑ false positives.
183
+ """
184
+
185
+ # ---------------------------
186
+ # UI
187
+ # ---------------------------
188
 
189
  with gr.Blocks() as app:
190
+ gr.Markdown(description_md)
191
+
 
 
192
  with gr.Row():
193
  with gr.Column(scale=1):
194
  image = gr.Image(type="pil", label="Upload an image", height=400)
195
  prompts = gr.Textbox(
196
+ label="Prompts (comma-separated)",
197
+ value="a cat, a remote control",
198
+ placeholder="e.g., a cat, a remote control",
199
  )
200
  with gr.Accordion("Per-model confidence thresholds", open=True):
201
+ threshold_llm = gr.Slider(label=f"Threshold β€” {bundle_llmdet.model_name}", minimum=0.0, maximum=1.0, value=0.3)
202
+ threshold_mm = gr.Slider(label=f"Threshold β€” {bundle_mm_grounding.model_name}", minimum=0.0, maximum=1.0, value=0.3)
203
+ threshold_owlv2 = gr.Slider(label=f"Threshold β€” {bundle_owlv2.model_name}", minimum=0.0, maximum=1.0, value=0.1)
204
+ threshold_omdet = gr.Slider(label=f"Threshold β€” {bundle_omdet.model_name}", minimum=0.0, maximum=1.0, value=0.2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  generate_btn = gr.Button(value="Detect")
206
+
207
  with gr.Row():
208
  with gr.Column(scale=2):
209
+ output_image_llm = gr.AnnotatedImage(label=f"Annotated β€” {bundle_llmdet.model_name}", height=400)
 
 
210
  output_time_llm = gr.Markdown()
211
  with gr.Column(scale=2):
212
+ output_image_mm = gr.AnnotatedImage(label=f"Annotated β€” {bundle_mm_grounding.model_name}", height=400)
 
 
213
  output_time_mm = gr.Markdown()
214
+
215
  with gr.Row():
216
  with gr.Column(scale=2):
217
+ output_image_owlv2 = gr.AnnotatedImage(label=f"Annotated β€” {bundle_owlv2.model_name}", height=400)
 
 
218
  output_time_owlv2 = gr.Markdown()
219
  with gr.Column(scale=2):
220
+ output_image_omdet = gr.AnnotatedImage(label=f"Annotated β€” {bundle_omdet.model_name}", height=400)
 
 
221
  output_time_omdet = gr.Markdown()
222
+
223
  gr.Markdown("### Examples")
224
  example_data = [
225
+ ["https://images.cocodataset.org/val2017/000000039769.jpg", "a cat, a remote control", 0.30, 0.30, 0.10, 0.30],
226
+ ["https://images.cocodataset.org/val2017/000000000139.jpg", "a person, a tv, a remote", 0.35, 0.30, 0.12, 0.30],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  ]
228
 
229
  gr.Examples(
230
  examples=example_data,
231
+ inputs=[image, prompts, threshold_llm, threshold_mm, threshold_owlv2, threshold_omdet],
 
 
 
 
 
 
 
232
  label="Click an example to populate the inputs",
233
  )
234
+
235
+ inputs = [image, prompts, threshold_llm, threshold_mm, threshold_owlv2, threshold_omdet]
 
 
 
 
 
 
236
  outputs = [
237
+ output_image_llm, output_time_llm,
238
+ output_image_mm, output_time_mm,
239
+ output_image_owlv2, output_time_owlv2,
240
+ output_image_omdet, output_time_omdet,
 
 
 
 
241
  ]
242
+ generate_btn.click(fn=run_detection, inputs=inputs, outputs=outputs)
243
+ image.upload(fn=run_detection, inputs=inputs, outputs=outputs)
 
 
 
 
 
 
 
 
244
 
245
+ # Optional: queue to handle multiple users gracefully (tune as needed)
246
+ app.queue(max_size=16, concurrency_count=1).launch()