Spaces:
Running
on
Zero
Running
on
Zero
| import gc | |
| from pathlib import Path | |
| import gradio as gr | |
| import matplotlib.cm as cm | |
| import numpy as np | |
| import spaces | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image, ImageOps | |
| from transformers import AutoImageProcessor, AutoModel | |
| # Device configuration with memory management | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| MODEL_MAP = { | |
| "DINOv3 ViT-L/16 Satellite (493M)": "facebook/dinov3-vitl16-pretrain-sat493m", | |
| "DINOv3 ViT-L/16 LVD (1.7B web)": "facebook/dinov3-vitl16-pretrain-lvd1689m", | |
| "DINOv3 ViT-7B/16 Satellite": "facebook/dinov3-vit7b16-pretrain-sat493m", | |
| } | |
| DEFAULT_NAME = list(MODEL_MAP.keys())[0] | |
| # Global model state | |
| processor = None | |
| model = None | |
| def cleanup_memory(): | |
| """Aggressive memory cleanup for model switching""" | |
| global processor, model | |
| if model is not None: | |
| del model | |
| model = None | |
| if processor is not None: | |
| del processor | |
| processor = None | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # torch.cuda.synchronize() | |
| def load_model(name): | |
| """Load model with CORRECT dtype""" | |
| global processor, model | |
| cleanup_memory() | |
| model_id = MODEL_MAP[name] | |
| processor = AutoImageProcessor.from_pretrained(model_id) | |
| model = AutoModel.from_pretrained( | |
| model_id, | |
| torch_dtype="auto", | |
| ).eval() | |
| param_count = sum(p.numel() for p in model.parameters()) / 1e9 | |
| return f"Loaded: {name} | {param_count:.1f}B params | Ready" | |
| # Initialize default model | |
| load_model(DEFAULT_NAME) | |
| def _extract_grid(img): | |
| """Extract feature grid from image""" | |
| global model | |
| with torch.inference_mode(): | |
| # Move model to GPU for this call | |
| model = model.to('cuda') | |
| # Process image and move to GPU | |
| pv = processor(images=img, return_tensors="pt").pixel_values.to(model.device) | |
| # Run inference | |
| out = model(pixel_values=pv) | |
| last = out.last_hidden_state[0].to(torch.float32) | |
| # Extract features | |
| num_reg = getattr(model.config, "num_register_tokens", 0) | |
| p = model.config.patch_size | |
| _, _, Ht, Wt = pv.shape | |
| gh, gw = Ht // p, Wt // p | |
| feats = last[1 + num_reg:, :].reshape(gh, gw, -1).cpu() | |
| # Move model back to CPU before function exits | |
| model = model.cpu() | |
| torch.cuda.empty_cache() | |
| return feats, gh, gw | |
| def _overlay(orig, heat01, alpha=0.55, box=None): | |
| """Create heatmap overlay""" | |
| H, W = orig.height, orig.width | |
| heat = Image.fromarray((heat01 * 255).astype(np.uint8)).resize((W, H)) | |
| # Use turbo colormap - better for satellite imagery | |
| rgba = (cm.get_cmap("turbo")(np.asarray(heat) / 255.0) * 255).astype(np.uint8) | |
| ov = Image.fromarray(rgba, "RGBA") | |
| ov.putalpha(int(alpha * 255)) | |
| base = orig.copy().convert("RGBA") | |
| out = Image.alpha_composite(base, ov) | |
| if box: | |
| from PIL import ImageDraw | |
| draw = ImageDraw.Draw(out, "RGBA") | |
| # Enhanced box visualization | |
| draw.rectangle(box, outline=(255, 255, 255, 255), width=3) | |
| draw.rectangle( | |
| (box[0] - 1, box[1] - 1, box[2] + 1, box[3] + 1), | |
| outline=(0, 0, 0, 200), | |
| width=1, | |
| ) | |
| return out | |
| def prepare(img): | |
| """Prepare image and extract features""" | |
| if img is None: | |
| return None | |
| base = ImageOps.exif_transpose(img.convert("RGB")) | |
| feats, gh, gw = _extract_grid(base) | |
| return {"orig": base, "feats": feats, "gh": gh, "gw": gw} | |
| def click(state, opacity, img_value, evt: gr.SelectData): | |
| """Handle click events for similarity visualization""" | |
| # If state wasn't prepared (e.g., Example selection), build it now | |
| if state is None and img_value is not None: | |
| state = prepare(img_value) | |
| if not state or evt.index is None: | |
| # Just show whatever is currently in the image component | |
| return img_value, state | |
| base, feats, gh, gw = state["orig"], state["feats"], state["gh"], state["gw"] | |
| x, y = evt.index | |
| px_x, px_y = base.width / gw, base.height / gh | |
| i = min(int(x // px_x), gw - 1) | |
| j = min(int(y // px_y), gh - 1) | |
| d = feats.shape[-1] | |
| grid = F.normalize(feats.reshape(-1, d), dim=1) | |
| v = F.normalize(feats[j, i].reshape(1, d), dim=1) | |
| sims = (grid @ v.T).reshape(gh, gw).numpy() | |
| smin, smax = float(sims.min()), float(sims.max()) | |
| heat01 = (sims - smin) / (smax - smin + 1e-12) | |
| box = (int(i * px_x), int(j * px_y), int((i + 1) * px_x), int((j + 1) * px_y)) | |
| overlay = _overlay(base, heat01, alpha=opacity, box=box) | |
| return overlay, state | |
| def reset(): | |
| """Reset the interface""" | |
| return None, None | |
| with gr.Blocks( | |
| theme=gr.themes.Citrus(), | |
| css=""" | |
| .container {max-width: 1200px; margin: auto;} | |
| .header {text-align: center; padding: 20px;} | |
| .info-box { | |
| background: rgba(0,0,0,0.03); | |
| border-radius: 8px; | |
| padding: 12px; | |
| margin: 10px 0; | |
| border-left: 4px solid #2563eb; | |
| } | |
| """, | |
| ) as demo: | |
| gr.HTML( | |
| """ | |
| <div class="header"> | |
| <h1>🛰️ DINOv3 Satellite Vision: Interactive Patch Similarity</h1> | |
| <p style="font-size: 1.1em; color: #666;"> | |
| Click any region to visualize feature similarities across the image | |
| </p> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| model_choice = gr.Dropdown( | |
| choices=list(MODEL_MAP.keys()), | |
| value=DEFAULT_NAME, | |
| label="Model Selection", | |
| info="Select a model (size/pretraining dataset)", | |
| ) | |
| status = gr.Textbox( | |
| label="Model Status", | |
| value=f"Loaded: {DEFAULT_NAME}", | |
| interactive=False, | |
| lines=1, | |
| ) | |
| opacity = gr.Slider( | |
| 0.0, | |
| 1.0, | |
| 0.55, | |
| step=0.05, | |
| label="Heatmap Opacity", | |
| info="Balance between image and similarity map", | |
| ) | |
| with gr.Row(): | |
| reset_btn = gr.Button("Reset", variant="secondary", scale=1) | |
| clear_btn = gr.ClearButton(value="Clear All", scale=1) | |
| with gr.Column(scale=2): | |
| img = gr.Image( | |
| type="pil", | |
| label="Interactive Canvas (Click to explore)", | |
| interactive=True, | |
| height=600, | |
| show_download_button=True, | |
| show_share_button=False, | |
| ) | |
| state = gr.State() | |
| model_choice.change( | |
| load_model, inputs=model_choice, outputs=status, show_progress="full" | |
| ) | |
| img.upload(prepare, inputs=img, outputs=state) | |
| img.select( | |
| click, | |
| inputs=[state, opacity, img], | |
| outputs=[img, state], | |
| show_progress="minimal", | |
| ) | |
| reset_btn.click(reset, outputs=[img, state]) | |
| clear_btn.add([img, state]) | |
| # Examples from current directory | |
| example_files = [ | |
| f.name | |
| for f in Path.cwd().iterdir() | |
| if f.suffix.lower() in [".jpg", ".jpeg", ".png", ".webp"] | |
| ] | |
| if example_files: | |
| gr.Examples( | |
| examples=[[f] for f in example_files], | |
| inputs=img, | |
| fn=prepare, | |
| outputs=[state], | |
| label="Example Images", | |
| examples_per_page=4, | |
| cache_examples=False, | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| <div style="text-align: center; color: #666; font-size: 0.9em;"> | |
| <b>Performance Notes:</b> Satellite models are optimized for geographic patterns, land use classification, | |
| and structural analysis. The 7B model provides exceptional detail but requires significant compute. | |
| <br><br> | |
| Built with DINOv3 | Optimized for satellite and aerial imagery analysis | |
| </div> | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=False, debug=True) |