import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms from PIL import Image import gradio as gr from transformers import AutoModel # === Path to model file (in same folder) === path_to_model = "dino2_classifier_cropped_body_v1.pth" # === Padding function === def pad_to_square(image, fill=0): w, h = image.size max_dim = max(w, h) padded = Image.new(image.mode, (max_dim, max_dim), fill) padded.paste(image, ((max_dim - w) // 2, (max_dim - h) // 2)) return padded # === Transform setup === def setup_transform(use_padding=True, use_augmentation=False): base_transforms = [] if use_padding: base_transforms.append(lambda img: pad_to_square(img)) base_transforms.append(transforms.Resize(224)) else: base_transforms.extend([ transforms.Resize(256), transforms.CenterCrop(224) ]) augmentation_transforms = [] if use_augmentation: augmentation_transforms.extend([ transforms.RandomRotation(degrees=10), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1), transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)), transforms.RandomResizedCrop(224, scale=(0.9, 1.0)), transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 0.5)), ]) final_transforms = [ transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ] return transforms.Compose(base_transforms + augmentation_transforms + final_transforms) # === Custom DINOv2 classifier === class DINOv2ArcFace(nn.Module): def __init__(self, usage='classifier', num_classes=33, embedding_dim=512, margin=0.5, scale=64.0): super().__init__() self.usage = usage self.num_classes = num_classes self.embedding_dim = embedding_dim self.margin = margin self.scale = scale self.dropout = nn.Dropout(p=0.5) self.backbone = AutoModel.from_pretrained("facebook/dinov2-base") if self.usage == 'finetune': self.embedding = nn.Linear(self.backbone.config.hidden_size, self.embedding_dim) elif self.usage == 'classifier': self.backbone.requires_grad_(False) in_features = self.backbone.config.hidden_size self.classifier = nn.Linear(in_features, self.num_classes) elif self.usage == 'embeddings': self.embedding = nn.Linear(self.backbone.config.hidden_size, self.embedding_dim) def forward(self, x, labels=None): features = self.backbone(x).last_hidden_state[:, 0, :] # CLS token if self.usage == 'classifier': features = self.dropout(features) logits = self.classifier(features) return logits elif self.usage == 'embeddings': embeddings = F.normalize(self.embedding(features), p=2, dim=1) return embeddings else: raise ValueError("Use mode 'classifier' or 'embeddings' for inference") # === Load model === NUM_CLASSES = 33 model = DINOv2ArcFace(usage="classifier", num_classes=NUM_CLASSES) model.load_state_dict(torch.load(path_to_model, map_location="cpu")) model.eval() # === Class mapping === class_names = { 0: 'Abril', 1: 'Akaloi', 2: 'Alira', 3: 'Apeiara', 4: 'Ariely', 5: 'Bagua', 6: 'Benita', 7: 'Bernard', 8: 'Bororo', 9: 'Estella', 10: 'Guaraci', 11: 'Ipepo', 12: 'Jaju', 13: 'Kamaikua', 14: 'Kasimir', 15: 'Katniss', 16: 'Kwang', 17: 'Lua', 18: 'Madalena', 19: 'Marcela', 20: 'Medrosa', 21: 'Ousado', 22: 'Overa', 23: 'Oxum', 24: 'Patricia', 25: 'Pixana', 26: 'Pollyanna', 27: 'Pyte', 28: 'Saseka', 29: 'Solar', 30: 'Ti', 31: 'Tomas', 32: 'unknown' } # === Apply your transform === transform = setup_transform(use_padding=True, use_augmentation=False) # === Gradio prediction function === def predict(image): image = image.convert("RGB") img_tensor = transform(image).unsqueeze(0) with torch.no_grad(): logits = model(img_tensor) probs = torch.nn.functional.softmax(logits[0], dim=0) return {class_names[i]: float(probs[i]) for i in range(NUM_CLASSES)} # === Gradio UI === gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=5), title="🐆 Jaguar Identifier (DINOv2 + ArcFace)", description="Upload an image of a jaguar. The model will classify it among 33 known individuals." ).launch()