Spaces:
Runtime error
Runtime error
update app.py
Browse files- .gitignore +4 -0
- app.py +15 -16
- requirements.txt +1 -0
.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
*.swp
|
| 3 |
+
hf_models/
|
| 4 |
+
pretrained_models/
|
app.py
CHANGED
|
@@ -6,11 +6,12 @@ CC BY-NC-4.0 (https://creativecommons.org/licenses/by-nc/4.0/)
|
|
| 6 |
import os
|
| 7 |
import time
|
| 8 |
from argparse import ArgumentParser
|
|
|
|
| 9 |
|
| 10 |
import numpy as np
|
| 11 |
import torch
|
| 12 |
import gradio as gr
|
| 13 |
-
|
| 14 |
|
| 15 |
from encode_with_pseudo_tokens import encode_with_pseudo_tokens_HF
|
| 16 |
from models import build_text_encoder, Phi, PIC2WORD
|
|
@@ -19,6 +20,7 @@ import transformers
|
|
| 19 |
from huggingface_hub import hf_hub_url, cached_download
|
| 20 |
|
| 21 |
|
|
|
|
| 22 |
def parse_args():
|
| 23 |
parser = ArgumentParser()
|
| 24 |
parser.add_argument("--lincir_ckpt_path", default=None, type=str,
|
|
@@ -100,6 +102,7 @@ def load_models(args):
|
|
| 100 |
}
|
| 101 |
|
| 102 |
|
|
|
|
| 103 |
def predict(images, input_text, model_name):
|
| 104 |
start_time = time.time()
|
| 105 |
input_images = model_dict['clip_preprocess'](images, return_tensors='pt')['pixel_values'].to(model_dict['device'])
|
|
@@ -125,18 +128,15 @@ def predict(images, input_text, model_name):
|
|
| 125 |
clip_text_time = time.time() - start_time
|
| 126 |
|
| 127 |
start_time = time.time()
|
| 128 |
-
try:
|
| 129 |
-
results = client.query(embedding_input=text_embeddings[0].tolist())
|
| 130 |
-
output = ''
|
| 131 |
-
except:
|
| 132 |
-
results = []
|
| 133 |
-
output = 'The server for image retrieval is not working. Please try again later.'
|
| 134 |
-
retrieval_time = time.time() - start_time
|
| 135 |
|
|
|
|
|
|
|
|
|
|
| 136 |
|
|
|
|
| 137 |
|
| 138 |
-
for idx,
|
| 139 |
-
image_url =
|
| 140 |
output += f'\n'
|
| 141 |
|
| 142 |
time_output = {'CLIP visual extractor': clip_image_time,
|
|
@@ -180,7 +180,7 @@ def test_fps(batch_size=1):
|
|
| 180 |
if __name__ == '__main__':
|
| 181 |
args = parse_args()
|
| 182 |
|
| 183 |
-
global model_dict,
|
| 184 |
|
| 185 |
model_dict = load_models(args)
|
| 186 |
|
|
@@ -189,19 +189,18 @@ if __name__ == '__main__':
|
|
| 189 |
test_fps(1)
|
| 190 |
exit()
|
| 191 |
|
|
|
|
| 192 |
|
| 193 |
-
|
| 194 |
-
indice_name="laion5B-H-14" if args.clip_model_name == "huge" else "laion5B-L-14",
|
| 195 |
-
)
|
| 196 |
|
| 197 |
-
title = 'Zeroshot CIR demo'
|
| 198 |
|
| 199 |
md_title = f'''# {title}
|
| 200 |
[LinCIR](https://arxiv.org/abs/2312.01998): Language-only Training of Zero-shot Composed Image Retrieval
|
| 201 |
[SEARLE](https://arxiv.org/abs/2303.15247): Zero-shot Composed Image Retrieval with Textual Inversion
|
| 202 |
[Pic2Word](https://arxiv.org/abs/2302.03084): Mapping Pictures to Words for Zero-shot Composed Image Retrieval
|
| 203 |
|
| 204 |
-
K-NN index for the retrieval results are entirely trained using the
|
| 205 |
'''
|
| 206 |
|
| 207 |
with gr.Blocks(title=title) as demo:
|
|
|
|
| 6 |
import os
|
| 7 |
import time
|
| 8 |
from argparse import ArgumentParser
|
| 9 |
+
import json
|
| 10 |
|
| 11 |
import numpy as np
|
| 12 |
import torch
|
| 13 |
import gradio as gr
|
| 14 |
+
import faiss
|
| 15 |
|
| 16 |
from encode_with_pseudo_tokens import encode_with_pseudo_tokens_HF
|
| 17 |
from models import build_text_encoder, Phi, PIC2WORD
|
|
|
|
| 20 |
from huggingface_hub import hf_hub_url, cached_download
|
| 21 |
|
| 22 |
|
| 23 |
+
|
| 24 |
def parse_args():
|
| 25 |
parser = ArgumentParser()
|
| 26 |
parser.add_argument("--lincir_ckpt_path", default=None, type=str,
|
|
|
|
| 102 |
}
|
| 103 |
|
| 104 |
|
| 105 |
+
@torch.no_grad()
|
| 106 |
def predict(images, input_text, model_name):
|
| 107 |
start_time = time.time()
|
| 108 |
input_images = model_dict['clip_preprocess'](images, return_tensors='pt')['pixel_values'].to(model_dict['device'])
|
|
|
|
| 128 |
clip_text_time = time.time() - start_time
|
| 129 |
|
| 130 |
start_time = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
+
_, results = faiss_index.search(text_embeddings.cpu().numpy(), k=10)
|
| 133 |
+
|
| 134 |
+
retrieval_time = time.time() - start_time
|
| 135 |
|
| 136 |
+
output = ''
|
| 137 |
|
| 138 |
+
for idx, retrieved_idx in enumerate(results[0]):
|
| 139 |
+
image_url = image_urls[retrieved_idx]
|
| 140 |
output += f'\n'
|
| 141 |
|
| 142 |
time_output = {'CLIP visual extractor': clip_image_time,
|
|
|
|
| 180 |
if __name__ == '__main__':
|
| 181 |
args = parse_args()
|
| 182 |
|
| 183 |
+
global model_dict, faiss_index, image_urls
|
| 184 |
|
| 185 |
model_dict = load_models(args)
|
| 186 |
|
|
|
|
| 189 |
test_fps(1)
|
| 190 |
exit()
|
| 191 |
|
| 192 |
+
faiss_index = faiss.read_index('./clip_large.index', faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY)
|
| 193 |
|
| 194 |
+
image_urls = json.load(open('./image_urls.json'))
|
|
|
|
|
|
|
| 195 |
|
| 196 |
+
title = 'Zeroshot CIR demo to search high-quality AI images'
|
| 197 |
|
| 198 |
md_title = f'''# {title}
|
| 199 |
[LinCIR](https://arxiv.org/abs/2312.01998): Language-only Training of Zero-shot Composed Image Retrieval
|
| 200 |
[SEARLE](https://arxiv.org/abs/2303.15247): Zero-shot Composed Image Retrieval with Textual Inversion
|
| 201 |
[Pic2Word](https://arxiv.org/abs/2302.03084): Mapping Pictures to Words for Zero-shot Composed Image Retrieval
|
| 202 |
|
| 203 |
+
K-NN index for the retrieval results are entirely trained using [the upscaled midjourney v5 images (444,901)](https://huggingface.co/datasets/wanng/midjourney-v5-202304-clean).
|
| 204 |
'''
|
| 205 |
|
| 206 |
with gr.Blocks(title=title) as demo:
|
requirements.txt
CHANGED
|
@@ -6,3 +6,4 @@ accelerate
|
|
| 6 |
datasets
|
| 7 |
spacy
|
| 8 |
git+https://github.com/rom1504/clip-retrieval
|
|
|
|
|
|
| 6 |
datasets
|
| 7 |
spacy
|
| 8 |
git+https://github.com/rom1504/clip-retrieval
|
| 9 |
+
faiss
|