import json from typing import List import altair as alt import pandas as pd import streamlit as st import torch from huggingface_hub import hf_hub_download from transformers import AutoModel, AutoTokenizer st.set_page_config(page_title="Emotion Classifier", layout="wide") EMOTIONS = ["anger", "fear", "joy", "sadness", "surprise"] BASE_MODEL_ID = "FacebookAI/roberta-large" HF_REPO_ID = "JashMevada/emotion-classifier" CLASSIFIER_WEIGHTS = "classifier.pth" CLASSIFIER_CONFIG = "classifier_config.json" DEFAULT_MAX_LEN = 512 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") def build_classifier(hidden_size: int, num_labels: int, dropout: float = 0.2) -> torch.nn.Sequential: return torch.nn.Sequential( torch.nn.Linear(hidden_size, hidden_size), torch.nn.ReLU(), torch.nn.Dropout(dropout), torch.nn.Linear(hidden_size, num_labels), ) @st.cache_resource(show_spinner=False) def load_fine_tuned_components(): tokenizer = AutoTokenizer.from_pretrained(HF_REPO_ID) encoder = AutoModel.from_pretrained(HF_REPO_ID, add_pooling_layer=False).to(DEVICE) encoder.eval() hidden_size = encoder.config.hidden_size head = build_classifier(hidden_size, len(EMOTIONS)).to(DEVICE) weights_path = hf_hub_download(HF_REPO_ID, CLASSIFIER_WEIGHTS) head.load_state_dict(torch.load(weights_path, map_location="cpu")) head.eval() max_length = DEFAULT_MAX_LEN try: cfg_path = hf_hub_download(HF_REPO_ID, CLASSIFIER_CONFIG) with open(cfg_path, "r", encoding="utf-8") as cfg_file: cfg = json.load(cfg_file) max_length = int(cfg.get("max_length", DEFAULT_MAX_LEN)) except Exception: pass return tokenizer, encoder, head, max_length def predict_emotions(texts: List[str], threshold: float) -> pd.DataFrame: tokenizer, encoder, head, max_length = load_fine_tuned_components() encoded = tokenizer( texts, padding=True, truncation=True, max_length=max_length, return_tensors="pt", ) encoded = {k: v.to(DEVICE) for k, v in encoded.items()} with torch.no_grad(): outputs = encoder(**encoded).last_hidden_state mask = encoded["attention_mask"].unsqueeze(-1).float() pooled = (outputs * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9) logits = head(pooled) probs = torch.sigmoid(logits).cpu().numpy() binary = (probs >= threshold).astype(int) df = pd.DataFrame({emotion: probs[:, idx] for idx, emotion in enumerate(EMOTIONS)}) df["Predicted Labels"] = [ ", ".join([emo for emo, flag in zip(EMOTIONS, row) if flag]) or "None" for row in binary ] df.insert(0, "Text", texts) return df st.title("Multi-label Emotion Classifier (Fine-tuned RoBERTa)") st.write( "Uploads from Hugging Face (`JashMevada/emotion-classifier`) are used to run the fine-tuned head." "Adjust the threshold to suit your precision/recall trade-off." ) with st.sidebar: st.subheader("Inference Settings") decision_threshold = st.slider("Decision threshold", 0.1, 0.9, 0.86, 0.01) st.caption("Threshold was tuned on validation data; tweak it to prioritize precision or recall.") sample_texts = { "Customer Support": "Your service saved my day and I couldn't be happier, thank you team!", "Crisis Post": "I feel overwhelmed and scared about what tomorrow brings.", "Mixed Emotions": "The movie made me cry and laugh at the same time—such a wild ride!", } selected_sample = st.selectbox("Sample snippets", list(sample_texts.keys())) user_text = st.text_area( "Enter one or multiple texts (separate lines). Empty lines are ignored.", value=sample_texts[selected_sample], height=160, ) submitted = st.button("Run inference", type="primary") if submitted: inputs = [line.strip() for line in user_text.splitlines() if line.strip()] if not inputs: st.warning("Provide at least one non-empty line.") else: with st.spinner("Predicting emotions..."): results_df = predict_emotions(inputs, decision_threshold) st.subheader("Predictions") display_df = results_df.copy() display_df["Text"] = display_df["Text"].apply( lambda t: t if len(t) <= 200 else f"{t[:197]}..." ) st.dataframe(display_df, width="stretch") chart_df = ( results_df.reset_index() .rename(columns={"index": "Text #"}) .loc[:, ["Text #", *EMOTIONS]] .melt(id_vars="Text #", var_name="Emotion", value_name="Probability") ) chart = ( alt.Chart(chart_df) .mark_bar() .encode( x=alt.X("Emotion:N", title="", sort=EMOTIONS), y=alt.Y("Probability:Q", scale=alt.Scale(domain=[0, 1])), color="Text #:N", tooltip=["Text #:N", "Emotion:N", alt.Tooltip("Probability:Q", format=".3f")], ) .properties(height=320) ) st.altair_chart(chart, width="stretch") st.info( "Weights download happens once per session (cached). Set the `HF_TOKEN` environment variable " "if the repository is private." )