KomaAl commited on
Commit
d1bcd45
·
1 Parent(s): 709ee68

Streamlit: fix image arg and indentation; integrate Imagen override rule into final decision

Browse files
Files changed (2) hide show
  1. streamlit_app.py +2 -2
  2. two_stage_inference.py +24 -2
streamlit_app.py CHANGED
@@ -76,8 +76,8 @@ def main():
76
  st.error("Please provide a valid image (upload or URL).")
77
  return
78
 
79
- # Show the image
80
- st.image(tmp_path, caption="Input Image", use_container_width=True)
81
 
82
  st.info("Loading models (first time may take a bit)...")
83
  cascade = get_cascade()
 
76
  st.error("Please provide a valid image (upload or URL).")
77
  return
78
 
79
+ # Show the image (use_column_width for Streamlit 1.x)
80
+ st.image(tmp_path, caption="Input Image", use_column_width=True)
81
 
82
  st.info("Loading models (first time may take a bit)...")
83
  cascade = get_cascade()
two_stage_inference.py CHANGED
@@ -293,16 +293,38 @@ class CascadeClassifier:
293
  m_conf, m_idx, m_probs = softmax_confidence(logits)
294
  m_label = self.multi_id2label.get(int(m_idx), str(m_idx))
295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  result['multiclass'] = {
297
  'label': m_label,
298
  'index': int(m_idx),
299
  'confidence': float(m_conf),
300
  'probs': m_probs,
301
  }
 
 
302
  result['final'] = {
303
  'decision': 'fake',
304
- 'origin': m_label,
305
- 'confidence': float(m_conf),
306
  }
307
  return result
308
 
 
293
  m_conf, m_idx, m_probs = softmax_confidence(logits)
294
  m_label = self.multi_id2label.get(int(m_idx), str(m_idx))
295
 
296
+ # Apply Imagen override rule for the final decision:
297
+ # If top-1 is 'Imagen' and any other class prob > 1.7e-5, choose the second-highest instead.
298
+ override_idx = int(m_idx)
299
+ override_conf = float(m_conf)
300
+ try:
301
+ # Find index for label 'Imagen' (case-insensitive)
302
+ imagen_idx = next((i for i, v in self.multi_id2label.items() if str(v).strip().lower() == 'imagen'), None)
303
+ except Exception:
304
+ imagen_idx = None
305
+ if imagen_idx is not None:
306
+ # Sort indices by probability descending
307
+ order = sorted(range(len(m_probs)), key=lambda i: m_probs[i], reverse=True)
308
+ if order and order[0] == int(imagen_idx):
309
+ # Check if any non-imagen prob exceeds threshold
310
+ THRESH = 1.7e-5
311
+ any_other_above = any((i != imagen_idx) and (m_probs[i] > THRESH) for i in range(len(m_probs)))
312
+ if any_other_above and len(order) > 1:
313
+ override_idx = int(order[1])
314
+ override_conf = float(m_probs[override_idx])
315
+
316
  result['multiclass'] = {
317
  'label': m_label,
318
  'index': int(m_idx),
319
  'confidence': float(m_conf),
320
  'probs': m_probs,
321
  }
322
+ # Final decision reflects the override rule if triggered
323
+ final_label = self.multi_id2label.get(int(override_idx), str(override_idx))
324
  result['final'] = {
325
  'decision': 'fake',
326
+ 'origin': final_label,
327
+ 'confidence': float(override_conf),
328
  }
329
  return result
330