import os import numpy as np import pandas as pd from pathlib import Path import streamlit as st from streamlit_plotly_events import plotly_events import pickle import plotly.express as px import plotly.graph_objects as go import gdown import torch from sentence_transformers import SentenceTransformer, util # ------------------------------- # Setup # ------------------------------- st.set_page_config(page_title="Alibaba Semantic Search", layout="wide") MODEL_DIR = Path("models") MODEL_DIR.mkdir(exist_ok=True) embeddings_path = MODEL_DIR / 'desc_embeddings_Alibaba_20251016.npy' umap_embeddings_path = MODEL_DIR / 'descs_umap_2d_AB_20251016.npy' data_file_path = MODEL_DIR / 'full_df_minus_nan_descs.csv' umap_model_path = MODEL_DIR / 'umap_2d_AB_written.pkl' pca_model_path = MODEL_DIR / 'pca_AB_written.pkl' emb_ID = '1QQ_QfFTSzTLNkp6Sr4jux_ZTJjMhSyah' umap_emb_ID = '1a5t5iWOAVgUmYXzrWXctATkDyx9rRF4F' data_ID = '1tzM67Lg3R-rAvRtol0VGHx6zGW_tdx60' umap_mod_ID = '1x8PK1Gn72YYBZ4po-0guZMUBtL8oSn1i' pca_mod_ID = '1jIxBBAZOy8OAzGxBCG4jy7244Wb_TjP9' paths = [embeddings_path, umap_embeddings_path, data_file_path, umap_model_path, pca_model_path] ids = [emb_ID, umap_emb_ID, data_ID, umap_mod_ID, pca_mod_ID] assets_links = [f"https://drive.google.com/uc?id={x}" for x in ids] # ------------------------------- # Download + Load Data # ------------------------------- def load_assets(): st.info("Downloading assets from Google Drive (only if missing)...") for url, path in zip(assets_links, paths): if not path.exists(): gdown.download(url, str(path), quiet=False) st.success("Assets ready.") embeddings = np.load(embeddings_path) umap_2d = np.load(umap_embeddings_path) docs = pd.read_csv(data_file_path) with open(umap_model_path, "rb") as f: umap_model = pickle.load(f) with open(pca_model_path, "rb") as f: pca_model = pickle.load(f) return embeddings, umap_2d, docs, umap_model, pca_model embeddings, umap_2d, docs, umap_model, pca_model = load_assets() # ------------------------------- # Load SentenceTransformer (cached) # ------------------------------- @st.cache_resource def load_text_encoder(): return SentenceTransformer('Alibaba-NLP/gte-multilingual-base', trust_remote_code=True) model = load_text_encoder() # ------------------------------- # UI # ------------------------------- st.title("Semantic Search — Alibaba Embeddings") st.markdown("Enter a query to highlight semantically similar documents on the 2D UMAP plot.") query = st.text_input("Enter search query:") top_k = st.slider("Number of matches to highlight", min_value=10, max_value=2500, value=100) similarity_measure = st.radio( "Similarity measure", ["Cosine", "Euclidean", "Manhattan (L1)"], horizontal=True ) # ------------------------------- # Search logic # ------------------------------- if query: with st.spinner("Encoding and searching..."): query_embedding = model.encode(query, convert_to_tensor=True) query_numpy = query_embedding.cpu().numpy().reshape(1, -1) query_pca = pca_model.transform(query_numpy) query_umap = umap_model.transform(query_pca) if similarity_measure == "Cosine": scores = util.cos_sim(query_embedding, embeddings)[0] elif similarity_measure == "Euclidean": scores = -torch.cdist(query_embedding, embeddings, p=2)[0] elif similarity_measure == "Manhattan (L1)": scores = -torch.cdist(query_embedding, embeddings, p=1)[0] top_k_scores, top_k_indices = torch.topk(scores, top_k) highlight_indices = top_k_indices.numpy() documents = docs.title_narrative labels = ["Match" if i in highlight_indices else "Other" for i in range(len(documents))] top_results_df = docs.iloc[top_k_indices].drop('Unnamed: 0') top_results_df.loc[:,'similarity_score'] = top_k_scores.numpy() df = pd.DataFrame({ "UMAP_1": umap_2d[:, 0], "UMAP_2": umap_2d[:, 1], "Label": labels, "Text": documents }) df["Title"] = df["Text"].str.slice(0, 100) + "..." df["Index"] = df.index color_discrete_map = {"Match": "red", "Other": "lightgray"} fig = px.scatter( df, x="UMAP_1", y="UMAP_2", color="Label", color_discrete_map=color_discrete_map, hover_data={"Text": False, "Title": True, "Index": True, "UMAP_1": False, "UMAP_2": False}, opacity=0.7, title=f"Top {top_k} semantic matches for: '{query}' ({similarity_measure})", width=900, height=700 ) fig.add_trace(go.Scatter( x=[query_umap[0][0]], y=[query_umap[0][1]], mode='markers+text', marker=dict(size=10, color='blue', symbol='x'), name='Query', text=['Query'], textposition='top center' )) fig.update_traces(marker=dict(size=4)) selected_points = plotly_events( fig, click_event=True, hover_event=False, select_event=True, override_height=700, override_width="100%" ) # Display info for selected points if selected_points: st.subheader("Selected Project Details") for pt in selected_points: idx = pt["pointIndex"] st.markdown(f"**Index:** {idx}") st.markdown(f"**Title:** {docs.iloc[idx]['title_narrative']}") st.markdown(f"**Description:** {docs.iloc[idx]['descriptive_general_description']}") st.markdown(f"**Sector Name:**({docs.iloc[idx]['sector_name']})") st.markdown(f"**Funding:**({docs.iloc[idx]['Funding']})") st.markdown(f"**Reporting Org:**({docs.iloc[idx]['reporting_org_name']})") st.divider() else: st.info("Click a point on the scatter plot to see its details.") # st.plotly_chart(fig, use_container_width=True) # ------------------------------- # Update User Interface for Displaying Results # ------------------------------- # Define possible columns that the user can choose to view possible_columns = list(docs.columns) # Allow the user to select which columns to display via multi-select sidebar widget selected_columns = st.sidebar.multiselect('Select columns to display:', possible_columns, default=possible_columns) # Update st.subheader to announce the display of matches in a table/dataframe st.subheader("Top Matched Documents in Interactive Table") # Create a dataframe for the top k results with selected columns top_results_dataframe = top_results_df[selected_columns] # Display the dataframe using st.dataframe, making it interactive and filterable st.dataframe(top_results_dataframe.assign(hack='').set_index('hack')) # A hack to have an index-less display else: st.info("Enter a search query to begin.")