import torch # def encode(sentences, tokenizer, model, device="mps"): inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt").to(device = device) with torch.no_grad(): outputs = model(**inputs) # outputs.last_hidden_state = [batch, tokens, hidden_dim] # mean pooling embeddings = outputs.last_hidden_state.mean(dim=1) return(embeddings)