sumittguptaa148's picture
Add Gradio application and dependencies for deployment
a452bfa
raw
history blame
4.59 kB
import gradio as gr
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# --- Configuration ---
# Your Model's Hugging Face Repository ID
# NOTE: This assumes you successfully pushed the model during the CV loop!
HF_REPO_ID = "sumittguptaa148/DL-Gen-AI-Project"
MODEL_NAME = "roberta-large" # The base model name
MAX_LEN = 256
LABELS = ['anger', 'fear', 'joy', 'sadness', 'surprise']
# Optimized Thresholds from the latest output:
BEST_THRESHOLDS = {
'anger': 0.64,
'fear': 0.34,
'joy': 0.79,
'sadness': 0.78,
'surprise': 0.48
}
# --- Load Model and Tokenizer ---
try:
# We load AutoModel for a custom architecture, but since we pushed
# using Trainer, the hub might save it with an AutoModelForSequenceClassification structure
# Let's try to load the base AutoModel and then load the state dict manually
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(HF_REPO_ID)
# Load the model weights and structure.
# We use AutoModelForSequenceClassification here as Hugging Face Trainer usually
# saves a classification head that is compatible with this class structure
# when pushing to the hub.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForSequenceClassification.from_pretrained(HF_REPO_ID, num_labels=len(LABELS))
model.to(device)
model.eval()
print(f"Model and Tokenizer loaded from {HF_REPO_ID}")
except Exception as e:
print(f"Error loading model from Hugging Face Hub: {e}")
# Fallback/Dummy definitions for deployment setup
tokenizer = None
model = None
device = "cpu"
# --- Prediction Function ---
def predict_emotion(text):
if model is None:
return "Model failed to load. Please check logs."
# Tokenize input
inputs = tokenizer(
text,
padding=True,
truncation=True,
max_length=MAX_LEN,
return_tensors="pt"
)
# Move tensors to device
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits.cpu().numpy()
# Convert logits to probabilities (sigmoid)
probs = 1 / (1 + np.exp(-logits))[0]
# Apply dynamic thresholds and format output
results = {}
for i, label in enumerate(LABELS):
prob = probs[i]
threshold = BEST_THRESHOLDS[label]
# Classification
is_present = " Present" if prob >= threshold else " Absent"
# Format for display
results[f"{label.capitalize()} ({is_present})"] = f"{prob:.4f} (Threshold: {threshold:.2f})"
return results
# --- Gradio Interface ---
# Output components for Gradio
output_components = [
gr.Textbox(label=f"{label.capitalize()} (Classification & Prob)", lines=1) for label in LABELS
]
# Map output components to keys in the dictionary returned by predict_emotion
# Gradio expects the output component labels to match the dictionary keys
output_keys = [f"{label.capitalize()} ({{}})" for label in LABELS] # Placeholder for "Present/Absent" part
# Custom function to create the correct output components for Gradio
def get_output_components():
# Use Textbox for results
outputs = []
for label in LABELS:
outputs.append(gr.Textbox(label=f"{label.capitalize()} Emotion Result", lines=1))
return outputs
# Custom wrapper to ensure output matches the order of Gradio components
def predict_emotion_gradio(text):
raw_results = predict_emotion(text)
# Reorder results to match the output components list
ordered_results = []
for label in LABELS:
# Find the key that starts with the label (e.g., 'Anger (✅ Present)')
key_match = next(k for k in raw_results if k.startswith(label.capitalize()))
ordered_results.append(raw_results[key_match])
return tuple(ordered_results)
title = "Multi-Label Emotion Classification with RoBERTa-Large"
description = "A DL/GenAI project classifying text into Anger, Fear, Joy, Sadness, and Surprise. The model uses a fine-tuned RoBERTa-Large with 5-Fold CV and dynamic threshold optimization."
gr.Interface(
fn=predict_emotion_gradio,
inputs=gr.Textbox(lines=5, placeholder="Enter a sentence or short text here...", label="Input Text"),
outputs=get_output_components(), # Use the custom function to get ordered components
title=title,
description=description,
allow_flagging="never",
theme="huggingface"
).launch()