import gc from pathlib import Path import gradio as gr import matplotlib.cm as cm import numpy as np import spaces import torch import torch.nn.functional as F from PIL import Image, ImageOps from transformers import AutoImageProcessor, AutoModel # Device configuration with memory management DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MODEL_MAP = { "DINOv3 ViT-L/16 Satellite": "facebook/dinov3-vitl16-pretrain-sat493m", "DINOv3 ViT-L/16 LVD (General Web)": "facebook/dinov3-vitl16-pretrain-lvd1689m", "⚠️ DINOv3 ViT-7B/16 Satellite": "facebook/dinov3-vit7b16-pretrain-sat493m", } DEFAULT_NAME = list(MODEL_MAP.keys())[0] # Global model state processor = None model = None def cleanup_memory(): """Aggressive memory cleanup for model switching""" global processor, model if model is not None: del model model = None if processor is not None: del processor processor = None gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() def load_model(name): """Load model with proper memory management and dtype handling""" global processor, model try: # Clean up existing model cleanup_memory() model_id = MODEL_MAP[name] # Load with auto dtype for optimal performance processor = AutoImageProcessor.from_pretrained(model_id) # Determine optimal dtype based on model size and hardware if "7b" in model_id.lower() and torch.cuda.is_available(): # For 7B model, use bfloat16 if available for memory efficiency dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 else: dtype = torch.float32 model = AutoModel.from_pretrained( model_id, torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() else None, ) if DEVICE == "cuda" and not hasattr(model, "device_map"): model = model.to(DEVICE) model.eval() # Get model info param_count = sum(p.numel() for p in model.parameters()) / 1e9 dtype_str = str(dtype).split(".")[-1] return f"✅ Loaded: {name} | {param_count:.1f}B params | {dtype_str} | {DEVICE.upper()}" except Exception as e: cleanup_memory() return f"❌ Failed to load {name}: {str(e)}" # Initialize default model load_model(DEFAULT_NAME) @spaces.GPU(duration=60) def _extract_grid(img): """Extract feature grid from image""" with torch.inference_mode(): pv = processor(images=img, return_tensors="pt").pixel_values if DEVICE == "cuda": pv = pv.to(DEVICE) out = model(pixel_values=pv) last = out.last_hidden_state[0].to(torch.float32) num_reg = getattr(model.config, "num_register_tokens", 0) p = model.config.patch_size _, _, Ht, Wt = pv.shape gh, gw = Ht // p, Wt // p feats = last[1 + num_reg :, :].reshape(gh, gw, -1).cpu() return feats, gh, gw def _overlay(orig, heat01, alpha=0.55, box=None): """Create heatmap overlay with improved visualization""" H, W = orig.height, orig.width heat = Image.fromarray((heat01 * 255).astype(np.uint8)).resize( (W, H), resample=Image.LANCZOS ) # Use a better colormap for satellite imagery rgba = (cm.get_cmap("turbo")(np.asarray(heat) / 255.0) * 255).astype(np.uint8) ov = Image.fromarray(rgba, "RGBA") ov.putalpha(int(alpha * 255)) base = orig.copy().convert("RGBA") out = Image.alpha_composite(base, ov) if box: from PIL import ImageDraw draw = ImageDraw.Draw(out, "RGBA") # Enhanced box visualization draw.rectangle(box, outline=(255, 255, 255, 255), width=3) draw.rectangle( (box[0] - 1, box[1] - 1, box[2] + 1, box[3] + 1), outline=(0, 0, 0, 200), width=1, ) return out def prepare(img): """Prepare image and extract features""" if img is None: return None base = ImageOps.exif_transpose(img.convert("RGB")) feats, gh, gw = _extract_grid(base) return {"orig": base, "feats": feats, "gh": gh, "gw": gw} def click(state, opacity, colormap, img_value, evt: gr.SelectData): """Handle click events for similarity visualization""" # If state wasn't prepared, build it now if state is None and img_value is not None: state = prepare(img_value) if not state or evt.index is None: return img_value, state, None base, feats, gh, gw = state["orig"], state["feats"], state["gh"], state["gw"] x, y = evt.index px_x, px_y = base.width / gw, base.height / gh i = min(int(x // px_x), gw - 1) j = min(int(y // px_y), gh - 1) d = feats.shape[-1] grid = F.normalize(feats.reshape(-1, d), dim=1) v = F.normalize(feats[j, i].reshape(1, d), dim=1) sims = (grid @ v.T).reshape(gh, gw).numpy() smin, smax = float(sims.min()), float(sims.max()) heat01 = (sims - smin) / (smax - smin + 1e-12) # Update colormap dynamically cm_func = cm.get_cmap(colormap.lower()) rgba = (cm_func(heat01) * 255).astype(np.uint8) ov = Image.fromarray(rgba, "RGBA") ov.putalpha(int(opacity * 255)) base_rgba = base.copy().convert("RGBA") box = (int(i * px_x), int(j * px_y), int((i + 1) * px_x), int((j + 1) * px_y)) out = Image.alpha_composite(base_rgba, ov) if box: from PIL import ImageDraw draw = ImageDraw.Draw(out, "RGBA") draw.rectangle(box, outline=(255, 255, 255, 255), width=3) draw.rectangle( (box[0] - 1, box[1] - 1, box[2] + 1, box[3] + 1), outline=(0, 0, 0, 200), width=1, ) # Stats for info panel stats = f"""📊 **Similarity Statistics** - Min: {smin:.3f} - Max: {smax:.3f} - Range: {smax - smin:.3f} - Patch: ({i}, {j}) - Grid: {gw}×{gh}""" return out, state, stats def reset(): """Reset the interface""" return None, None, None # Build the interface with gr.Blocks( theme=gr.themes.Soft( primary_hue="blue", secondary_hue="gray", neutral_hue="gray", font=gr.themes.GoogleFont("Inter"), ), css=""" .container {max-width: 1200px; margin: auto;} .header {text-align: center; padding: 20px;} .info-box { background: rgba(0,0,0,0.03); border-radius: 8px; padding: 12px; margin: 10px 0; border-left: 4px solid #2563eb; } """, ) as demo: gr.HTML( """

