get-c-image / app.py
vinithius's picture
Update app.py
f7d8642 verified
raw
history blame
1.3 kB
import torch
from torch import nn
from transformers import AutoImageProcessor, AutoModel
import gradio as gr
import numpy as np
from PIL import Image
# Nome do modelo no Hugging Face Hub
MODEL_NAME = "facebook/dinov2-small"
# Carregando processador e modelo
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
# Projeção para 512D (caso a saída seja >512, reduzimos)
projection = nn.Linear(model.config.hidden_size, 512)
def get_embedding(image: Image.Image):
# Preprocessamento
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
# Usando o CLS token como embedding da imagem
last_hidden_state = outputs.last_hidden_state # (batch, seq_len, hidden)
embedding = last_hidden_state[:, 0] # pegando o [CLS] token
# Projeta para 512D
embedding_512 = projection(embedding)
# Converte para lista Python
return embedding_512.squeeze().tolist()
# Cria API com Gradio (sem interface visual, apenas endpoint)
iface = gr.Interface(
fn=get_embedding,
inputs=gr.Image(type="pil"),
outputs=gr.JSON(),
live=False,
api_name="embed" # endpoint em /embed
)
if __name__ == "__main__":
iface.launch()