Streamlit: fix image arg and indentation; integrate Imagen override rule into final decision
Browse files- streamlit_app.py +2 -2
- two_stage_inference.py +24 -2
streamlit_app.py
CHANGED
|
@@ -76,8 +76,8 @@ def main():
|
|
| 76 |
st.error("Please provide a valid image (upload or URL).")
|
| 77 |
return
|
| 78 |
|
| 79 |
-
# Show the image
|
| 80 |
-
st.image(tmp_path, caption="Input Image",
|
| 81 |
|
| 82 |
st.info("Loading models (first time may take a bit)...")
|
| 83 |
cascade = get_cascade()
|
|
|
|
| 76 |
st.error("Please provide a valid image (upload or URL).")
|
| 77 |
return
|
| 78 |
|
| 79 |
+
# Show the image (use_column_width for Streamlit 1.x)
|
| 80 |
+
st.image(tmp_path, caption="Input Image", use_column_width=True)
|
| 81 |
|
| 82 |
st.info("Loading models (first time may take a bit)...")
|
| 83 |
cascade = get_cascade()
|
two_stage_inference.py
CHANGED
|
@@ -293,16 +293,38 @@ class CascadeClassifier:
|
|
| 293 |
m_conf, m_idx, m_probs = softmax_confidence(logits)
|
| 294 |
m_label = self.multi_id2label.get(int(m_idx), str(m_idx))
|
| 295 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
result['multiclass'] = {
|
| 297 |
'label': m_label,
|
| 298 |
'index': int(m_idx),
|
| 299 |
'confidence': float(m_conf),
|
| 300 |
'probs': m_probs,
|
| 301 |
}
|
|
|
|
|
|
|
| 302 |
result['final'] = {
|
| 303 |
'decision': 'fake',
|
| 304 |
-
'origin':
|
| 305 |
-
'confidence': float(
|
| 306 |
}
|
| 307 |
return result
|
| 308 |
|
|
|
|
| 293 |
m_conf, m_idx, m_probs = softmax_confidence(logits)
|
| 294 |
m_label = self.multi_id2label.get(int(m_idx), str(m_idx))
|
| 295 |
|
| 296 |
+
# Apply Imagen override rule for the final decision:
|
| 297 |
+
# If top-1 is 'Imagen' and any other class prob > 1.7e-5, choose the second-highest instead.
|
| 298 |
+
override_idx = int(m_idx)
|
| 299 |
+
override_conf = float(m_conf)
|
| 300 |
+
try:
|
| 301 |
+
# Find index for label 'Imagen' (case-insensitive)
|
| 302 |
+
imagen_idx = next((i for i, v in self.multi_id2label.items() if str(v).strip().lower() == 'imagen'), None)
|
| 303 |
+
except Exception:
|
| 304 |
+
imagen_idx = None
|
| 305 |
+
if imagen_idx is not None:
|
| 306 |
+
# Sort indices by probability descending
|
| 307 |
+
order = sorted(range(len(m_probs)), key=lambda i: m_probs[i], reverse=True)
|
| 308 |
+
if order and order[0] == int(imagen_idx):
|
| 309 |
+
# Check if any non-imagen prob exceeds threshold
|
| 310 |
+
THRESH = 1.7e-5
|
| 311 |
+
any_other_above = any((i != imagen_idx) and (m_probs[i] > THRESH) for i in range(len(m_probs)))
|
| 312 |
+
if any_other_above and len(order) > 1:
|
| 313 |
+
override_idx = int(order[1])
|
| 314 |
+
override_conf = float(m_probs[override_idx])
|
| 315 |
+
|
| 316 |
result['multiclass'] = {
|
| 317 |
'label': m_label,
|
| 318 |
'index': int(m_idx),
|
| 319 |
'confidence': float(m_conf),
|
| 320 |
'probs': m_probs,
|
| 321 |
}
|
| 322 |
+
# Final decision reflects the override rule if triggered
|
| 323 |
+
final_label = self.multi_id2label.get(int(override_idx), str(override_idx))
|
| 324 |
result['final'] = {
|
| 325 |
'decision': 'fake',
|
| 326 |
+
'origin': final_label,
|
| 327 |
+
'confidence': float(override_conf),
|
| 328 |
}
|
| 329 |
return result
|
| 330 |
|