Spaces:
Running
on
Zero
Running
on
Zero
| import gc | |
| from pathlib import Path | |
| import gradio as gr | |
| import matplotlib.cm as cm | |
| import numpy as np | |
| import spaces | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image, ImageOps | |
| from transformers import AutoImageProcessor, AutoModel | |
| # Device configuration with memory management | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| MODEL_MAP = { | |
| "DINOv3 ViT-L/16 Satellite": "facebook/dinov3-vitl16-pretrain-sat493m", | |
| "DINOv3 ViT-L/16 LVD (General Web)": "facebook/dinov3-vitl16-pretrain-lvd1689m", | |
| "⚠️ DINOv3 ViT-7B/16 Satellite": "facebook/dinov3-vit7b16-pretrain-sat493m", | |
| } | |
| DEFAULT_NAME = list(MODEL_MAP.keys())[0] | |
| # Global model state | |
| processor = None | |
| model = None | |
| def cleanup_memory(): | |
| """Aggressive memory cleanup for model switching""" | |
| global processor, model | |
| if model is not None: | |
| del model | |
| model = None | |
| if processor is not None: | |
| del processor | |
| processor = None | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| def load_model(name): | |
| """Load model with proper memory management and dtype handling""" | |
| global processor, model | |
| try: | |
| # Clean up existing model | |
| cleanup_memory() | |
| model_id = MODEL_MAP[name] | |
| # Load with auto dtype for optimal performance | |
| processor = AutoImageProcessor.from_pretrained(model_id) | |
| # Determine optimal dtype based on model size and hardware | |
| if "7b" in model_id.lower() and torch.cuda.is_available(): | |
| # For 7B model, use bfloat16 if available for memory efficiency | |
| dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 | |
| else: | |
| dtype = torch.float32 | |
| model = AutoModel.from_pretrained( | |
| model_id, | |
| torch_dtype=dtype, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| ) | |
| if DEVICE == "cuda" and not hasattr(model, "device_map"): | |
| model = model.to(DEVICE) | |
| model.eval() | |
| # Get model info | |
| param_count = sum(p.numel() for p in model.parameters()) / 1e9 | |
| dtype_str = str(dtype).split(".")[-1] | |
| return f"✅ Loaded: {name} | {param_count:.1f}B params | {dtype_str} | {DEVICE.upper()}" | |
| except Exception as e: | |
| cleanup_memory() | |
| return f"❌ Failed to load {name}: {str(e)}" | |
| # Initialize default model | |
| load_model(DEFAULT_NAME) | |
| def _extract_grid(img): | |
| """Extract feature grid from image""" | |
| with torch.inference_mode(): | |
| pv = processor(images=img, return_tensors="pt").pixel_values | |
| if DEVICE == "cuda": | |
| pv = pv.to(DEVICE) | |
| out = model(pixel_values=pv) | |
| last = out.last_hidden_state[0].to(torch.float32) | |
| num_reg = getattr(model.config, "num_register_tokens", 0) | |
| p = model.config.patch_size | |
| _, _, Ht, Wt = pv.shape | |
| gh, gw = Ht // p, Wt // p | |
| feats = last[1 + num_reg :, :].reshape(gh, gw, -1).cpu() | |
| return feats, gh, gw | |
| def _overlay(orig, heat01, alpha=0.55, box=None): | |
| """Create heatmap overlay with improved visualization""" | |
| H, W = orig.height, orig.width | |
| heat = Image.fromarray((heat01 * 255).astype(np.uint8)).resize( | |
| (W, H), resample=Image.LANCZOS | |
| ) | |
| # Use a better colormap for satellite imagery | |
| rgba = (cm.get_cmap("turbo")(np.asarray(heat) / 255.0) * 255).astype(np.uint8) | |
| ov = Image.fromarray(rgba, "RGBA") | |
| ov.putalpha(int(alpha * 255)) | |
| base = orig.copy().convert("RGBA") | |
| out = Image.alpha_composite(base, ov) | |
| if box: | |
| from PIL import ImageDraw | |
| draw = ImageDraw.Draw(out, "RGBA") | |
| # Enhanced box visualization | |
| draw.rectangle(box, outline=(255, 255, 255, 255), width=3) | |
| draw.rectangle( | |
| (box[0] - 1, box[1] - 1, box[2] + 1, box[3] + 1), | |
| outline=(0, 0, 0, 200), | |
| width=1, | |
| ) | |
| return out | |
| def prepare(img): | |
| """Prepare image and extract features""" | |
| if img is None: | |
| return None | |
| base = ImageOps.exif_transpose(img.convert("RGB")) | |
| feats, gh, gw = _extract_grid(base) | |
| return {"orig": base, "feats": feats, "gh": gh, "gw": gw} | |
| def click(state, opacity, colormap, img_value, evt: gr.SelectData): | |
| """Handle click events for similarity visualization""" | |
| # If state wasn't prepared, build it now | |
| if state is None and img_value is not None: | |
| state = prepare(img_value) | |
| if not state or evt.index is None: | |
| return img_value, state, None | |
| base, feats, gh, gw = state["orig"], state["feats"], state["gh"], state["gw"] | |
| x, y = evt.index | |
| px_x, px_y = base.width / gw, base.height / gh | |
| i = min(int(x // px_x), gw - 1) | |
| j = min(int(y // px_y), gh - 1) | |
| d = feats.shape[-1] | |
| grid = F.normalize(feats.reshape(-1, d), dim=1) | |
| v = F.normalize(feats[j, i].reshape(1, d), dim=1) | |
| sims = (grid @ v.T).reshape(gh, gw).numpy() | |
| smin, smax = float(sims.min()), float(sims.max()) | |
| heat01 = (sims - smin) / (smax - smin + 1e-12) | |
| # Update colormap dynamically | |
| cm_func = cm.get_cmap(colormap.lower()) | |
| rgba = (cm_func(heat01) * 255).astype(np.uint8) | |
| ov = Image.fromarray(rgba, "RGBA") | |
| ov.putalpha(int(opacity * 255)) | |
| base_rgba = base.copy().convert("RGBA") | |
| box = (int(i * px_x), int(j * px_y), int((i + 1) * px_x), int((j + 1) * px_y)) | |
| out = Image.alpha_composite(base_rgba, ov) | |
| if box: | |
| from PIL import ImageDraw | |
| draw = ImageDraw.Draw(out, "RGBA") | |
| draw.rectangle(box, outline=(255, 255, 255, 255), width=3) | |
| draw.rectangle( | |
| (box[0] - 1, box[1] - 1, box[2] + 1, box[3] + 1), | |
| outline=(0, 0, 0, 200), | |
| width=1, | |
| ) | |
| # Stats for info panel | |
| stats = f"""📊 **Similarity Statistics** | |
| - Min: {smin:.3f} | |
| - Max: {smax:.3f} | |
| - Range: {smax - smin:.3f} | |
| - Patch: ({i}, {j}) | |
| - Grid: {gw}×{gh}""" | |
| return out, state, stats | |
| def reset(): | |
| """Reset the interface""" | |
| return None, None, None | |
| # Build the interface | |
| with gr.Blocks( | |
| theme=gr.themes.Soft( | |
| primary_hue="blue", | |
| secondary_hue="gray", | |
| neutral_hue="gray", | |
| font=gr.themes.GoogleFont("Inter"), | |
| ), | |
| css=""" | |
| .container {max-width: 1200px; margin: auto;} | |
| .header {text-align: center; padding: 20px;} | |
| .info-box { | |
| background: rgba(0,0,0,0.03); | |
| border-radius: 8px; | |
| padding: 12px; | |
| margin: 10px 0; | |
| border-left: 4px solid #2563eb; | |
| } | |
| """, | |
| ) as demo: | |
| gr.HTML( | |
| """ | |
| <div class="header"> | |
| <h1>🛰️ DINOv3 Satellite Vision: Interactive Patch Similarity</h1> | |
| <p style="font-size: 1.1em; color: #666;"> | |
| Explore how DINOv3 models trained on satellite imagery understand visual patterns | |
| </p> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| gr.Markdown( | |
| """ | |
| ### How it works | |
| 1. **Select a model** - Satellite-pretrained models are optimized for aerial/satellite imagery | |
| 2. **Upload or select an image** - Works best with satellite, aerial, or outdoor scenes | |
| 3. **Click any region** - See how similar other patches are to your selection | |
| 4. **Adjust visualization** - Fine-tune opacity and colormap for clarity | |
| """ | |
| ) | |
| with gr.Column(scale=2): | |
| gr.HTML( | |
| """ | |
| <div class="info-box"> | |
| <b>💡 Model Info:</b><br> | |
| • <b>Satellite models</b>: Trained on 493M satellite images<br> | |
| • <b>LVD model</b>: Trained on 1.7B diverse images<br> | |
| • <b>7B model</b>: Massive capacity, slower but more nuanced | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| model_choice = gr.Dropdown( | |
| choices=list(MODEL_MAP.keys()), | |
| value=DEFAULT_NAME, | |
| label="🤖 Model Selection", | |
| info="Satellite models excel at geographic and structural patterns", | |
| ) | |
| status = gr.Textbox( | |
| label="📡 Model Status", | |
| value=f"Ready: {DEFAULT_NAME}", | |
| interactive=False, | |
| lines=1, | |
| ) | |
| with gr.Row(): | |
| opacity = gr.Slider( | |
| 0.2, | |
| 0.9, | |
| 0.55, | |
| step=0.05, | |
| label="🎨 Heatmap Opacity", | |
| info="Balance between image and similarity map", | |
| ) | |
| colormap = gr.Dropdown( | |
| choices=["Turbo", "Inferno", "Viridis", "Plasma", "Magma", "Jet"], | |
| value="Turbo", | |
| label="🌈 Colormap", | |
| info="Different maps for different contrasts", | |
| ) | |
| info_panel = gr.Markdown(value=None, label="Statistics", visible=True) | |
| with gr.Row(): | |
| reset_btn = gr.Button("🔄 Reset", variant="secondary", scale=1) | |
| clear_btn = gr.ClearButton(value="🗑️ Clear All", scale=1) | |
| with gr.Column(scale=2): | |
| img = gr.Image( | |
| type="pil", | |
| label="Interactive Canvas (Click to explore)", | |
| interactive=True, | |
| height=600, | |
| show_download_button=True, | |
| show_share_button=False, | |
| ) | |
| state = gr.State() | |
| # Examples focused on satellite-relevant imagery | |
| gr.Examples( | |
| examples=[ | |
| [_filepath.name] | |
| for _filepath in Path.cwd().iterdir() | |
| if _filepath.suffix.lower() in [".jpg", ".png", ".webp"] | |
| ], | |
| inputs=img, | |
| fn=prepare, | |
| outputs=[state], | |
| label="Example Images", | |
| examples_per_page=6, | |
| cache_examples=False, | |
| ) | |
| # Event handlers | |
| model_choice.change( | |
| load_model, inputs=model_choice, outputs=status, show_progress="full" | |
| ) | |
| img.upload(prepare, inputs=img, outputs=state, show_progress="minimal") | |
| img.select( | |
| click, | |
| inputs=[state, opacity, colormap, img], | |
| outputs=[img, state, info_panel], | |
| show_progress="minimal", | |
| ) | |
| reset_btn.click(reset, outputs=[img, state, info_panel], show_progress=False) | |
| clear_btn.add([img, state, info_panel]) | |
| gr.Markdown( | |
| """ | |
| --- | |
| <div style="text-align: center; color: #666; font-size: 0.9em;"> | |
| <b>Performance Notes:</b> Satellite models are optimized for geographic patterns, land use classification, | |
| and structural analysis. The 7B model provides exceptional detail but requires significant compute. | |
| <br><br> | |
| Built with DINOv3 | Optimized for satellite and aerial imagery analysis | |
| </div> | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=False, show_error=True) | |