Spaces:
Runtime error
Runtime error
| from transformers import AutoFeatureExtractor, AutoModelForImageClassification, pipeline | |
| import torch | |
| from PIL import Image | |
| import gradio as gr | |
| import aiohttp | |
| import asyncio | |
| from io import BytesIO | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 | |
| nsfw_pipe = pipeline("image-classification", | |
| model=AutoModelForImageClassification.from_pretrained( | |
| "carbon225/vit-base-patch16-224-hentai"), | |
| feature_extractor=AutoFeatureExtractor.from_pretrained( | |
| "carbon225/vit-base-patch16-224-hentai"), | |
| device=device, | |
| torch_dtype=dtype) | |
| style_pipe = pipeline("image-classification", | |
| model=AutoModelForImageClassification.from_pretrained( | |
| "cafeai/cafe_style"), | |
| feature_extractor=AutoFeatureExtractor.from_pretrained( | |
| "cafeai/cafe_style"), | |
| device=device, | |
| torch_dtype=dtype) | |
| aesthetic_pipe = pipeline("image-classification", | |
| model=AutoModelForImageClassification.from_pretrained( | |
| "cafeai/cafe_aesthetic"), | |
| feature_extractor=AutoFeatureExtractor.from_pretrained( | |
| "cafeai/cafe_aesthetic"), | |
| device=device, | |
| torch_dtype=dtype) | |
| async def fetch_image(session, image_url): | |
| print(f"fetching image {image_url}") | |
| async with session.get(image_url) as response: | |
| if response.status == 200 and response.headers['content-type'].startswith('image'): | |
| pil_image = Image.open(BytesIO(await response.read())).convert('RGB') | |
| # resize image proportional | |
| # image = ImageOps.fit(image, (400, 400), Image.LANCZOS) | |
| return pil_image | |
| return None | |
| async def fetch_images(image_urls): | |
| async with aiohttp.ClientSession() as session: | |
| tasks = [asyncio.ensure_future(fetch_image( | |
| session, image_url)) for image_url in image_urls] | |
| return await asyncio.gather(*tasks) | |
| async def predict(json=None, enable_gallery=True, image=None, files=None): | |
| print(json) | |
| if image or files: | |
| if image is not None: | |
| images_paths = [image] | |
| elif files is not None: | |
| images_paths = list(map(lambda x: x.name, files)) | |
| pil_images = [Image.open(image_path).convert("RGB") | |
| for image_path in images_paths] | |
| elif json is not None: | |
| pil_images = await fetch_images(json["urls"]) | |
| style = style_pipe(pil_images) | |
| aesthetic = aesthetic_pipe(pil_images) | |
| nsfw = nsfw_pipe(pil_images) | |
| results = [a + b + c for (a, b, c) in zip(style, aesthetic, nsfw)] | |
| label_data = {} | |
| if image is not None: | |
| label_data = {row["label"]: row["score"] for row in results[0]} | |
| return results, label_data, pil_images if enable_gallery else None | |
| with gr.Blocks() as blocks: | |
| with gr.Row(): | |
| with gr.Column(): | |
| image = gr.Image(label="Image to test", type="filepath") | |
| files = gr.File(label="Multipls Images", file_types=[ | |
| "image"], file_count="multiple") | |
| enable_gallery = gr.Checkbox(label="Enable Gallery", value=True) | |
| json = gr.JSON(label="Results", value={"urls": [ | |
| 'https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/b9fb3257-6a54-455e-b636-9d61cf261676.jpg', | |
| 'https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/062eb9be-76eb-4d7e-9299-d1ebea14b46f.jpg', | |
| 'https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/8ff6d4f6-08d0-4a31-818c-4d32ab146f81.jpg']}) | |
| with gr.Column(): | |
| label = gr.Label(label="style") | |
| results = gr.JSON(label="Results") | |
| gallery = gr.Gallery().style(grid=[2], height="auto") | |
| btn = gr.Button("Run") | |
| btn.click(fn=predict, inputs=[json, enable_gallery, image, files], | |
| outputs=[results, label, gallery], api_name="inference") | |
| blocks.queue() | |
| blocks.launch(debug=True, inline=True) | |