Spaces:
Runtime error
Runtime error
| import os | |
| import subprocess | |
| import torch | |
| import gradio as gr | |
| from clip_interrogator import Config, Interrogator | |
| CACHE_URLS = [ | |
| 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl', | |
| 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl', | |
| 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl', | |
| 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl', | |
| 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl', | |
| ] | |
| os.makedirs('cache', exist_ok=True) | |
| for url in CACHE_URLS: | |
| subprocess.run(['wget', url, '-P', 'cache'], stdout=subprocess.PIPE).stdout.decode('utf-8') | |
| config = Config() | |
| config.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| config.blip_offload = False if torch.cuda.is_available() else True | |
| config.chunk_size = 2048 | |
| config.flavor_intermediate_count = 512 | |
| config.blip_num_beams = 64 | |
| ci = Interrogator(config) | |
| #@spaces.GPU | |
| def inference(image, mode, best_max_flavors): | |
| image = image.convert('RGB') | |
| if mode == 'best': | |
| prompt_result = ci.interrogate(image, max_flavors=int(best_max_flavors)) | |
| elif mode == 'classic': | |
| prompt_result = ci.interrogate_classic(image) | |
| else: | |
| prompt_result = ci.interrogate_fast(image) | |
| return prompt_result | |
| with gr.Blocks() as demo: | |
| with gr.Column(): | |
| gr.Markdown("# CLIP Interrogator") | |
| input_image = gr.Image(type='pil', elem_id="input-img") | |
| with gr.Row(): | |
| mode_input = gr.Radio(['best', 'classic', 'fast'], label='Select mode', value='best') | |
| flavor_input = gr.Slider(minimum=2, maximum=48, step=2, value=32, label='best mode max flavors') | |
| submit_btn = gr.Button("Submit") | |
| output_text = gr.Textbox(label="Description Output") | |
| submit_btn.click( | |
| fn=inference, | |
| inputs=[input_image, mode_input, flavor_input], | |
| outputs=[output_text], | |
| concurrency_limit=10 | |
| ) | |
| #demo.launch(server_name="0.0.0.0") | |
| demo.queue().launch() | |