import torch from transformers import AutoModel, AutoProcessor import gradio as gr from PIL import Image import requests model = AutoModel.from_pretrained("facebook/metaclip-2-mt5-worldwide-s16", torch_dtype=torch.bfloat16, attn_implementation="sdpa") processor = AutoProcessor.from_pretrained("facebook/metaclip-2-mt5-worldwide-s16") def postprocess_metaclip(probs, labels): output = {labels[i]: probs[0][i].item() for i in range(len(labels))} return output def metaclip_detector(image, texts): inputs = processor(text=texts, images=image, return_tensors="pt", padding=True) with torch.no_grad(): outputs = model(**inputs) logits_per_image = outputs.logits_per_image probs = logits_per_image.softmax(dim=1) return probs def infer(image, candidate_labels): candidate_labels = [label.lstrip(" ") for label in candidate_labels.split(",")] probs = metaclip_detector(image, candidate_labels) return postprocess_metaclip(probs, labels=candidate_labels) with gr.Blocks() as demo: gr.Markdown("# MetaCLIP 2 Zero-Shot Classification") gr.Markdown( "Test the performance of MetaCLIP 2 on zero-shot classification in this Space :point_down:" ) with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil") text_input = gr.Textbox(label="Input a list of labels (comma seperated)") run_button = gr.Button("Run", visible=True) with gr.Column(): metaclip_output = gr.Label(label="MetaCLIP 2 Output", num_top_classes=3) # It's recommended to have local images for the examples # For demonstration purposes, we will download them if they don't exist. def download_image(url, filename): import os if not os.path.exists(filename): response = requests.get(url, stream=True) response.raise_for_status() with open(filename, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) download_image("https://gradio-builds.s3.amazonaws.com/demo-files/baklava.jpg", "baklava.jpg") download_image("https://gradio-builds.s3.amazonaws.com/demo-files/cat.jpg", "cat.jpg") examples = [ ["./baklava.jpg", "dessert on a plate, a serving of baklava, a plate and spoon"], ["./cat.jpg", "a cat, two cats, three cats"], ["./cat.jpg", "two sleeping cats, two cats playing, three cats laying down"], ] gr.Examples( examples=examples, inputs=[image_input, text_input], outputs=[metaclip_output], fn=infer, ) run_button.click(fn=infer, inputs=[image_input, text_input], outputs=[metaclip_output]) demo.launch()