Spaces:
Running
Running
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoTokenizer, AutoModelForMaskedLM | |
| from typing import List | |
| import torch | |
| from functools import lru_cache | |
| import logging | |
| from datetime import datetime | |
| from collections import defaultdict | |
| # π§ Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # π Initialize FastAPI app | |
| app = FastAPI() | |
| logger.info("Starting FastAPI application") | |
| # π Load SentenceTransformer models | |
| logger.info("Loading BGE small model...") | |
| bge_small_model = SentenceTransformer('BAAI/bge-small-en-v1.5', device="cpu") | |
| logger.info("Loaded BGE small model") | |
| logger.info("Loading All-MPNet model...") | |
| all_mp_net_model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2', device="cpu") | |
| logger.info("Loaded All-MPNet model") | |
| # π Load SPLADE model | |
| logger.info("Loading SPLADE model...") | |
| SPLADE_MODEL = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-ensembledistil", trust_remote_code=True) | |
| SPLADE_TOKENIZER = AutoTokenizer.from_pretrained("naver/splade-cocondenser-ensembledistil") | |
| SPLADE_MODEL.eval() | |
| logger.info("Loaded SPLADE model") | |
| # π¦ Request and response models | |
| class TextInput(BaseModel): | |
| text: List[str] | |
| model_name: str | |
| class SparseVector(BaseModel): | |
| indices: List[int] | |
| values: List[float] | |
| # π§ LRU cacheable versions | |
| def encode_dense_cached(model_name: str, text: str): | |
| logger.info(f"Encoding dense text with model {model_name}: {text}") | |
| if model_name == "BM": | |
| embedding = all_mp_net_model.encode([text])[0].tolist() | |
| else: | |
| embedding = bge_small_model.encode([text])[0].tolist() | |
| logger.info(f"Finished encoding dense text") | |
| return embedding | |
| def encode_splade_cached(text: str) -> SparseVector: | |
| logger.info(f"Encoding SPLADE sparse vector: {text}") | |
| inputs = SPLADE_TOKENIZER(text, return_tensors="pt", truncation=True) | |
| with torch.no_grad(): | |
| outputs = SPLADE_MODEL(**inputs) | |
| logits = outputs.logits[0] | |
| relu_log = torch.log1p(torch.relu(logits)) | |
| nonzero = relu_log.nonzero(as_tuple=False) | |
| if nonzero.shape[0] == 0: | |
| logger.info("No non-zero values found in SPLADE output") | |
| return SparseVector(indices=[], values=[]) | |
| vocab_indices = nonzero[:, 1] | |
| values = relu_log[nonzero[:, 0], nonzero[:, 1]] | |
| vocab_indices_list = vocab_indices.cpu().numpy().tolist() | |
| values_list = values.cpu().numpy().tolist() | |
| index_to_value = defaultdict(float) | |
| for idx, val in zip(vocab_indices_list, values_list): | |
| index_to_value[idx] += val | |
| deduped_indices = list(index_to_value.keys()) | |
| deduped_values = list(index_to_value.values()) | |
| logger.info(f"SPLADE encoding complete with {len(deduped_indices)} dimensions") | |
| return SparseVector( | |
| indices=deduped_indices, | |
| values=deduped_values | |
| ) | |
| # π Main endpoint | |
| async def get_embedding(input: TextInput): | |
| logger.info(f"Received request with model: {input.model_name}, texts: {input.text}") | |
| model_key = input.model_name.upper() | |
| if model_key in {"BM", "BG"}: | |
| embeddings = [encode_dense_cached(model_key, t) for t in input.text] | |
| logger.info(f"Returning dense embeddings for {len(embeddings)} texts") | |
| return {"type": "dense", "embeddings": embeddings} | |
| elif model_key == "SPLADE": | |
| sparse_vecs = [encode_splade_cached(t).model_dump() for t in input.text] | |
| logger.info(f"Returning sparse embeddings for {len(sparse_vecs)} texts") | |
| return {"type": "sparse", "embeddings": sparse_vecs} | |
| else: | |
| embeddings = bge_small_model.encode(input.text) | |
| return {"embeddings": embeddings.tolist()} | |
| async def status(): | |
| logger.info(f"Status API: Server is up and running at {datetime.now()}") | |
| return {"status": "Server is up and running"} | |