codewithRiz commited on
Commit
498d1bb
·
verified ·
1 Parent(s): 3762343

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -0
app.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from IPython.display import display, JSON
2
+ import matplotlib.pyplot as plt
3
+ from speciesnet import DEFAULT_MODEL, SUPPORTED_MODELS, SpeciesNet
4
+ import numpy as np
5
+ import time
6
+ import gradio as gr
7
+ import json
8
+ import cv2
9
+
10
+
11
+ # --- Load SpeciesNet model ---
12
+ print("Default SpeciesNet model:", DEFAULT_MODEL)
13
+ print("Supported SpeciesNet models:", SUPPORTED_MODELS)
14
+ model = SpeciesNet(DEFAULT_MODEL)
15
+
16
+
17
+ # --- Visualization Function ---
18
+ def draw_predictions(image_path, predictions_dict):
19
+ # Load image with OpenCV
20
+ img = cv2.imread(image_path)
21
+ if img is None:
22
+ raise ValueError(f"Could not load image: {image_path}")
23
+
24
+ img_h, img_w, _ = img.shape
25
+
26
+ for pred in predictions_dict.get("predictions", []):
27
+ detections = pred.get("detections", [])
28
+ classifications = pred.get("classifications", {})
29
+
30
+ # Get class names and scores
31
+ classes = classifications.get("classes", [])
32
+ scores = classifications.get("scores", [])
33
+
34
+ # Pick top classification if available
35
+ top_class_name = None
36
+ top_score = None
37
+ if len(classes) > 0:
38
+ top_class_name = classes[0].split(";")[-1] # readable species name
39
+ top_score = scores[0]
40
+
41
+ # --- Draw detections ---
42
+ for i, det in enumerate(detections):
43
+ bbox = det["bbox"] # [x, y, w, h] (relative coords)
44
+ conf = det["conf"]
45
+ label = det["label"]
46
+
47
+ # Convert relative bbox → pixel coordinates
48
+ x, y, w, h = bbox
49
+ x1 = int(x * img_w)
50
+ y1 = int(y * img_h)
51
+ x2 = int((x + w) * img_w)
52
+ y2 = int((y + h) * img_h)
53
+
54
+ # Draw bounding box
55
+ cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 3)
56
+
57
+ # Build text lines
58
+ detection_text = f"{label} ({conf:.2f})"
59
+ classification_text = (
60
+ f"{top_class_name} ({top_score:.2f})" if top_class_name else ""
61
+ )
62
+
63
+ # Combine both labels stacked vertically
64
+ text_lines = []
65
+ if classification_text:
66
+ text_lines.append(classification_text)
67
+ text_lines.append(detection_text)
68
+
69
+ # Compute text height for background box
70
+ total_text_height = 0
71
+ text_widths = []
72
+ for line in text_lines:
73
+ (text_w, text_h), _ = cv2.getTextSize(
74
+ line, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2
75
+ )
76
+ total_text_height += text_h + 5
77
+ text_widths.append(text_w)
78
+
79
+ max_text_width = max(text_widths)
80
+
81
+ # Background for stacked labels
82
+ cv2.rectangle(
83
+ img,
84
+ (x1, max(y1 - total_text_height - 10, 0)),
85
+ (x1 + max_text_width + 10, y1),
86
+ (0, 255, 0),
87
+ -1,
88
+ )
89
+
90
+ # Write both lines
91
+ y_text = y1 - 5
92
+ for line in text_lines[::-1]: # classification above detection
93
+ cv2.putText(
94
+ img,
95
+ line,
96
+ (x1 + 5, y_text),
97
+ cv2.FONT_HERSHEY_SIMPLEX,
98
+ 0.6,
99
+ (0, 0, 0),
100
+ 2,
101
+ cv2.LINE_AA,
102
+ )
103
+ (_, text_h), _ = cv2.getTextSize(line, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
104
+ y_text -= text_h + 5
105
+
106
+ # Convert BGR → RGB for display
107
+ img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
108
+ return img_rgb
109
+
110
+
111
+ # --- Inference Function ---
112
+ def inference(image):
113
+ filepath = "temp_image.jpg"
114
+ image.save(filepath)
115
+
116
+ start_time = time.time()
117
+ predictions_dict = model.predict(
118
+ instances_dict={
119
+ "instances": [
120
+ {
121
+ "filepath": filepath,
122
+ "country": "VNM", # optional
123
+ }
124
+ ]
125
+ }
126
+ )
127
+ end_time = time.time()
128
+ print(f"Inference Time: {end_time - start_time:.2f} seconds")
129
+
130
+ # Draw predictions with detection + classification stacked
131
+ annotated_image = draw_predictions(filepath, predictions_dict)
132
+
133
+ # Return image + JSON
134
+ formatted_json = json.dumps(predictions_dict, indent=4)
135
+ return annotated_image, formatted_json
136
+
137
+
138
+ # --- Gradio Interface ---
139
+ iface = gr.Interface(
140
+ fn=inference,
141
+ inputs=gr.Image(type="pil"),
142
+ outputs=[
143
+ gr.Image(label="Detection + Classification Visualization"),
144
+ gr.JSON(label="Prediction Details"),
145
+ ],
146
+ title="🐾 SpeciesNet Wildlife Detector + Classifier",
147
+ description=(
148
+ "Upload a wildlife camera image. The model detects animals and shows both "
149
+ "the detection label (e.g., 'animal 0.97') and classification result "
150
+ "(e.g., 'white-tailed deer 0.99') on each bounding box."
151
+ ),
152
+ )
153
+
154
+ iface.launch()