Spaces:
Configuration error
Configuration error
Upload 47 files
Browse files- .gitattributes +3 -0
- README.md +20 -0
- __pycache__/app.cpython-313.pyc +0 -0
- __pycache__/utils.cpython-313.pyc +0 -0
- app.py +136 -0
- config.yaml +25 -0
- knowledge/vectorstore_1/config.json +1 -0
- knowledge/vectorstore_1/docs.pkl +3 -0
- knowledge/vectorstore_1/index.faiss +3 -0
- knowledge/vectorstore_1/index.pkl +3 -0
- rag_pipeline/__init__.py +8 -0
- rag_pipeline/__pycache__/__init__.cpython-313.pyc +0 -0
- rag_pipeline/data_ingest/__pycache__/loader.cpython-313.pyc +0 -0
- rag_pipeline/data_ingest/loader.py +40 -0
- rag_pipeline/data_ingest/parser.py +0 -0
- rag_pipeline/generation/__pycache__/llm_wrapper.cpython-313.pyc +0 -0
- rag_pipeline/generation/__pycache__/prompt_template.cpython-313.pyc +0 -0
- rag_pipeline/generation/llm_wrapper.py +59 -0
- rag_pipeline/generation/prompt_template.py +115 -0
- rag_pipeline/indexing/chunking/__pycache__/markdown.cpython-313.pyc +0 -0
- rag_pipeline/indexing/chunking/__pycache__/recursive.cpython-313.pyc +0 -0
- rag_pipeline/indexing/chunking/markdown.py +54 -0
- rag_pipeline/indexing/chunking/recursive.py +30 -0
- rag_pipeline/indexing/embedding/__pycache__/embedding.cpython-313.pyc +0 -0
- rag_pipeline/indexing/embedding/embedding.py +23 -0
- rag_pipeline/retrieval/__pycache__/reranker.cpython-313.pyc +0 -0
- rag_pipeline/retrieval/__pycache__/vector_retriever.cpython-313.pyc +0 -0
- rag_pipeline/retrieval/graph_retriever.py +4 -0
- rag_pipeline/retrieval/hybrid_retriever.py +0 -0
- rag_pipeline/retrieval/reranker.py +8 -0
- rag_pipeline/retrieval/vector_retriever.py +38 -0
- requirements.txt +0 -0
- test/__pycache__/_normalize_qa.cpython-313.pyc +0 -0
- test/__pycache__/data_ingest.cpython-313.pyc +0 -0
- test/__pycache__/eval_lm.cpython-313.pyc +0 -0
- test/__pycache__/eval_qa.cpython-313.pyc +0 -0
- test/__pycache__/prepare_retrieve.cpython-313.pyc +0 -0
- test/__pycache__/test_llm.cpython-313.pyc +0 -0
- test/__pycache__/test_retrieve.cpython-313.pyc +0 -0
- test/_normalize_qa.py +43 -0
- test/chatbot_inference.py +23 -0
- test/data_ingest.py +78 -0
- test/eval_lm.py +87 -0
- test/eval_qa.py +106 -0
- test/prepare_retrieve.py +50 -0
- test/test_llm.py +9 -0
- test/test_retrieve.py +39 -0
- utils.py +211 -0
.gitattributes
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
knowledge/vectorstore_1/docs.pkl filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
knowledge/vectorstore_1/index.faiss filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
knowledge/vectorstore_1/index.pkl filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
```
|
| 2 |
+
python -m notebook.An.master.test.data_ingest
|
| 3 |
+
--data_dir notebook/An/master/data \\
|
| 4 |
+
--vectorstore_dir notebook/An/master/knowledge/vectorstore_1 \\
|
| 5 |
+
--embed_model_name alibaba-nlp/gte-multilingual-base \\
|
| 6 |
+
--chunking_strategy recursive \\
|
| 7 |
+
--chunk_size 2048 \\
|
| 8 |
+
--chunk_overlap 512 \\
|
| 9 |
+
--vectorstore faiss
|
| 10 |
+
```
|
| 11 |
+
|
| 12 |
+
```
|
| 13 |
+
python -m notebook.An.master.test.test_retrieve
|
| 14 |
+
--query "Heart definition and heart disease"
|
| 15 |
+
--vectorstore_dir notebook/An/master/knowledge/vectorstore_1 \\
|
| 16 |
+
--embed_model_name alibaba-nlp/gte-multilingual-base \\
|
| 17 |
+
--retriever_k 4 \\
|
| 18 |
+
--metric cosine \\
|
| 19 |
+
--threshold 0.5 \\
|
| 20 |
+
```
|
__pycache__/app.cpython-313.pyc
ADDED
|
Binary file (6.45 kB). View file
|
|
|
__pycache__/utils.cpython-313.pyc
ADDED
|
Binary file (10.9 kB). View file
|
|
|
app.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from datetime import datetime
|
| 3 |
+
|
| 4 |
+
# Assuming these are in your project structure
|
| 5 |
+
from .rag_pipeline import ChatAssistant, get_embeddings, vretrieve, retrieve_chatbot_prompt, request_retrieve_prompt
|
| 6 |
+
from .utils import load_local
|
| 7 |
+
|
| 8 |
+
# --- Constants and System Prompt ---
|
| 9 |
+
|
| 10 |
+
# DEVELOPER: Add or remove models here.
|
| 11 |
+
# The key is the display name in the dropdown.
|
| 12 |
+
# The value is a tuple of (model_id, model_provider).
|
| 13 |
+
AVAILABLE_MODELS = {
|
| 14 |
+
"mistral large (mistral)": ("mistral", "mistral"),
|
| 15 |
+
"mistral medium (mistral)": ("mistral-medium", "mistral"),
|
| 16 |
+
"mistral small (mistral)": ("mistral-small", "mistral"),
|
| 17 |
+
"llama3 8B" : ("llama3:8b", "ollama"),
|
| 18 |
+
"llama3.1 8B": ("llama3.1:8b", "ollama"),
|
| 19 |
+
"gpt-oss 20B": ("gpt-oss-20b", "ollama"),
|
| 20 |
+
"gemma3 12B": ("gemma3:12b", "ollama"),
|
| 21 |
+
"gpt 4o mini": ("gpt-4o-mini", "openai"),
|
| 22 |
+
"gpt 4o": ("gpt-4o", "openai"),
|
| 23 |
+
}
|
| 24 |
+
DEFAULT_MODEL_KEY = "mistral medium (mistral)"
|
| 25 |
+
|
| 26 |
+
EMBEDDING_MODEL_ID = "alibaba-nlp/gte-multilingual-base"
|
| 27 |
+
VECTORSTORE_PATH = "notebook/An/master/knowledge/vectorstore_full"
|
| 28 |
+
LOG_FILE_PATH = "log.txt"
|
| 29 |
+
MAX_HISTORY_CONVERSATION = 50
|
| 30 |
+
|
| 31 |
+
# System prompt for the medical assistant
|
| 32 |
+
sys = """
|
| 33 |
+
You are an Medical Assistant specialized in providing information and answering questions related to healthcare and medicine.
|
| 34 |
+
You must answer professionally and empathetically, taking into account the user's feelings and concerns.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
# --- Initial Setup (runs once) ---
|
| 38 |
+
print("Initializing models and data...")
|
| 39 |
+
embedding_model = get_embeddings(EMBEDDING_MODEL_ID, show_progress=False)
|
| 40 |
+
vectorstore, docs = load_local(VECTORSTORE_PATH, embedding_model)
|
| 41 |
+
print("Initialization complete.")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# --- Helper Functions ---
|
| 45 |
+
def log(log_txt: str):
|
| 46 |
+
"""Appends a log entry to the log file."""
|
| 47 |
+
with open(LOG_FILE_PATH, "a", encoding="utf-8") as log_file:
|
| 48 |
+
log_file.write(log_txt + "\n")
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# --- Core Chatbot Logic ---
|
| 52 |
+
def chatbot_logic(message: str, history: list, selected_model_key: str):
|
| 53 |
+
"""
|
| 54 |
+
Handles the main logic for receiving a message, performing RAG, and generating a response.
|
| 55 |
+
"""
|
| 56 |
+
# 1. Look up the model_id and model_provider from the selected key
|
| 57 |
+
model_id, model_provider = AVAILABLE_MODELS[selected_model_key]
|
| 58 |
+
|
| 59 |
+
log(f"** Current time **: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
| 60 |
+
log(f"** User message **: {message}")
|
| 61 |
+
log(f"** Using Model **: {model_id} ({model_provider})")
|
| 62 |
+
|
| 63 |
+
# Initialize the assistant with the specified model for this request
|
| 64 |
+
try:
|
| 65 |
+
chat_assistant = ChatAssistant(model_id, model_provider)
|
| 66 |
+
except Exception as e:
|
| 67 |
+
yield f"Error: Could not initialize the model. Please check the ID and provider. Details: {e}"
|
| 68 |
+
return
|
| 69 |
+
|
| 70 |
+
# --- RAG Pipeline ---
|
| 71 |
+
# 2. Format conversation history for context
|
| 72 |
+
history = history[-MAX_HISTORY_CONVERSATION:]
|
| 73 |
+
conversation = "".join(f"User: {user_msg}\nBot: {bot_msg}\n" for user_msg, bot_msg in history)
|
| 74 |
+
query_for_rag = conversation + f"User: {message}\nBot:"
|
| 75 |
+
|
| 76 |
+
# 3. Generate a search query from the conversation
|
| 77 |
+
rag_query = chat_assistant.get_response(request_retrieve_prompt.format(role="user", conversation=query_for_rag))
|
| 78 |
+
rag_query = rag_query[rag_query.lower().rfind("[") + 1: rag_query.rfind("]")]
|
| 79 |
+
|
| 80 |
+
# 4. Retrieve relevant documents if necessary
|
| 81 |
+
if "NO" not in rag_query:
|
| 82 |
+
retrieve_results = vretrieve(rag_query, vectorstore, docs, k=4, metric="mmr", threshold=0.7)
|
| 83 |
+
else:
|
| 84 |
+
retrieve_results = []
|
| 85 |
+
|
| 86 |
+
retrieved_docs = "\n".join([f"Document {i+1}:\n" + doc.page_content for i, doc in enumerate(retrieve_results)])
|
| 87 |
+
log(f"** RAG query **: {rag_query}")
|
| 88 |
+
log(f"** Retrieved documents **:\n{retrieved_docs}")
|
| 89 |
+
|
| 90 |
+
# --- Final Response Generation ---
|
| 91 |
+
# 5. Create the final prompt with retrieved context
|
| 92 |
+
final_prompt = retrieve_chatbot_prompt.format(role="user", documents=retrieved_docs, conversation=query_for_rag)
|
| 93 |
+
|
| 94 |
+
# 6. Stream the response from the LLM
|
| 95 |
+
response = ""
|
| 96 |
+
for token in chat_assistant.get_streaming_response(final_prompt, sys):
|
| 97 |
+
response += token
|
| 98 |
+
yield response
|
| 99 |
+
|
| 100 |
+
log(f"** Bot response **: {response}")
|
| 101 |
+
log("=" * 50 + "\n\n")
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# --- Gradio UI ---
|
| 105 |
+
with gr.Blocks(theme="soft") as chatbot_ui:
|
| 106 |
+
gr.Markdown("# MedLLM")
|
| 107 |
+
|
| 108 |
+
model_selector = gr.Dropdown(
|
| 109 |
+
label="Select Model",
|
| 110 |
+
choices=list(AVAILABLE_MODELS.keys()),
|
| 111 |
+
value=DEFAULT_MODEL_KEY,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
chatbot = gr.Chatbot(label="Chat Window", height=500, bubble_full_width=False)
|
| 115 |
+
msg_input = gr.Textbox(label="Your Message", placeholder="Type your question here and press Enter...", scale=7)
|
| 116 |
+
|
| 117 |
+
def respond(message, chat_history, selected_model_key):
|
| 118 |
+
"""Wrapper function to connect chatbot_logic with Gradio's state."""
|
| 119 |
+
bot_message_stream = chatbot_logic(message, chat_history, selected_model_key)
|
| 120 |
+
chat_history.append([message, ""])
|
| 121 |
+
for token in bot_message_stream:
|
| 122 |
+
chat_history[-1][1] = token
|
| 123 |
+
yield chat_history
|
| 124 |
+
|
| 125 |
+
msg_input.submit(
|
| 126 |
+
respond,
|
| 127 |
+
[msg_input, chatbot, model_selector],
|
| 128 |
+
[chatbot]
|
| 129 |
+
).then(
|
| 130 |
+
lambda: gr.update(value=""), None, [msg_input], queue=False
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# --- Launch the App ---
|
| 135 |
+
if __name__ == "__main__":
|
| 136 |
+
chatbot_ui.launch(debug=True, share=False)
|
config.yaml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version: 0.1
|
| 2 |
+
|
| 3 |
+
model:
|
| 4 |
+
name: "llama2:7b"
|
| 5 |
+
temperature: 0.3
|
| 6 |
+
max_tokens: 100000
|
| 7 |
+
provider: "ollama"
|
| 8 |
+
base_url: "http://localhost:11434/v1"
|
| 9 |
+
|
| 10 |
+
rag_config:
|
| 11 |
+
k: 4
|
| 12 |
+
rerank:
|
| 13 |
+
name: "bge-reranker-large"
|
| 14 |
+
model: "BAAI/bge-reranker-large"
|
| 15 |
+
top_n: 100
|
| 16 |
+
embed_model:
|
| 17 |
+
name: "gte-multilingual-base"
|
| 18 |
+
model: "alibaba-nlp/gte-multilingual-base"
|
| 19 |
+
chunk_size: 2048
|
| 20 |
+
chunk_overlap: 512
|
| 21 |
+
similarity_threshold: 0.7
|
| 22 |
+
similarity_metric: "cosine"
|
| 23 |
+
|
| 24 |
+
knowledge:
|
| 25 |
+
vectorstore: "faiss"
|
knowledge/vectorstore_1/config.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"data_paths": ["dataset/RAG_Data/wiki_vi", "dataset/RAG_Data/youmed"], "vectorstore_dir": "notebook/An/master/knowledge/vectorstore_1", "file_type": "txt", "embed_model_name": "alibaba-nlp/gte-multilingual-base", "chunk_size": 2048, "chunk_overlap": 512, "chunk_method": "markdown", "vectorstore": "faiss", "clear_vectorstore": true}
|
knowledge/vectorstore_1/docs.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9e969c1a0beb575363fc3cd0e252b9751f9ad79fc605ec6ab4a2c4ee68845e43
|
| 3 |
+
size 7568017
|
knowledge/vectorstore_1/index.faiss
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:10ffba3c9fc6846d51de37463833eecf8b42b036a78e93e90ff779fbd47268f6
|
| 3 |
+
size 9440301
|
knowledge/vectorstore_1/index.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c9d81a73aa18b621f660a69e7ce3bba1b8b1875e983752a1e504f1f2922a7fdc
|
| 3 |
+
size 7730542
|
rag_pipeline/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .generation.llm_wrapper import ChatAssistant
|
| 2 |
+
from .indexing.chunking.recursive import split_document as recursive_chunking
|
| 3 |
+
from .indexing.chunking.markdown import split_document as markdown_chunking
|
| 4 |
+
from .indexing.embedding.embedding import get_embeddings
|
| 5 |
+
from .data_ingest.loader import load_data
|
| 6 |
+
from .generation.prompt_template import *
|
| 7 |
+
from .retrieval.vector_retriever import retrieve as vretrieve
|
| 8 |
+
from .retrieval.reranker import rerank
|
rag_pipeline/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (704 Bytes). View file
|
|
|
rag_pipeline/data_ingest/__pycache__/loader.cpython-313.pyc
ADDED
|
Binary file (2.09 kB). View file
|
|
|
rag_pipeline/data_ingest/loader.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List
|
| 3 |
+
from langchain.schema import Document
|
| 4 |
+
|
| 5 |
+
def load_data(data_path: str, file_type: str) -> List[Document]:
|
| 6 |
+
"""
|
| 7 |
+
Load knowledge data from a specified path and file type.
|
| 8 |
+
Args:
|
| 9 |
+
data_path: The path to the data.
|
| 10 |
+
file_type: The type of the data.
|
| 11 |
+
Returns:
|
| 12 |
+
A list of documents.
|
| 13 |
+
"""
|
| 14 |
+
if file_type == "pdf":
|
| 15 |
+
raise NotImplementedError("PDF loading is not yet implemented.")
|
| 16 |
+
elif file_type == "txt":
|
| 17 |
+
return _load_txt(data_path)
|
| 18 |
+
|
| 19 |
+
def _load_txt(data_path: str) -> List[Document]:
|
| 20 |
+
splits = []
|
| 21 |
+
|
| 22 |
+
if not os.path.isdir(data_path):
|
| 23 |
+
raise FileNotFoundError(f"Error: Directory not found at {data_path}")
|
| 24 |
+
|
| 25 |
+
for file_name in os.listdir(data_path):
|
| 26 |
+
if file_name.endswith('.txt'):
|
| 27 |
+
file_path = os.path.join(data_path, file_name)
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 31 |
+
content = f.read()
|
| 32 |
+
metadata = {"source": file_name}
|
| 33 |
+
doc = Document(page_content=content, metadata=metadata)
|
| 34 |
+
|
| 35 |
+
splits.append(doc)
|
| 36 |
+
|
| 37 |
+
except Exception as e:
|
| 38 |
+
print(f"Error reading file {file_path}: {e}")
|
| 39 |
+
|
| 40 |
+
return splits
|
rag_pipeline/data_ingest/parser.py
ADDED
|
File without changes
|
rag_pipeline/generation/__pycache__/llm_wrapper.cpython-313.pyc
ADDED
|
Binary file (2.9 kB). View file
|
|
|
rag_pipeline/generation/__pycache__/prompt_template.cpython-313.pyc
ADDED
|
Binary file (3.99 kB). View file
|
|
|
rag_pipeline/generation/llm_wrapper.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from openai import OpenAI
|
| 2 |
+
import backoff
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
_base_url_ ={
|
| 7 |
+
"ollama": "http://localhost:11434/v1",
|
| 8 |
+
"mistral": "https://api.mistral.ai/v1",
|
| 9 |
+
"openai": "https://api.openai.com/v1",
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
_api_key_ = {
|
| 13 |
+
"ollama": "ollama",
|
| 14 |
+
"mistral": os.getenv("MISTRAL_API_KEY"),
|
| 15 |
+
"openai": os.getenv("OPENAI_API_KEY"),
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
class ChatAssistant:
|
| 19 |
+
def __init__(self, model_name:str, provider:str = "ollama"):
|
| 20 |
+
"""
|
| 21 |
+
Args:
|
| 22 |
+
model_name: The name of the model to use.
|
| 23 |
+
provider: The provider of the model. Can be "ollama", "mistral", or "openai".
|
| 24 |
+
"""
|
| 25 |
+
self.model_name = model_name
|
| 26 |
+
self.client = OpenAI(
|
| 27 |
+
base_url=_base_url_[provider],
|
| 28 |
+
api_key=_api_key_[provider],
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
@backoff.on_exception(backoff.expo, Exception)
|
| 32 |
+
def get_response(self, user: str, sys: str = ""):
|
| 33 |
+
response = self.client.chat.completions.create(
|
| 34 |
+
model=self.model_name,
|
| 35 |
+
messages=[
|
| 36 |
+
{"role": "system", "content": sys},
|
| 37 |
+
{"role": "user", "content": user},
|
| 38 |
+
]
|
| 39 |
+
)
|
| 40 |
+
return response.choices[0].message.content
|
| 41 |
+
|
| 42 |
+
@backoff.on_exception(backoff.expo, Exception)
|
| 43 |
+
def get_streaming_response(self, user: str, sys: str = ""):
|
| 44 |
+
"""Yields the response token by token (streaming)."""
|
| 45 |
+
response_stream = self.client.chat.completions.create(
|
| 46 |
+
model=self.model_name,
|
| 47 |
+
messages=[
|
| 48 |
+
{"role": "system", "content": sys},
|
| 49 |
+
{"role": "user", "content": user},
|
| 50 |
+
],
|
| 51 |
+
stream=True
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# Iterate over the stream of chunks
|
| 55 |
+
for chunk in response_stream:
|
| 56 |
+
# The actual token is in chunk.choices[0].delta.content
|
| 57 |
+
token = chunk.choices[0].delta.content
|
| 58 |
+
if token is not None:
|
| 59 |
+
yield token
|
rag_pipeline/generation/prompt_template.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
multichoice_qa_prompt = """
|
| 2 |
+
-- DOCUMENT --
|
| 3 |
+
{document}
|
| 4 |
+
-- END OF DOCUMENT --
|
| 5 |
+
|
| 6 |
+
-- INSTRUCTION --
|
| 7 |
+
You are a medical expert.
|
| 8 |
+
Given the documents, you must answer the question follow these step by step.
|
| 9 |
+
First, you must read the question and the options, and draft an answer for it based on your knowledge.
|
| 10 |
+
Second, you must read the documents and check if they can help answer the question.
|
| 11 |
+
Third, you cross check the document with your knowledge and the draft answer.
|
| 12 |
+
Finally, you answer the question based on your knowledge and the true documents.
|
| 13 |
+
Your response must end with the letter of the most correct option like: "the answer is A".
|
| 14 |
+
The entire thought must under 500 words long.
|
| 15 |
+
-- END OF INSTRUCTION --
|
| 16 |
+
|
| 17 |
+
-- QUESTION --
|
| 18 |
+
{question}
|
| 19 |
+
{options}
|
| 20 |
+
-- END OF QUESTION --
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
qa_prompt = """
|
| 24 |
+
-- DOCUMENT --
|
| 25 |
+
{document}
|
| 26 |
+
-- END OF DOCUMENT --
|
| 27 |
+
|
| 28 |
+
-- INSTRUCTION --
|
| 29 |
+
You are a medical expert.
|
| 30 |
+
Given the documents, you must answer the question follow these step by step.
|
| 31 |
+
First, you must read the question and draft an answer for it based on your knowledge.
|
| 32 |
+
Second, you must read the documents and check if they can help answer the question.
|
| 33 |
+
Third, you cross check the document with your knowledge and the draft answer.
|
| 34 |
+
Finally, you answer the question based on your knowledge and the true documents concisely.
|
| 35 |
+
Your response must as shortest as possible, in Vietnamese and between brackets like: "[...]".
|
| 36 |
+
-- END OF INSTRUCTION --
|
| 37 |
+
|
| 38 |
+
-- QUESTION --
|
| 39 |
+
{question}
|
| 40 |
+
-- END OF QUESTION --
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
retrieve_chatbot_prompt = """
|
| 44 |
+
You are a medical expert.
|
| 45 |
+
You are having a conversation with a {role} and you have an external documents to help you.
|
| 46 |
+
Continue the conversation based on the chat history, the context information, and not prior knowledge.
|
| 47 |
+
Before use the retrieved chunk, you must check if it is relevant to the user query. If it is not relevant, you must ignore it.
|
| 48 |
+
You use the relevant chunk to answer the question and cite the source inside <<<>>>.
|
| 49 |
+
If you don't know the answer, you must say "I don't know".
|
| 50 |
+
---------------------
|
| 51 |
+
{documents}
|
| 52 |
+
---------------------
|
| 53 |
+
Given the documents and not prior knowledge, continue the conversation.
|
| 54 |
+
---------------------
|
| 55 |
+
{conversation}
|
| 56 |
+
---------------------
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
request_retrieve_prompt = """
|
| 60 |
+
--- INSTRUCTION ---
|
| 61 |
+
You are having a conversation with a {role}.
|
| 62 |
+
You have to provide a short query to retrieve the documents that you need inside the brackets like: "[...]".
|
| 63 |
+
If it is something do not related to medical field, or something you do not need the external knowledge to answer, you must write "[NO]".
|
| 64 |
+
--- END OF INSTRUCTION ---
|
| 65 |
+
|
| 66 |
+
--- COVERSATION ---
|
| 67 |
+
{conversation}
|
| 68 |
+
--- END OF COVERSATION ---
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
answer_prompt = """
|
| 72 |
+
-- INSTRUCTION --
|
| 73 |
+
You are a medical expert.
|
| 74 |
+
Given the documents below, you must answer the question step by step.
|
| 75 |
+
First, you must read the question.
|
| 76 |
+
Second, you must read the documents and check for it's reliability.
|
| 77 |
+
Third, you cross check with your knowledge.
|
| 78 |
+
Finally, you answer the question based on your knowledge and the true documents.
|
| 79 |
+
|
| 80 |
+
Your answer must UNDER 50 words, write on 1 line and write in Vietnamese.
|
| 81 |
+
-- END OF INSTRUCTION --
|
| 82 |
+
|
| 83 |
+
-- QUESTION --
|
| 84 |
+
{question}
|
| 85 |
+
-- END OF QUESTION --
|
| 86 |
+
|
| 87 |
+
-- DOCUMENT --
|
| 88 |
+
{document}
|
| 89 |
+
-- END OF DOCUMENT --
|
| 90 |
+
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
translate_prompt = """
|
| 94 |
+
[ INSTRUCTION ]
|
| 95 |
+
You are a Medical translator expert.
|
| 96 |
+
Your task is to translate this English question into Vietnamese with EXACTLY the same format and write in 1 line.
|
| 97 |
+
[ END OF INSTRUCTION ]
|
| 98 |
+
|
| 99 |
+
[ QUERY TO TRANSLATE ]
|
| 100 |
+
{query}
|
| 101 |
+
[ END OF QUERY TO TRANSLATE ]
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
pdf2txt_prompt = """
|
| 105 |
+
Rewrite this plain text from pdf file follow the right reading order and these instructions:
|
| 106 |
+
- Use markdown format.
|
| 107 |
+
- Use same language.
|
| 108 |
+
- Keep the content intact.
|
| 109 |
+
- Beautify the table.
|
| 110 |
+
- No talk.
|
| 111 |
+
|
| 112 |
+
[ QUERY ]
|
| 113 |
+
{query}
|
| 114 |
+
[ END OF QUERY ]
|
| 115 |
+
"""
|
rag_pipeline/indexing/chunking/__pycache__/markdown.cpython-313.pyc
ADDED
|
Binary file (2.55 kB). View file
|
|
|
rag_pipeline/indexing/chunking/__pycache__/recursive.cpython-313.pyc
ADDED
|
Binary file (1.52 kB). View file
|
|
|
rag_pipeline/indexing/chunking/markdown.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain.text_splitter import MarkdownHeaderTextSplitter, RecursiveCharacterTextSplitter
|
| 2 |
+
from langchain.schema import Document
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
def __split_1_document__(document: Document, chunk_size: int, chunk_overlap: int) -> List[Document]:
|
| 6 |
+
headers_to_split_on = [
|
| 7 |
+
("#", "Header 1"),
|
| 8 |
+
("##", "Header 2"),
|
| 9 |
+
("###", "Header 3"),
|
| 10 |
+
]
|
| 11 |
+
|
| 12 |
+
markdown_splitter = MarkdownHeaderTextSplitter(
|
| 13 |
+
headers_to_split_on=headers_to_split_on,
|
| 14 |
+
strip_headers=False,
|
| 15 |
+
return_each_line=False
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
md_header_splits = markdown_splitter.split_text(document.page_content)
|
| 19 |
+
|
| 20 |
+
for doc in md_header_splits:
|
| 21 |
+
doc.metadata.update(document.metadata)
|
| 22 |
+
|
| 23 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
| 24 |
+
chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
final_splits = text_splitter.split_documents(md_header_splits)
|
| 28 |
+
|
| 29 |
+
# Iterate through the final chunks to prepend metadata to the page_content
|
| 30 |
+
for i, doc in enumerate(final_splits):
|
| 31 |
+
header_lines = []
|
| 32 |
+
source_line = f"-- source: {doc.metadata.get('source', 'N/A')}"
|
| 33 |
+
|
| 34 |
+
if 'Header 1' in doc.metadata:
|
| 35 |
+
header_lines.append(doc.metadata['Header 1'])
|
| 36 |
+
if 'Header 2' in doc.metadata:
|
| 37 |
+
header_lines.append(doc.metadata['Header 2'])
|
| 38 |
+
if 'Header 3' in doc.metadata:
|
| 39 |
+
header_lines.append(doc.metadata['Header 3'])
|
| 40 |
+
|
| 41 |
+
header_content = "\n".join(header_lines)
|
| 42 |
+
chunk_header = f"Chunk {i+1}:"
|
| 43 |
+
|
| 44 |
+
# Combine everything into the new page content
|
| 45 |
+
original_content = doc.page_content
|
| 46 |
+
doc.page_content = f"{source_line}\n{header_content}\n{chunk_header}\n{original_content}"
|
| 47 |
+
|
| 48 |
+
return final_splits
|
| 49 |
+
|
| 50 |
+
def split_document(documents: List[Document], chunk_size: int, chunk_overlap: int) -> List[Document]:
|
| 51 |
+
split_documents = []
|
| 52 |
+
for doc in documents:
|
| 53 |
+
split_documents.extend(__split_1_document__(doc, chunk_size, chunk_overlap))
|
| 54 |
+
return split_documents
|
rag_pipeline/indexing/chunking/recursive.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 2 |
+
from langchain.schema import Document
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
def __split_1_document__(document: Document, chunk_size: int, chunk_overlap: int) -> List[Document]:
|
| 6 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
| 7 |
+
chunk_size=chunk_size,
|
| 8 |
+
chunk_overlap=chunk_overlap,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
text_content = document.page_content
|
| 12 |
+
text_chunks = text_splitter.split_text(text_content)
|
| 13 |
+
split_documents = []
|
| 14 |
+
|
| 15 |
+
for i, chunk in enumerate(text_chunks):
|
| 16 |
+
new_metadata = document.metadata.copy()
|
| 17 |
+
|
| 18 |
+
# new_metadata['chunk_number'] = i + 1
|
| 19 |
+
|
| 20 |
+
new_doc = Document(page_content=chunk, metadata=new_metadata)
|
| 21 |
+
split_documents.append(new_doc)
|
| 22 |
+
|
| 23 |
+
return split_documents
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def split_document(documents: List[Document], chunk_size: int, chunk_overlap: int) -> List[Document]:
|
| 27 |
+
split_documents = []
|
| 28 |
+
for doc in documents:
|
| 29 |
+
split_documents.extend(__split_1_document__(doc, chunk_size, chunk_overlap))
|
| 30 |
+
return split_documents
|
rag_pipeline/indexing/embedding/__pycache__/embedding.cpython-313.pyc
ADDED
|
Binary file (1.06 kB). View file
|
|
|
rag_pipeline/indexing/embedding/embedding.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
_model_cache = {}
|
| 6 |
+
|
| 7 |
+
def get_embeddings(model_name: str, show_progress: bool = True) -> HuggingFaceEmbeddings:
|
| 8 |
+
"""
|
| 9 |
+
Get the embeddings model. Cache available.
|
| 10 |
+
Args:
|
| 11 |
+
model_name: The name of the model.
|
| 12 |
+
Returns:
|
| 13 |
+
The embeddings model.
|
| 14 |
+
"""
|
| 15 |
+
if model_name not in _model_cache:
|
| 16 |
+
embeddings = HuggingFaceEmbeddings(
|
| 17 |
+
model_name=model_name,
|
| 18 |
+
show_progress=show_progress,
|
| 19 |
+
model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu', 'trust_remote_code':True},
|
| 20 |
+
encode_kwargs={'batch_size': 15}
|
| 21 |
+
)
|
| 22 |
+
_model_cache[model_name] = embeddings
|
| 23 |
+
return _model_cache[model_name]
|
rag_pipeline/retrieval/__pycache__/reranker.cpython-313.pyc
ADDED
|
Binary file (504 Bytes). View file
|
|
|
rag_pipeline/retrieval/__pycache__/vector_retriever.cpython-313.pyc
ADDED
|
Binary file (2.25 kB). View file
|
|
|
rag_pipeline/retrieval/graph_retriever.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Any
|
| 2 |
+
|
| 3 |
+
def retrieve(query: str, graphstore: Any = None) -> List[str]:
|
| 4 |
+
pass
|
rag_pipeline/retrieval/hybrid_retriever.py
ADDED
|
File without changes
|
rag_pipeline/retrieval/reranker.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
from langchain.schema import Document
|
| 6 |
+
|
| 7 |
+
def rerank(docs: List[Document]) -> List[Document]:
|
| 8 |
+
return docs
|
rag_pipeline/retrieval/vector_retriever.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_community.vectorstores import FAISS
|
| 2 |
+
from langchain.schema import Document
|
| 3 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 4 |
+
|
| 5 |
+
from .reranker import rerank
|
| 6 |
+
|
| 7 |
+
from typing import List, Any
|
| 8 |
+
|
| 9 |
+
def retrieve(query: str, vectorstore: FAISS, docs: List[Document] = None, k: int = 4, metric: str = "cosine", threshold: float = 0.5, reranker: Any = None) -> List[Document]:
|
| 10 |
+
"""
|
| 11 |
+
Retrieve documents from the vectorstore based on the query and metric.
|
| 12 |
+
Args:
|
| 13 |
+
query: The query to search for.
|
| 14 |
+
metric: The metric to use for retrieval.
|
| 15 |
+
vectorstore: The vectorstore to search in.
|
| 16 |
+
k: The number of documents to retrieve.
|
| 17 |
+
threshold: The threshold for the metric to use for retrieval.
|
| 18 |
+
reranker: The reranker to use for reranking the retrieved documents.
|
| 19 |
+
Returns:
|
| 20 |
+
A list of documents.
|
| 21 |
+
"""
|
| 22 |
+
if metric == "cosine":
|
| 23 |
+
docs = vectorstore.similarity_search_with_score(query, k=k)
|
| 24 |
+
docs = [doc for doc, score in docs if score > threshold]
|
| 25 |
+
elif metric == "mmr":
|
| 26 |
+
docs = vectorstore.max_marginal_relevance_search(query, k=k)
|
| 27 |
+
elif metric == "bm25":
|
| 28 |
+
from langchain_community.retrievers import BM25Retriever
|
| 29 |
+
if docs is None:
|
| 30 |
+
raise ValueError("Documents not available. BM25 requires ingested or loaded documents.")
|
| 31 |
+
bm25_retriever = BM25Retriever.from_documents(docs)
|
| 32 |
+
docs = bm25_retriever.get_relevant_documents(query, k=k)
|
| 33 |
+
else:
|
| 34 |
+
raise ValueError(f"Unsupported metric: '{metric}'. Supported metrics are 'similarity', 'mmr', and 'bm25'.")
|
| 35 |
+
|
| 36 |
+
if (reranker != None):
|
| 37 |
+
return rerank(docs)
|
| 38 |
+
return docs
|
requirements.txt
ADDED
|
Binary file (11.9 kB). View file
|
|
|
test/__pycache__/_normalize_qa.cpython-313.pyc
ADDED
|
Binary file (2.2 kB). View file
|
|
|
test/__pycache__/data_ingest.cpython-313.pyc
ADDED
|
Binary file (3.98 kB). View file
|
|
|
test/__pycache__/eval_lm.cpython-313.pyc
ADDED
|
Binary file (5.68 kB). View file
|
|
|
test/__pycache__/eval_qa.cpython-313.pyc
ADDED
|
Binary file (6.52 kB). View file
|
|
|
test/__pycache__/prepare_retrieve.cpython-313.pyc
ADDED
|
Binary file (3.71 kB). View file
|
|
|
test/__pycache__/test_llm.cpython-313.pyc
ADDED
|
Binary file (603 Bytes). View file
|
|
|
test/__pycache__/test_retrieve.cpython-313.pyc
ADDED
|
Binary file (2.32 kB). View file
|
|
|
test/_normalize_qa.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# import json
|
| 2 |
+
# import uuid
|
| 3 |
+
|
| 4 |
+
# origin_qa_data_path = 'dataset/QA Data/MedMCQA/hard_questions.jsonl'
|
| 5 |
+
# target_qa_data_path = 'dataset/QA Data/MedMCQA/translated_hard_questions.jsonl'
|
| 6 |
+
|
| 7 |
+
# def transform_id(origin_id):
|
| 8 |
+
# # Add 'T' prefix and remove last character
|
| 9 |
+
# return ' T' + origin_id[:-1]
|
| 10 |
+
|
| 11 |
+
# def update_answers():
|
| 12 |
+
# # Read origin data
|
| 13 |
+
# with open(origin_qa_data_path, 'r', encoding='utf-8') as f:
|
| 14 |
+
# origin_data = [json.loads(line) for line in f]
|
| 15 |
+
|
| 16 |
+
# # Read target data
|
| 17 |
+
# with open(target_qa_data_path, 'r', encoding='utf-8') as f:
|
| 18 |
+
# target_data = [json.loads(line) for line in f]
|
| 19 |
+
|
| 20 |
+
# c = []
|
| 21 |
+
# for item in origin_data:
|
| 22 |
+
# for target_item in target_data:
|
| 23 |
+
# if transform_id(item['id']) == target_item['uuid']:
|
| 24 |
+
# if item['cop'] == 0:
|
| 25 |
+
# target_item['answer'] = 'A'
|
| 26 |
+
# elif item['cop'] == 1:
|
| 27 |
+
# target_item['answer'] = 'B'
|
| 28 |
+
# elif item['cop'] == 2:
|
| 29 |
+
# target_item['answer'] = 'C'
|
| 30 |
+
# elif item['cop'] == 3:
|
| 31 |
+
# target_item['answer'] = 'D'
|
| 32 |
+
# c.extend([target_item['uuid']])
|
| 33 |
+
# # print(c)
|
| 34 |
+
# for item in target_data:
|
| 35 |
+
# if item['uuid'] not in c:
|
| 36 |
+
# print(item['uuid'])
|
| 37 |
+
# # Write updated target data back to file
|
| 38 |
+
# with open(target_qa_data_path, 'w', encoding='utf-8') as f:
|
| 39 |
+
# for item in target_data:
|
| 40 |
+
# f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
| 41 |
+
|
| 42 |
+
# # Call the function to update answers
|
| 43 |
+
# update_answers()
|
test/chatbot_inference.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rag_pipeline import get_embeddings, vretrieve, rerank
|
| 2 |
+
from utils import load_local
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
|
| 6 |
+
def inference():
|
| 7 |
+
embed_model = get_embeddings(args.embed_model_name)
|
| 8 |
+
vectorstore, docs = load_local(args.vectorstore_dir, embed_model)
|
| 9 |
+
retrieve_results = vretrieve(args.query, vectorstore, docs, args.retriever_k, args.metric, args.threshold)
|
| 10 |
+
|
| 11 |
+
retrieve_results = rerank(retrieve_results)
|
| 12 |
+
|
| 13 |
+
print(retrieve_results)
|
| 14 |
+
|
| 15 |
+
def conversation():
|
| 16 |
+
while True:
|
| 17 |
+
query = input("User: ")
|
| 18 |
+
if query == "exit":
|
| 19 |
+
break
|
| 20 |
+
inference(query)
|
| 21 |
+
|
| 22 |
+
if __name__ == '__main__':
|
| 23 |
+
conversation()
|
test/data_ingest.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
from ..rag_pipeline import get_embeddings, load_data
|
| 6 |
+
from ..utils import load_local, save_local
|
| 7 |
+
|
| 8 |
+
def main(args):
|
| 9 |
+
print(f"Log: {args}")
|
| 10 |
+
|
| 11 |
+
if args.clear_vectorstore:
|
| 12 |
+
import shutil
|
| 13 |
+
if os.path.isdir(args.vectorstore_dir):
|
| 14 |
+
shutil.rmtree(args.vectorstore_dir)
|
| 15 |
+
|
| 16 |
+
embed_model = get_embeddings(args.embed_model_name)
|
| 17 |
+
vectorstore, docs = load_local(args.vectorstore_dir, embed_model)
|
| 18 |
+
|
| 19 |
+
new_docs = []
|
| 20 |
+
for data_path in args.data_paths:
|
| 21 |
+
new_docs.extend(load_data(data_path, args.file_type))
|
| 22 |
+
print(f"Got {len(new_docs)} documents.")
|
| 23 |
+
|
| 24 |
+
if args.chunk_method == "recursive":
|
| 25 |
+
from ..rag_pipeline import recursive_chunking
|
| 26 |
+
new_docs = recursive_chunking(new_docs, args.chunk_size, args.chunk_overlap)
|
| 27 |
+
elif args.chunk_method == "markdown":
|
| 28 |
+
from ..rag_pipeline import markdown_chunking
|
| 29 |
+
new_docs = markdown_chunking(new_docs, args.chunk_size, args.chunk_overlap)
|
| 30 |
+
print(f"Got {len(new_docs)} chunks.")
|
| 31 |
+
|
| 32 |
+
from langchain_community.vectorstores import FAISS
|
| 33 |
+
if vectorstore is None:
|
| 34 |
+
vectorstore = FAISS.from_documents(new_docs, embed_model)
|
| 35 |
+
docs = new_docs
|
| 36 |
+
print(f"Successfully consumed {len(new_docs)} documents.")
|
| 37 |
+
else:
|
| 38 |
+
docs.extend(new_docs)
|
| 39 |
+
vectorstore.add_documents(new_docs)
|
| 40 |
+
|
| 41 |
+
save_local(args.vectorstore_dir, vectorstore, docs)
|
| 42 |
+
|
| 43 |
+
import json
|
| 44 |
+
with open(os.path.join(args.vectorstore_dir, "config.json"), "a") as f:
|
| 45 |
+
json.dump(vars(args), f)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
if __name__ == '__main__':
|
| 49 |
+
parser = argparse.ArgumentParser()
|
| 50 |
+
|
| 51 |
+
data_paths = [
|
| 52 |
+
'dataset/RAG_Data/wiki_vi',
|
| 53 |
+
'dataset/RAG_Data/youmed',
|
| 54 |
+
'dataset/RAG_Data/mimic_ex_report',
|
| 55 |
+
'dataset/RAG_Data/Download sach y/OCR',
|
| 56 |
+
]
|
| 57 |
+
|
| 58 |
+
# Dataset params
|
| 59 |
+
parser.add_argument("--data_paths", type=List[str], required=False, default=data_paths)
|
| 60 |
+
parser.add_argument("--vectorstore_dir", type=str, required=False, default="notebook/An/master/knowledge/vectorstore_full")
|
| 61 |
+
parser.add_argument("--file_type", type=str, choices=["pdf", "txt"], default="txt")
|
| 62 |
+
|
| 63 |
+
# Model params
|
| 64 |
+
parser.add_argument("--embed_model_name", type=str, default="alibaba-nlp/gte-multilingual-base")
|
| 65 |
+
|
| 66 |
+
# Index params
|
| 67 |
+
parser.add_argument("--chunk_size", type=int, default=2048)
|
| 68 |
+
parser.add_argument("--chunk_overlap", type=int, default=512)
|
| 69 |
+
parser.add_argument("--chunk_method", type=str, choices=["recursive", "markdown"], default="markdown")
|
| 70 |
+
|
| 71 |
+
# Vectorstore params
|
| 72 |
+
parser.add_argument("--vectorstore", type=str, choices=["faiss", "chroma"], default="faiss")
|
| 73 |
+
parser.add_argument("--clear_vectorstore", action="store_true", default=True)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
args = parser.parse_args()
|
| 77 |
+
|
| 78 |
+
main(args)
|
test/eval_lm.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from ..rag_pipeline import qa_prompt
|
| 3 |
+
from ..rag_pipeline import ChatAssistant
|
| 4 |
+
from ..utils import load_qa_dataset, load_prepared_retrieve_docs
|
| 5 |
+
|
| 6 |
+
from typing import List, Optional
|
| 7 |
+
from langchain.schema import Document
|
| 8 |
+
|
| 9 |
+
def get_answer_from_response(llm_response: str) -> str:
|
| 10 |
+
return llm_response.strip()
|
| 11 |
+
|
| 12 |
+
def build_qa_prompt(question: str, document: Optional[List[Document]]) -> str:
|
| 13 |
+
if document is not None:
|
| 14 |
+
document = '\n'.join([f"Document {i+1}:\n" + doc.page_content for i,doc in enumerate(document)])
|
| 15 |
+
|
| 16 |
+
return qa_prompt.format(question=question, document=document)
|
| 17 |
+
|
| 18 |
+
def process_question(question, prompt, answer, id, args, llm):
|
| 19 |
+
llm_response = llm.get_response("", prompt)
|
| 20 |
+
# ans = get_answer_from_response(llm_response)
|
| 21 |
+
with open("log.txt", "a", encoding="utf-8") as f:
|
| 22 |
+
f.write(f"ID: {id}\n")
|
| 23 |
+
f.write(prompt)
|
| 24 |
+
f.write(f"LLM Response:\n{llm_response}\n")
|
| 25 |
+
f.write(f"Answer: {answer} \n\n")
|
| 26 |
+
|
| 27 |
+
# with open("log_score.txt", "a", encoding="utf-8") as f:
|
| 28 |
+
# f.write("1" if ans == answer else "0")
|
| 29 |
+
# return 1 if ans == answer else 0
|
| 30 |
+
return llm_response
|
| 31 |
+
|
| 32 |
+
def evaluate_qa(questions, prompts, answers, ids, args, llm):
|
| 33 |
+
import concurrent.futures
|
| 34 |
+
from tqdm import tqdm
|
| 35 |
+
ans = []
|
| 36 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor:
|
| 37 |
+
futures = [executor.submit(process_question, questions[i], prompts[i], answers[i], ids[i], args, llm) for i in range(len(questions))]
|
| 38 |
+
for future in tqdm(concurrent.futures.as_completed(futures), total=len(questions)):
|
| 39 |
+
ans.append(future.result())
|
| 40 |
+
return ans
|
| 41 |
+
|
| 42 |
+
def main(args):
|
| 43 |
+
ids, questions, options, answers = load_qa_dataset(args.qa_file)
|
| 44 |
+
|
| 45 |
+
if ids is None:
|
| 46 |
+
raise ValueError(f"No id field in {args.qa_file}.")
|
| 47 |
+
|
| 48 |
+
if args.num_docs > 0:
|
| 49 |
+
if args.prepared_retrieve_docs_path is not None:
|
| 50 |
+
documents = load_prepared_retrieve_docs(args.prepared_retrieve_docs_path)
|
| 51 |
+
docs = [d[:args.num_docs] for i,d in enumerate(documents)]
|
| 52 |
+
else:
|
| 53 |
+
raise ValueError(f"No prepared retrieve docs found.")
|
| 54 |
+
else:
|
| 55 |
+
docs = [None]*len(questions)
|
| 56 |
+
|
| 57 |
+
prompts = [build_qa_prompt(questions[i], docs[i]) for i in range(len(questions))]
|
| 58 |
+
|
| 59 |
+
llm = ChatAssistant(args.model_name, args.provider)
|
| 60 |
+
|
| 61 |
+
with open("log_score.txt", "a", encoding="utf-8") as f:
|
| 62 |
+
f.write("\n")
|
| 63 |
+
|
| 64 |
+
qa_results = evaluate_qa(questions, prompts, answers, ids, args, llm)
|
| 65 |
+
qa_results = [qa_results[i][qa_results[i].rfind("[")+1:qa_results[i].rfind("]")] for i in range(len(qa_results))]
|
| 66 |
+
# print(f"{qa_results}")
|
| 67 |
+
import pyperclip
|
| 68 |
+
pyperclip.copy('\n'.join(qa_results))
|
| 69 |
+
|
| 70 |
+
if __name__ == '__main__':
|
| 71 |
+
parser = argparse.ArgumentParser()
|
| 72 |
+
|
| 73 |
+
parser.add_argument("--qa_file", type=str, default="dataset/QA Data/random.jsonl")
|
| 74 |
+
parser.add_argument("--prepared_retrieve_docs_path", type=str, default="prepared_retrieve_docs.pkl")
|
| 75 |
+
|
| 76 |
+
parser.add_argument("--model_name", type=str, default="mistral-medium")
|
| 77 |
+
parser.add_argument("--provider", type=str, default="mistral")
|
| 78 |
+
parser.add_argument("--max_workers", type=int, default=4)
|
| 79 |
+
parser.add_argument("--num_docs", type=int, default=0)
|
| 80 |
+
|
| 81 |
+
parser.add_argument("--dataset_path", type=str)
|
| 82 |
+
|
| 83 |
+
args = parser.parse_args()
|
| 84 |
+
|
| 85 |
+
print(args)
|
| 86 |
+
|
| 87 |
+
main(args)
|
test/eval_qa.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from ..rag_pipeline import multichoice_qa_prompt
|
| 3 |
+
from ..rag_pipeline import ChatAssistant
|
| 4 |
+
from ..utils import paralelize, load_qa_dataset, load_prepared_retrieve_docs
|
| 5 |
+
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from typing import List, Optional
|
| 8 |
+
from langchain.schema import Document
|
| 9 |
+
|
| 10 |
+
def get_answer_from_response(llm_response: str) -> chr:
|
| 11 |
+
"""
|
| 12 |
+
Get the answer from the LLM response.
|
| 13 |
+
"""
|
| 14 |
+
return llm_response[llm_response.lower().rfind("the answer is ") + 14]
|
| 15 |
+
|
| 16 |
+
def build_multichoice_qa_prompt(question: str, options: str, document: Optional[List[Document]]) -> str:
|
| 17 |
+
"""
|
| 18 |
+
Build the prompt for the multichoice QA task.
|
| 19 |
+
"""
|
| 20 |
+
if document is not None:
|
| 21 |
+
document = '\n'.join([f"Document {i+1}:\n" + doc.page_content for i,doc in enumerate(document)])
|
| 22 |
+
|
| 23 |
+
return multichoice_qa_prompt.format(question=question, options=options, document=document)
|
| 24 |
+
|
| 25 |
+
def process_question(question, prompt, answer, id, args, llm):
|
| 26 |
+
llm_response = ""
|
| 27 |
+
for j in range(args.retries):
|
| 28 |
+
try:
|
| 29 |
+
llm_response = llm.get_response("", prompt)
|
| 30 |
+
ans = get_answer_from_response(llm_response)
|
| 31 |
+
if ans in ["A", "B", "C", "D", "E"]:
|
| 32 |
+
with open("log.txt", "a", encoding="utf-8") as f:
|
| 33 |
+
f.write(f"ID: {id}\n")
|
| 34 |
+
f.write(prompt)
|
| 35 |
+
f.write(f"LLM Response:\n{llm_response}\n")
|
| 36 |
+
f.write(f"Answer: {answer} {ans}\n\n")
|
| 37 |
+
break
|
| 38 |
+
except Exception as e:
|
| 39 |
+
print(f"Error: {e}")
|
| 40 |
+
ans = "#"
|
| 41 |
+
with open("log_score.txt", "a", encoding="utf-8") as f:
|
| 42 |
+
f.write("1" if ans == answer else "0")
|
| 43 |
+
return 1 if ans == answer else 0
|
| 44 |
+
|
| 45 |
+
def evaluate_qa(questions, prompts, answers, ids, args, llm):
|
| 46 |
+
import concurrent.futures
|
| 47 |
+
from tqdm import tqdm
|
| 48 |
+
correct = 0
|
| 49 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor:
|
| 50 |
+
futures = [executor.submit(process_question, questions[i], prompts[i], answers[i], ids[i], args, llm) for i in range(len(questions))]
|
| 51 |
+
for future in tqdm(concurrent.futures.as_completed(futures), total=len(questions)):
|
| 52 |
+
correct += future.result()
|
| 53 |
+
return correct / len(questions)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def main(args):
|
| 57 |
+
ids, questions, options, answers = load_qa_dataset(args.qa_file)
|
| 58 |
+
|
| 59 |
+
if ids is None:
|
| 60 |
+
raise ValueError(f"No id field in {args.qa_file}.")
|
| 61 |
+
|
| 62 |
+
if args.num_docs > 0:
|
| 63 |
+
if args.prepared_retrieve_docs_path is not None:
|
| 64 |
+
documents = load_prepared_retrieve_docs(args.prepared_retrieve_docs_path)
|
| 65 |
+
docs = [d[:args.num_docs] for i,d in enumerate(documents)]
|
| 66 |
+
else:
|
| 67 |
+
raise ValueError(f"No prepared retrieve docs found.")
|
| 68 |
+
else:
|
| 69 |
+
docs = [None]*len(questions)
|
| 70 |
+
|
| 71 |
+
prompts = [build_multichoice_qa_prompt(questions[i], options[i], docs[i]) for i in range(len(questions))]
|
| 72 |
+
|
| 73 |
+
# print(prompts[0])
|
| 74 |
+
llm = ChatAssistant(args.model_name, args.provider)
|
| 75 |
+
|
| 76 |
+
with open("log_score.txt", "a", encoding="utf-8") as f:
|
| 77 |
+
f.write(f"\n{datetime.now()} {args}\n")
|
| 78 |
+
|
| 79 |
+
acc = evaluate_qa(questions, prompts, answers, ids, args, llm)
|
| 80 |
+
print(f"Accuracy: {acc}")
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
if __name__ == '__main__':
|
| 84 |
+
parser = argparse.ArgumentParser()
|
| 85 |
+
|
| 86 |
+
# parser.add_argument("--qa_file", type=str, default="dataset/QA Data/MedAB/MedABv2.jsonl")
|
| 87 |
+
# parser.add_argument("--prepared_retrieve_docs_path", type=str, default="dataset/QA Data/MedAB/prepared_retrieve_docs_full.pkl")
|
| 88 |
+
|
| 89 |
+
parser.add_argument("--qa_file", type=str, default="dataset/QA Data/MedMCQA/translated_hard_questions.jsonl")
|
| 90 |
+
parser.add_argument("--prepared_retrieve_docs_path", type=str, default="dataset/QA Data/MedMCQA/prepared_retrieve_docs_full.pkl")
|
| 91 |
+
|
| 92 |
+
# Eval params
|
| 93 |
+
parser.add_argument("--model_name", type=str, default="mistral-medium")
|
| 94 |
+
parser.add_argument("--provider", type=str, default="mistral")
|
| 95 |
+
parser.add_argument("--max_workers", type=int, default=4)
|
| 96 |
+
parser.add_argument("--num_docs", type=int, default=0)
|
| 97 |
+
parser.add_argument("--retries", type=int, default=4)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# Dataset params
|
| 101 |
+
parser.add_argument("--dataset_path", type=str)
|
| 102 |
+
|
| 103 |
+
args = parser.parse_args()
|
| 104 |
+
print(f"Log:{args}")
|
| 105 |
+
|
| 106 |
+
main(args)
|
test/prepare_retrieve.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
from ..rag_pipeline import get_embeddings, vretrieve
|
| 5 |
+
from ..utils import load_local, load_qa_dataset, safe_save_langchain_docs
|
| 6 |
+
|
| 7 |
+
def main(args):
|
| 8 |
+
embed_model = get_embeddings(args.embed_model_name, show_progress=False)
|
| 9 |
+
vectorstore, docs = load_local(args.vectorstore_dir, embed_model)
|
| 10 |
+
|
| 11 |
+
ids, questions, options, answers = load_qa_dataset(args.qa_data_path)
|
| 12 |
+
|
| 13 |
+
rag_queries = [f"Question: {questions[i]}\n{options[i]}" for i in range(len(questions))]
|
| 14 |
+
if (args.rag_queries_path is not None) and os.path.exists(args.rag_queries_path):
|
| 15 |
+
import json
|
| 16 |
+
with open(args.rag_queries_path, "r", encoding="utf-8") as f:
|
| 17 |
+
rag_queries = [json.loads(line)["query"] for line in f]
|
| 18 |
+
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
retrieve_results = [vretrieve(rag_queries[i], vectorstore, docs, args.retriever_k, args.metric, args.threshold) for i in tqdm(range(len(rag_queries)), desc="Retrieving documents")]
|
| 21 |
+
|
| 22 |
+
safe_save_langchain_docs(retrieve_results, args.prepared_retrieve_docs_path)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
if __name__ == '__main__':
|
| 26 |
+
parser = argparse.ArgumentParser()
|
| 27 |
+
|
| 28 |
+
# Dataset params
|
| 29 |
+
parser.add_argument("--qa_data_path", type=str, default="dataset/QA Data/MedMCQA/translated_hard_questions.jsonl")
|
| 30 |
+
|
| 31 |
+
# Vectorstore params
|
| 32 |
+
parser.add_argument("--vectorstore_dir", type=str, default="notebook/An/master/knowledge/vectorstore_full")
|
| 33 |
+
parser.add_argument("--prepared_retrieve_docs_path", type=str, default="dataset/QA Data/MedMCQA/prepared_retrieve_docs_full.pkl")
|
| 34 |
+
parser.add_argument("--rag_queries_path", type=str, default=None)
|
| 35 |
+
|
| 36 |
+
# Model params
|
| 37 |
+
parser.add_argument("--embed_model_name", type=str, default="alibaba-nlp/gte-multilingual-base")
|
| 38 |
+
|
| 39 |
+
# Vectorstore retriever params
|
| 40 |
+
parser.add_argument("--vectorstore", type=str, choices=["faiss", "chroma"], default="faiss")
|
| 41 |
+
parser.add_argument("--metric", type=str, choices=["cosine", "mmr", "bm25"], default="mmr")
|
| 42 |
+
parser.add_argument("--retriever_k", type=int, default=20, help="Number of documents to retrieve")
|
| 43 |
+
parser.add_argument("--threshold", type=float, default=0.5, help="Threshold for cosine similarity")
|
| 44 |
+
parser.add_argument("--reranker_model_name", type=str, default=None)
|
| 45 |
+
parser.add_argument("--reranker_k", type=int, default=50, help="Number of documents to rerank")
|
| 46 |
+
|
| 47 |
+
args = parser.parse_args()
|
| 48 |
+
print(args)
|
| 49 |
+
|
| 50 |
+
main(args)
|
test/test_llm.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ..rag_pipeline import ChatAssistant
|
| 2 |
+
from ..rag_pipeline import request_retrieve_prompt
|
| 3 |
+
|
| 4 |
+
cb = ChatAssistant("mistral-medium", "mistral")
|
| 5 |
+
|
| 6 |
+
query = "Beta blocker for hypertension"
|
| 7 |
+
query = request_retrieve_prompt.format(conversation=query, role="customer")
|
| 8 |
+
response = cb.get_response(user=query)
|
| 9 |
+
print(response)
|
test/test_retrieve.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
from ..rag_pipeline import get_embeddings, rerank
|
| 5 |
+
from ..utils import load_local
|
| 6 |
+
|
| 7 |
+
from ..rag_pipeline import vretrieve
|
| 8 |
+
|
| 9 |
+
def main(args):
|
| 10 |
+
embed_model = get_embeddings(args.embed_model_name)
|
| 11 |
+
vectorstore, docs = load_local(args.vectorstore_dir, embed_model)
|
| 12 |
+
retrieve_results = vretrieve(args.query, vectorstore, docs, args.retriever_k, args.metric, args.threshold)
|
| 13 |
+
|
| 14 |
+
retrieve_results = rerank(retrieve_results)
|
| 15 |
+
|
| 16 |
+
print(retrieve_results)
|
| 17 |
+
|
| 18 |
+
if __name__ == '__main__':
|
| 19 |
+
parser = argparse.ArgumentParser()
|
| 20 |
+
|
| 21 |
+
parser.add_argument("--query", type=str, required=False, default="What are the applications of beta blockers in the treatment of hypertension?")
|
| 22 |
+
|
| 23 |
+
# Vectorstore params
|
| 24 |
+
parser.add_argument("--vectorstore_dir", type=str, required=False, default="notebook/An/master/knowledge/vectorstore_full")
|
| 25 |
+
|
| 26 |
+
# Model params
|
| 27 |
+
parser.add_argument("--embed_model_name", type=str, default="alibaba-nlp/gte-multilingual-base")
|
| 28 |
+
|
| 29 |
+
# Vectorstore retriever params
|
| 30 |
+
parser.add_argument("--vectorstore", type=str, choices=["faiss", "chroma"], default="faiss")
|
| 31 |
+
parser.add_argument("--metric", type=str, choices=["cosine", "mmr", "bm25"], default="cosine")
|
| 32 |
+
parser.add_argument("--retriever_k", type=int, default=4, help="Number of documents to retrieve")
|
| 33 |
+
parser.add_argument("--threshold", type=float, default=0.7, help="Threshold for cosine similarity")
|
| 34 |
+
parser.add_argument("--reranker_model_name", type=str, default=None)
|
| 35 |
+
parser.add_argument("--reranker_k", type=int, default=20, help="Number of documents to rerank")
|
| 36 |
+
|
| 37 |
+
args = parser.parse_args()
|
| 38 |
+
|
| 39 |
+
main(args)
|
utils.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
from typing import List, Optional
|
| 4 |
+
|
| 5 |
+
from langchain_community.vectorstores import FAISS
|
| 6 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 7 |
+
from langchain.schema import Document
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def load_local(vectorstore_dir: str, embed_model: HuggingFaceEmbeddings) -> tuple[Optional[FAISS], Optional[List[Document]]]:
|
| 11 |
+
"""
|
| 12 |
+
Load the vectorstore and documents from disk.
|
| 13 |
+
Args:
|
| 14 |
+
vectorstore_dir: The directory to load the vectorstore from.
|
| 15 |
+
embed_model: The embedding model to use.
|
| 16 |
+
Returns:
|
| 17 |
+
vector_store: The vectorstore.
|
| 18 |
+
"""
|
| 19 |
+
from langchain_community.vectorstores import FAISS
|
| 20 |
+
|
| 21 |
+
if not os.path.isdir(vectorstore_dir):
|
| 22 |
+
print(f"Vectorstore directory not found at {vectorstore_dir}. Creating a new one.")
|
| 23 |
+
os.makedirs(vectorstore_dir, exist_ok=True)
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
vector_store = FAISS.load_local(vectorstore_dir, embed_model, allow_dangerous_deserialization=True)
|
| 27 |
+
|
| 28 |
+
docs_path = os.path.join(vectorstore_dir, "docs.pkl")
|
| 29 |
+
if os.path.exists(docs_path):
|
| 30 |
+
with open(docs_path, "rb") as f:
|
| 31 |
+
docs = pickle.load(f)
|
| 32 |
+
else:
|
| 33 |
+
docs = None
|
| 34 |
+
print("Warning: docs.pkl not found. BM25 search will not be available.")
|
| 35 |
+
|
| 36 |
+
print(f"Successfully loaded RAG state from {vectorstore_dir}")
|
| 37 |
+
return vector_store, docs
|
| 38 |
+
except Exception as e:
|
| 39 |
+
print(f"Could not load from {vectorstore_dir}. It might be empty or corrupted. Error: {e}")
|
| 40 |
+
return None, None
|
| 41 |
+
|
| 42 |
+
def save_local(vectorstore_dir: str, vectorstore: FAISS, docs: Optional[List[Document]]) -> None:
|
| 43 |
+
"""
|
| 44 |
+
Save the vectorstore and documents to disk.
|
| 45 |
+
Args:
|
| 46 |
+
vectorstore_dir: The directory to save the vectorstore to.
|
| 47 |
+
vectorstore: The vectorstore to save.
|
| 48 |
+
docs: The documents to save.
|
| 49 |
+
"""
|
| 50 |
+
if vectorstore is None:
|
| 51 |
+
raise ValueError("Nothing to save.")
|
| 52 |
+
if docs is None:
|
| 53 |
+
print("Warning: No documents to save. BM25 search will not be available.")
|
| 54 |
+
|
| 55 |
+
os.makedirs(vectorstore_dir, exist_ok=True)
|
| 56 |
+
vectorstore.save_local(vectorstore_dir)
|
| 57 |
+
|
| 58 |
+
if docs is not None:
|
| 59 |
+
with open(os.path.join(vectorstore_dir, "docs.pkl"), "wb") as f:
|
| 60 |
+
pickle.dump(docs, f)
|
| 61 |
+
|
| 62 |
+
print(f"Successfully saved RAG state to {vectorstore_dir}")
|
| 63 |
+
|
| 64 |
+
def load_qa_dataset(qa_dataset_path: str) -> tuple[List[str], List[str], List[str], List[str]]:
|
| 65 |
+
"""
|
| 66 |
+
Load the QA dataset. (jsonl)
|
| 67 |
+
Args:
|
| 68 |
+
qa_dataset_path: The path to the QA dataset.
|
| 69 |
+
Returns:
|
| 70 |
+
Tuple: (ids, questions, options, answers)\\
|
| 71 |
+
ids: The ids of the questions\\
|
| 72 |
+
questions: The questions\\
|
| 73 |
+
options: The options for each question\\
|
| 74 |
+
answers: The answers for each question.
|
| 75 |
+
"""
|
| 76 |
+
import json
|
| 77 |
+
if not os.path.exists(qa_dataset_path):
|
| 78 |
+
raise FileNotFoundError(f"Error: File not found at {qa_dataset_path}")
|
| 79 |
+
|
| 80 |
+
with open(qa_dataset_path, "r", encoding="utf-8") as f:
|
| 81 |
+
data = [json.loads(line) for line in f]
|
| 82 |
+
questions = [item["question"] for item in data]
|
| 83 |
+
try:
|
| 84 |
+
options = [
|
| 85 |
+
(f"A. {item['A']} \n" if item['A'] not in [" ", "", None] else "") +
|
| 86 |
+
(f"B. {item['B']} \n" if item['B'] not in [" ", "", None] else "") +
|
| 87 |
+
(f"C. {item['C']} \n" if item['C'] not in [" ", "", None] else "") +
|
| 88 |
+
(f"D. {item['D']} \n" if item['D'] not in [" ", "", None] else "") +
|
| 89 |
+
(f"E. {item['E']} \n" if item['E'] not in [" ", "", None] else "")
|
| 90 |
+
for item in data]
|
| 91 |
+
except KeyError:
|
| 92 |
+
options = [" " for item in data]
|
| 93 |
+
answers = [item["answer"] for item in data]
|
| 94 |
+
uuids = [item["uuid"] for item in data]
|
| 95 |
+
return uuids, questions, options, answers
|
| 96 |
+
|
| 97 |
+
def load_prepared_retrieve_docs(prepared_retrieve_docs_path: str) -> List[List[Document]]:
|
| 98 |
+
"""
|
| 99 |
+
Load the prepared retrieve docs from a file.
|
| 100 |
+
Args:
|
| 101 |
+
prepared_retrieve_docs_path: The path to the prepared retrieve docs.
|
| 102 |
+
Returns:
|
| 103 |
+
A list of lists of documents.
|
| 104 |
+
"""
|
| 105 |
+
return safe_load_langchain_docs(prepared_retrieve_docs_path)
|
| 106 |
+
|
| 107 |
+
def paralelize(func, max_workers: int = 4, **kwargs) -> List:
|
| 108 |
+
"""
|
| 109 |
+
Parallelizes a function call over multiple keyword argument iterables.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
func: The function to execute in parallel.
|
| 113 |
+
max_workers: The maximum number of threads to use.
|
| 114 |
+
**kwargs: Keyword arguments where each value is an iterable (e.g., a list).
|
| 115 |
+
All iterables must be of the same length.
|
| 116 |
+
The keyword names do not matter, but their order does.
|
| 117 |
+
Returns:
|
| 118 |
+
A list of the results of the function calls.
|
| 119 |
+
"""
|
| 120 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 121 |
+
from tqdm import tqdm
|
| 122 |
+
|
| 123 |
+
if not kwargs:
|
| 124 |
+
return []
|
| 125 |
+
|
| 126 |
+
arg_lists = list(kwargs.values())
|
| 127 |
+
if len(set(len(lst) for lst in arg_lists)) > 1:
|
| 128 |
+
raise ValueError("All iterable arguments must have the same length.")
|
| 129 |
+
|
| 130 |
+
total_items = len(arg_lists[0])
|
| 131 |
+
iterable = zip(*arg_lists)
|
| 132 |
+
unpacker_func = lambda args_tuple: func(*args_tuple)
|
| 133 |
+
|
| 134 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 135 |
+
results = list(tqdm(executor.map(unpacker_func, iterable), total=total_items))
|
| 136 |
+
return results
|
| 137 |
+
|
| 138 |
+
def safe_save_langchain_docs(documents: List[List[Document]], filepath: str):
|
| 139 |
+
"""
|
| 140 |
+
Converts LangChain Document objects into a serializable list of dictionaries
|
| 141 |
+
and saves them to a file using pickle.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
documents (List[List[Document]]): The nested list of LangChain Documents.
|
| 145 |
+
filepath (str): The path to the file where the data will be saved.
|
| 146 |
+
"""
|
| 147 |
+
serializable_data = []
|
| 148 |
+
print(f"Preparing to save {len(documents)} lists of documents...")
|
| 149 |
+
|
| 150 |
+
# Convert each Document object into a dictionary
|
| 151 |
+
for doc_list in documents:
|
| 152 |
+
serializable_doc_list = []
|
| 153 |
+
for doc in doc_list:
|
| 154 |
+
serializable_doc_list.append({
|
| 155 |
+
"page_content": doc.page_content,
|
| 156 |
+
"metadata": doc.metadata,
|
| 157 |
+
})
|
| 158 |
+
serializable_data.append(serializable_doc_list)
|
| 159 |
+
|
| 160 |
+
print(f"Conversion complete. Saving to {filepath}...")
|
| 161 |
+
try:
|
| 162 |
+
# Use 'with' to ensure the file is closed properly, even if errors occur
|
| 163 |
+
with open(filepath, "wb") as f:
|
| 164 |
+
pickle.dump(serializable_data, f)
|
| 165 |
+
print("File saved successfully.")
|
| 166 |
+
except Exception as e:
|
| 167 |
+
print(f"An error occurred while saving the file: {e}")
|
| 168 |
+
|
| 169 |
+
def safe_load_langchain_docs(filepath: str) -> List[List[Document]]:
|
| 170 |
+
"""
|
| 171 |
+
Loads data from a pickle file and reconstructs the LangChain Document objects.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
filepath (str): The path to the file to load.
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
List[List[Document]]: The reconstructed nested list of LangChain Documents.
|
| 178 |
+
"""
|
| 179 |
+
reconstructed_documents = []
|
| 180 |
+
|
| 181 |
+
print(f"Loading data from {filepath}...")
|
| 182 |
+
try:
|
| 183 |
+
with open(filepath, "rb") as f:
|
| 184 |
+
loaded_data = pickle.load(f)
|
| 185 |
+
print("File loaded successfully. Reconstructing Document objects...")
|
| 186 |
+
|
| 187 |
+
# Reconstruct the Document objects from the dictionaries
|
| 188 |
+
for doc_list_data in loaded_data:
|
| 189 |
+
reconstructed_doc_list = []
|
| 190 |
+
for doc_data in doc_list_data:
|
| 191 |
+
reconstructed_doc_list.append(
|
| 192 |
+
Document(
|
| 193 |
+
page_content=doc_data["page_content"],
|
| 194 |
+
metadata=doc_data["metadata"]
|
| 195 |
+
)
|
| 196 |
+
)
|
| 197 |
+
reconstructed_documents.append(reconstructed_doc_list)
|
| 198 |
+
|
| 199 |
+
print("Document objects reconstructed successfully.")
|
| 200 |
+
return reconstructed_documents
|
| 201 |
+
|
| 202 |
+
except FileNotFoundError:
|
| 203 |
+
print(f"Error: The file at {filepath} was not found.")
|
| 204 |
+
return []
|
| 205 |
+
except EOFError:
|
| 206 |
+
print(f"Error: The file at {filepath} is corrupted or incomplete (EOFError).")
|
| 207 |
+
print("Please re-run the script that generates this file.")
|
| 208 |
+
return []
|
| 209 |
+
except Exception as e:
|
| 210 |
+
print(f"An unexpected error occurred while loading the file: {e}")
|
| 211 |
+
return []
|