import os import io import tempfile from typing import Optional import streamlit as st from PIL import Image import requests from two_stage_inference import CascadeClassifier # Fixed configuration DEFAULT_BINARY_DIR = "KomaAl/level_1_binary_model" DEFAULT_MULTICLASS_DIR = "KomaAl/level_2_multiclass_model" DEFAULT_DEVICE = None # let CascadeClassifier auto-select DEFAULT_AMP = True @st.cache_resource def get_cascade(): return CascadeClassifier( bin_dir=DEFAULT_BINARY_DIR, multi_dir=DEFAULT_MULTICLASS_DIR, device=DEFAULT_DEVICE, amp=DEFAULT_AMP, ) def download_image(url: str) -> Optional[str]: try: resp = requests.get(url, timeout=10) resp.raise_for_status() ct = resp.headers.get("content-type", "").lower() if "image" not in ct: return None img = Image.open(io.BytesIO(resp.content)).convert("RGB") tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png") img.save(tmp.name) return tmp.name except Exception: return None def save_uploaded_to_temp(uploaded_file) -> Optional[str]: try: img = Image.open(uploaded_file).convert("RGB") tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png") img.save(tmp.name) return tmp.name except Exception: return None def main(): st.set_page_config(page_title="AI Image Detection", page_icon="🖼️", layout="centered") st.title("Two-Stage AI Image Detection") st.write( "Upload an image or paste an image URL. We'll predict if it's real or fake; if fake, we'll also predict its origin." ) col1, col2 = st.columns(2) with col1: uploaded = st.file_uploader("Upload image", type=["png", "jpg", "jpeg", "webp"]) with col2: url = st.text_input("...or paste image URL") run = st.button("Run Inference") if run: tmp_path = None if uploaded is not None: tmp_path = save_uploaded_to_temp(uploaded) elif url: tmp_path = download_image(url) if not tmp_path or not os.path.exists(tmp_path): st.error("Please provide a valid image (upload or URL).") return # Show the image (use_column_width for Streamlit 1.x) st.image(tmp_path, caption="Input Image", use_column_width=True) st.info("Loading models (first time may take a bit)...") cascade = get_cascade() try: result = cascade.predict(tmp_path) except Exception as e: st.error(f"Error during inference: {e}") return finally: try: os.remove(tmp_path) except Exception: pass binary = result.get("binary", {}) multi = result.get("multiclass", {}) final = result.get("final", {}) st.subheader("Inference Summary") st.write("Stage 1 (Real/Fake):") st.write(f"- Label: {binary.get('label')} | Confidence: {binary.get('confidence', 0):.4f}") if multi: st.write("\nStage 2 (Origin):") st.write(f"- Label: {multi.get('label')} | Confidence: {multi.get('confidence', 0):.4f}") st.write("\nFinal Decision:") st.json(final) if __name__ == "__main__": main()