prithivMLmods commited on
Commit
0a3df57
·
verified ·
1 Parent(s): 3163833

update app

Browse files
Files changed (1) hide show
  1. app.py +68 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModel, AutoProcessor
3
+ import gradio as gr
4
+ from PIL import Image
5
+ import requests
6
+
7
+ model = AutoModel.from_pretrained("facebook/metaclip-2-mt5-worldwide-s16", torch_dtype=torch.bfloat16, attn_implementation="sdpa")
8
+ processor = AutoProcessor.from_pretrained("facebook/metaclip-2-mt5-worldwide-s16")
9
+
10
+ def postprocess_metaclip(probs, labels):
11
+ output = {labels[i]: probs[0][i].item() for i in range(len(labels))}
12
+ return output
13
+
14
+
15
+ def metaclip_detector(image, texts):
16
+ inputs = processor(text=texts, images=image, return_tensors="pt", padding=True)
17
+ with torch.no_grad():
18
+ outputs = model(**inputs)
19
+ logits_per_image = outputs.logits_per_image
20
+ probs = logits_per_image.softmax(dim=1)
21
+ return probs
22
+
23
+
24
+ def infer(image, candidate_labels):
25
+ candidate_labels = [label.lstrip(" ") for label in candidate_labels.split(",")]
26
+ probs = metaclip_detector(image, candidate_labels)
27
+ return postprocess_metaclip(probs, labels=candidate_labels)
28
+
29
+ with gr.Blocks() as demo:
30
+ gr.Markdown("# MetaCLIP 2 Zero-Shot Classification")
31
+ gr.Markdown(
32
+ "Test the performance of MetaCLIP 2 on zero-shot classification in this Space :point_down:"
33
+ )
34
+ with gr.Row():
35
+ with gr.Column():
36
+ image_input = gr.Image(type="pil")
37
+ text_input = gr.Textbox(label="Input a list of labels (comma seperated)")
38
+ run_button = gr.Button("Run", visible=True)
39
+ with gr.Column():
40
+ metaclip_output = gr.Label(label="MetaCLIP 2 Output", num_top_classes=3)
41
+
42
+ # It's recommended to have local images for the examples
43
+ # For demonstration purposes, we will download them if they don't exist.
44
+ def download_image(url, filename):
45
+ import os
46
+ if not os.path.exists(filename):
47
+ response = requests.get(url, stream=True)
48
+ response.raise_for_status()
49
+ with open(filename, 'wb') as f:
50
+ for chunk in response.iter_content(chunk_size=8192):
51
+ f.write(chunk)
52
+
53
+ download_image("https://gradio-builds.s3.amazonaws.com/demo-files/baklava.jpg", "baklava.jpg")
54
+ download_image("https://gradio-builds.s3.amazonaws.com/demo-files/cat.jpg", "cat.jpg")
55
+
56
+ examples = [
57
+ ["./baklava.jpg", "dessert on a plate, a serving of baklava, a plate and spoon"],
58
+ ["./cat.jpg", "a cat, two cats, three cats"],
59
+ ["./cat.jpg", "two sleeping cats, two cats playing, three cats laying down"],
60
+ ]
61
+ gr.Examples(
62
+ examples=examples,
63
+ inputs=[image_input, text_input],
64
+ outputs=[metaclip_output],
65
+ fn=infer,
66
+ )
67
+ run_button.click(fn=infer, inputs=[image_input, text_input], outputs=[metaclip_output])
68
+ demo.launch()