🛰️ DINOv3 Satellite Vision: Interactive Patch Similarity

Explore how DINOv3 models trained on satellite imagery understand visual patterns

""" ) with gr.Row(): with gr.Column(scale=3): gr.Markdown( """ ### How it works 1. **Select a model** - Satellite-pretrained models are optimized for aerial/satellite imagery 2. **Upload or select an image** - Works best with satellite, aerial, or outdoor scenes 3. **Click any region** - See how similar other patches are to your selection 4. **Adjust visualization** - Fine-tune opacity and colormap for clarity """ ) with gr.Column(scale=2): gr.HTML( """
💡 Model Info:
Satellite models: Trained on 493M satellite images
LVD model: Trained on 1.7B diverse images
7B model: Massive capacity, slower but more nuanced
""" ) with gr.Row(): with gr.Column(scale=1): model_choice = gr.Dropdown( choices=list(MODEL_MAP.keys()), value=DEFAULT_NAME, label="🤖 Model Selection", info="Satellite models excel at geographic and structural patterns", ) status = gr.Textbox( label="📡 Model Status", value=f"Ready: {DEFAULT_NAME}", interactive=False, lines=1, ) with gr.Row(): opacity = gr.Slider( 0.2, 0.9, 0.55, step=0.05, label="🎨 Heatmap Opacity", info="Balance between image and similarity map", ) colormap = gr.Dropdown( choices=["Turbo", "Inferno", "Viridis", "Plasma", "Magma", "Jet"], value="Turbo", label="🌈 Colormap", info="Different maps for different contrasts", ) info_panel = gr.Markdown(value=None, label="Statistics", visible=True) with gr.Row(): reset_btn = gr.Button("🔄 Reset", variant="secondary", scale=1) clear_btn = gr.ClearButton(value="🗑️ Clear All", scale=1) with gr.Column(scale=2): img = gr.Image( type="pil", label="Interactive Canvas (Click to explore)", interactive=True, height=600, show_download_button=True, show_share_button=False, ) state = gr.State() # Examples focused on satellite-relevant imagery gr.Examples( examples=[ [_filepath.name] for _filepath in Path.cwd().iterdir() if _filepath.suffix.lower() in [".jpg", ".png", ".webp"] ], inputs=img, fn=prepare, outputs=[state], label="Example Images", examples_per_page=6, cache_examples=False, ) # Event handlers model_choice.change( load_model, inputs=model_choice, outputs=status, show_progress="full" ) img.upload(prepare, inputs=img, outputs=state, show_progress="minimal") img.select( click, inputs=[state, opacity, colormap, img], outputs=[img, state, info_panel], show_progress="minimal", ) reset_btn.click(reset, outputs=[img, state, info_panel], show_progress=False) clear_btn.add([img, state, info_panel]) gr.Markdown( """ ---
Performance Notes: Satellite models are optimized for geographic patterns, land use classification, and structural analysis. The 7B model provides exceptional detail but requires significant compute.

Built with DINOv3 | Optimized for satellite and aerial imagery analysis
""" ) if __name__ == "__main__": demo.launch(share=False, show_error=True)