ariG23498 HF Staff commited on
Commit
e007554
·
1 Parent(s): 29750ea

init model and processor beforehand

Browse files
Files changed (1) hide show
  1. app.py +26 -10
app.py CHANGED
@@ -11,9 +11,28 @@ def extract_model_short_name(model_id):
11
 
12
 
13
  model_llmdet_id = "iSEE-Laboratory/llmdet_tiny"
 
 
 
 
 
14
  model_mm_grounding_id = "rziga/mm_grounding_dino_tiny_o365v1_goldg"
 
 
 
 
 
15
  model_omdet_id = "omlab/omdet-turbo-swin-tiny-hf"
 
 
 
 
 
16
  model_owlv2_id = "google/owlv2-large-patch14-ensemble"
 
 
 
 
17
 
18
  model_llmdet_name = extract_model_short_name(model_llmdet_id)
19
  model_mm_grounding_name = extract_model_short_name(model_mm_grounding_id)
@@ -22,13 +41,10 @@ model_owlv2_name = extract_model_short_name(model_owlv2_id)
22
 
23
 
24
  @spaces.GPU
25
- def detect(model_id: str, image: Image.Image, prompts: list, threshold: float):
26
  t0 = time.perf_counter()
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
- processor = AutoProcessor.from_pretrained(model_id)
29
- model = (
30
- AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device).eval()
31
- )
32
  texts = [prompts]
33
  inputs = processor(images=image, text=texts, return_tensors="pt").to(device)
34
  with torch.inference_mode():
@@ -38,7 +54,7 @@ def detect(model_id: str, image: Image.Image, prompts: list, threshold: float):
38
  )
39
  result = results[0]
40
  annotations = []
41
- for box, score, label_name in zip(result["boxes"], result["scores"], result["text_abels"]):
42
  if score >= threshold:
43
  xmin, ymin, xmax, ymax = [int(x) for x in box.tolist()]
44
  annotations.append(((xmin, ymin, xmax, ymax), f"{label_name} {score:.2f}"))
@@ -51,10 +67,10 @@ def run_detection(
51
  image: Image.Image, prompts_str: str, threshold_llm, threshold_mm, threshold_owlv2, threshold_omdet,
52
  ):
53
  prompts = [p.strip() for p in prompts_str.split(",")]
54
- ann_llm, time_llm = detect(model_llmdet_id, image, prompts, threshold_llm)
55
- ann_mm, time_mm = detect(model_mm_grounding_name, image, prompts, threshold_mm)
56
- ann_owlv2, time_owlv2 = detect(model_omdet_id, image, prompts, threshold_owlv2)
57
- ann_omdet, time_omdet = detect(model_owlv2_name, image, prompts, threshold_omdet)
58
  return (
59
  (image, ann_llm),
60
  time_llm,
 
11
 
12
 
13
  model_llmdet_id = "iSEE-Laboratory/llmdet_tiny"
14
+ processor_llmdet = AutoProcessor.from_pretrained(model_llmdet_id)
15
+ model_llmdet = (
16
+ AutoModelForZeroShotObjectDetection.from_pretrained(model_llmdet_id)
17
+ )
18
+
19
  model_mm_grounding_id = "rziga/mm_grounding_dino_tiny_o365v1_goldg"
20
+ processor_mm_grounding = AutoProcessor.from_pretrained(model_mm_grounding_id)
21
+ model_mm_grounding = (
22
+ AutoModelForZeroShotObjectDetection.from_pretrained(model_mm_grounding_id)
23
+ )
24
+
25
  model_omdet_id = "omlab/omdet-turbo-swin-tiny-hf"
26
+ processor_omdet = AutoProcessor.from_pretrained(model_omdet_id)
27
+ model_omdet = (
28
+ AutoModelForZeroShotObjectDetection.from_pretrained(model_omdet_id)
29
+ )
30
+
31
  model_owlv2_id = "google/owlv2-large-patch14-ensemble"
32
+ processor_owlv2 = AutoProcessor.from_pretrained(model_owlv2_id)
33
+ model_owlv2 = (
34
+ AutoModelForZeroShotObjectDetection.from_pretrained(model_owlv2_id)
35
+ )
36
 
37
  model_llmdet_name = extract_model_short_name(model_llmdet_id)
38
  model_mm_grounding_name = extract_model_short_name(model_mm_grounding_id)
 
41
 
42
 
43
  @spaces.GPU
44
+ def detect(model, processor, image: Image.Image, prompts: list, threshold: float):
45
  t0 = time.perf_counter()
46
  device = "cuda" if torch.cuda.is_available() else "cpu"
47
+ model = model.to(device).eval()
 
 
 
48
  texts = [prompts]
49
  inputs = processor(images=image, text=texts, return_tensors="pt").to(device)
50
  with torch.inference_mode():
 
54
  )
55
  result = results[0]
56
  annotations = []
57
+ for box, score, label_name in zip(result["boxes"], result["scores"], result["text_labels"]):
58
  if score >= threshold:
59
  xmin, ymin, xmax, ymax = [int(x) for x in box.tolist()]
60
  annotations.append(((xmin, ymin, xmax, ymax), f"{label_name} {score:.2f}"))
 
67
  image: Image.Image, prompts_str: str, threshold_llm, threshold_mm, threshold_owlv2, threshold_omdet,
68
  ):
69
  prompts = [p.strip() for p in prompts_str.split(",")]
70
+ ann_llm, time_llm = detect(model_llmdet, processor_llmdet, image, prompts, threshold_llm)
71
+ ann_mm, time_mm = detect(model_mm_grounding, processor_mm_grounding, image, prompts, threshold_mm)
72
+ ann_owlv2, time_owlv2 = detect(model_owlv2, processor_owlv2, image, prompts, threshold_owlv2)
73
+ ann_omdet, time_omdet = detect(model_omdet, processor_omdet, image, prompts, threshold_omdet)
74
  return (
75
  (image, ann_llm),
76
  time_llm,