# ---------------- Libraries ---------------- import torch import gradio as gr from diffusers import DiffusionPipeline from transformers import pipeline, BlipProcessor, BlipForQuestionAnswering import lpips import clip from bert_score import score import torchvision.transforms as T # ---------------- Device Setup ---------------- device = "cuda" if torch.cuda.is_available() else "cpu" # ---------------- GPU Cache Free ---------------- def free_gpu_cache(): torch.cuda.empty_cache() # ---------------- Load SD Turbo & DreamShaper ---------------- gen_pipe = DiffusionPipeline.from_pretrained( "stabilityai/sd-turbo", torch_dtype=torch.float16 if device=="cuda" else torch.float32 ).to(device) dreamshaper_pipe = DiffusionPipeline.from_pretrained( "Lykon/dreamshaper-7", torch_dtype=torch.float16 if device=="cuda" else torch.float32 ).to(device) # ---------------- Load NLP Models ---------------- captioner = pipeline( "image-to-text", model="Salesforce/blip-image-captioning-large", device=0 if device=="cuda" else -1 ) sentiment_model = pipeline( "sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english", device=0 if device=="cuda" else -1 ) ner_model = pipeline( "ner", model="dbmdz/bert-large-cased-finetuned-conll03-english", aggregation_strategy="simple", device=0 if device=="cuda" else -1 ) topic_model = pipeline( "zero-shot-classification", model="facebook/bart-large-mnli", device=0 if device=="cuda" else -1 ) # ---------------- Load VQA Model ---------------- vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to(device) # ---------------- Load CLIP & LPIPS ---------------- clip_model, clip_preprocess = clip.load("ViT-B/32", device=device) lpips_model = lpips.LPIPS(net='alex').to(device) # ---------------- Style Map ---------------- style_map = { "Photorealistic": "photorealistic, ultra-detailed, 8k, cinematic lighting", "Real Life": "natural lighting, true-to-life colors, DSLR", "Documentary": "documentary handheld muted colors", "iPhone Camera": "iPhone photo natural HDR", "Street Photography": "candid street ambient shadows", "Cinematic": "cinematic lighting dramatic depth", "Anime": "anime cel shaded vibrant", "Watercolor": "watercolor soft wash art", "Macro": "macro lens shallow DOF", "Cyberpunk": "neon cyberpunk futuristic", } lpips_transform = T.Compose([T.ToTensor(), T.Resize((256,256))]) # ---------------- Image Generation Functions ---------------- def generate_image_and_store(prompt, negative, seed, style, images): images = images or [] enhanced_prompt = f"{prompt}, {style_map.get(style,'')}" generator = torch.Generator(device=device).manual_seed(int(seed)) ctx = torch.autocast("cuda") if device=="cuda" else torch.no_grad() with ctx: img = gen_pipe(prompt=enhanced_prompt, negative_prompt=negative, generator=generator).images[0] images.append(img) free_gpu_cache() return img, images def generate_dreamshaper_image(prompt, negative, seed, style, images): images = images or [] enhanced_prompt = f"{prompt}, {style_map.get(style,'')}" generator = torch.Generator(device=device).manual_seed(int(seed)) ctx = torch.autocast("cuda") if device=="cuda" else torch.no_grad() with ctx: img = dreamshaper_pipe(prompt=enhanced_prompt, negative_prompt=negative, generator=generator).images[0] images.append(img) free_gpu_cache() return img, images # ---------------- VQA ---------------- def answer_vqa(question, image): if image is None or question.strip() == "": return "Upload an image and enter a question." inputs = vqa_processor(images=image, text=question, return_tensors="pt").to(device) with torch.no_grad(): generated_ids = vqa_model.generate(**inputs) answer = vqa_processor.decode(generated_ids[0], skip_special_tokens=True) return answer # ---------------- Metrics Computation ---------------- def compute_metrics_button(images, captions, idx1, idx2): img1_clip = clip_preprocess(images[idx1]).unsqueeze(0).to(device) img2_clip = clip_preprocess(images[idx2]).unsqueeze(0).to(device) with torch.no_grad(): feat1 = clip_model.encode_image(img1_clip) feat2 = clip_model.encode_image(img2_clip) clip_sim = float(torch.cosine_similarity(feat1, feat2).item()) img1_lp = lpips_transform(images[idx1]).unsqueeze(0).to(device) * 2 - 1 img2_lp = lpips_transform(images[idx2]).unsqueeze(0).to(device) * 2 - 1 with torch.no_grad(): lpips_score = float(lpips_model(img1_lp, img2_lp).item()) _, _, F1 = score([captions[idx1]], [captions[idx2]], lang="en", verbose=False) bert_f1 = float(F1.mean().item()) return f""" **Metrics Comparison** - CLIP Similarity: {clip_sim:.4f} - LPIPS Score: {lpips_score:.4f} - BERTScore F1: {bert_f1:.4f} """ # ---------------- Build Gradio UI ---------------- # ---------------- Build Gradio UI with Original Look ---------------- def build_ui_with_custom_ui(): with gr.Blocks(title="Multimodal AI Image Studio") as demo: # ---------------- CSS Styling ---------------- gr.HTML(""" """) # ---------------- Heading ---------------- gr.Markdown("## Multimodal AI Image Studio: An Integrated Comparative Perspective", elem_classes="heading-orange") # ---------------- States ---------------- images_state = gr.State([]) captions_state = gr.State([]) # ---------------- Step 1: Upload Reference Image ---------------- gr.Markdown("### Upload Reference Image", elem_classes="heading-orange") with gr.Row(): with gr.Column(scale=1): upload_input = gr.Image(label="Drag & Drop Image", type="pil") upload_btn = gr.Button("Upload Image & Generate Caption", elem_classes="orange-btn") with gr.Column(scale=1): upload_preview = gr.Image(label="Uploaded Image", interactive=False) caption_out = gr.Markdown(label="Generated Caption") def upload_and_generate_caption_ui(img, images_state, captions_state): images = [img] caption = captioner(img)[0]["generated_text"] captions = [caption] return img, caption, images, captions upload_btn.click( upload_and_generate_caption_ui, inputs=[upload_input, images_state, captions_state], outputs=[upload_preview, caption_out, images_state, captions_state] ) # ---------------- Step 2: Generate SD-Turbo & DreamShaper ---------------- gr.Markdown("### Generate Images from Caption", elem_classes="heading-orange") with gr.Row(): with gr.Column(scale=1, min_width=300): sd_btn = gr.Button("Generate SD-Turbo Image", elem_classes="orange-btn") sd_preview = gr.Image(label="SD-Turbo Image", interactive=False) with gr.Column(scale=1, min_width=300): ds_btn = gr.Button("Generate DreamShaper Image", elem_classes="orange-btn") ds_preview = gr.Image(label="DreamShaper Image", interactive=False) def generate_sd_from_caption_ui(caption, images_state, captions_state): img, images = generate_image_and_store(caption, negative="", seed=42, style="Photorealistic", images=images_state) captions_state[1:2] = [captioner(img)[0]["generated_text"]] return img, images, captions_state def generate_ds_from_caption_ui(caption, images_state, captions_state): img, images = generate_dreamshaper_image(caption, negative="", seed=123, style="Photorealistic", images=images_state) captions_state[2:3] = [captioner(img)[0]["generated_text"]] return img, images, captions_state sd_btn.click(generate_sd_from_caption_ui, inputs=[caption_out, images_state, captions_state], outputs=[sd_preview, images_state, captions_state]) ds_btn.click(generate_ds_from_caption_ui, inputs=[caption_out, images_state, captions_state], outputs=[ds_preview, images_state, captions_state]) # ---------------- Step 3: Compute Pairwise Metrics ---------------- gr.Markdown("### Compute Pairwise Metrics", elem_classes="heading-orange") metrics_btn = gr.Button("Compute Metrics for All Pairs", elem_classes="teal-btn") with gr.Row(): metrics_A = gr.Markdown() metrics_B = gr.Markdown() metrics_C = gr.Markdown() def compute_metrics_all_pairs_ui(images, captions): yield ("
", "
", "
") if len(images) < 3: msg = "All three images and captions are required to compute metrics." yield msg, msg, msg else: A = compute_metrics_button(images, captions, 0, 1) B = compute_metrics_button(images, captions, 0, 2) C = compute_metrics_button(images, captions, 1, 2) yield (f"**Reference ↔ SD-Turbo**\n{A}", f"**Reference ↔ DreamShaper**\n{B}", f"**SD-Turbo ↔ DreamShaper**\n{C}") metrics_btn.click(compute_metrics_all_pairs_ui, inputs=[images_state, captions_state], outputs=[metrics_A, metrics_B, metrics_C]) # ---------------- Step 4: NLP Analysis ---------------- gr.Markdown("### NLP Analysis of Captions", elem_classes="heading-orange") nlp_btn = gr.Button("Analyze Captions", elem_classes="teal-btn") nlp_out = gr.HTML() def analyze_caption_pipeline_ui(captions): yield "
" if len(captions) < 3: yield "All three captions are required for NLP analysis." else: labels = ["Reference Image", "SD-Turbo", "DreamShaper"] blocks = [] for label, caption in zip(labels, captions): sentiment = "
".join([f"{s['label']}: {s['score']:.2f}" for s in sentiment_model(caption)]) ents = "
".join([f"{e['entity_group']}: {e['word']}" for e in ner_model(caption)]) or "None" topics_data = topic_model(caption, candidate_labels=['people','animals','objects','food','nature']) topics = "
".join([f"{l}: {sc:.2f}" for l, sc in zip(topics_data['labels'], topics_data['scores'])]) block = f"

{label}

Sentiment
{sentiment}

Entities
{ents}

Topics
{topics}
" blocks.append(block) yield f"
{''.join(blocks)}
" nlp_btn.click(analyze_caption_pipeline_ui, inputs=[captions_state], outputs=[nlp_out]) # ---------------- Step 5: Visual Question Answering ---------------- gr.Markdown("### Visual Question Answering (VQA)", elem_classes="heading-orange") with gr.Row(): with gr.Column(scale=1): vqa_input = gr.Textbox(label="Enter a question about the reference image") vqa_btn = gr.Button("Get Answer", elem_classes="teal-btn") with gr.Column(scale=1): vqa_out = gr.Markdown(label="VQA Output") def answer_vqa_ui(question, image): yield "
" ans = answer_vqa(question, image) yield ans vqa_btn.click(answer_vqa_ui, inputs=[vqa_input, upload_preview], outputs=[vqa_out]) return demo # Launch the interface demo = build_ui_with_custom_ui() demo.launch()