VuvanAn commited on
Commit
09dc9d3
·
verified ·
1 Parent(s): 5ba9d4c

Upload 47 files

Browse files
Files changed (48) hide show
  1. .gitattributes +3 -0
  2. README.md +20 -0
  3. __pycache__/app.cpython-313.pyc +0 -0
  4. __pycache__/utils.cpython-313.pyc +0 -0
  5. app.py +136 -0
  6. config.yaml +25 -0
  7. knowledge/vectorstore_1/config.json +1 -0
  8. knowledge/vectorstore_1/docs.pkl +3 -0
  9. knowledge/vectorstore_1/index.faiss +3 -0
  10. knowledge/vectorstore_1/index.pkl +3 -0
  11. rag_pipeline/__init__.py +8 -0
  12. rag_pipeline/__pycache__/__init__.cpython-313.pyc +0 -0
  13. rag_pipeline/data_ingest/__pycache__/loader.cpython-313.pyc +0 -0
  14. rag_pipeline/data_ingest/loader.py +40 -0
  15. rag_pipeline/data_ingest/parser.py +0 -0
  16. rag_pipeline/generation/__pycache__/llm_wrapper.cpython-313.pyc +0 -0
  17. rag_pipeline/generation/__pycache__/prompt_template.cpython-313.pyc +0 -0
  18. rag_pipeline/generation/llm_wrapper.py +59 -0
  19. rag_pipeline/generation/prompt_template.py +115 -0
  20. rag_pipeline/indexing/chunking/__pycache__/markdown.cpython-313.pyc +0 -0
  21. rag_pipeline/indexing/chunking/__pycache__/recursive.cpython-313.pyc +0 -0
  22. rag_pipeline/indexing/chunking/markdown.py +54 -0
  23. rag_pipeline/indexing/chunking/recursive.py +30 -0
  24. rag_pipeline/indexing/embedding/__pycache__/embedding.cpython-313.pyc +0 -0
  25. rag_pipeline/indexing/embedding/embedding.py +23 -0
  26. rag_pipeline/retrieval/__pycache__/reranker.cpython-313.pyc +0 -0
  27. rag_pipeline/retrieval/__pycache__/vector_retriever.cpython-313.pyc +0 -0
  28. rag_pipeline/retrieval/graph_retriever.py +4 -0
  29. rag_pipeline/retrieval/hybrid_retriever.py +0 -0
  30. rag_pipeline/retrieval/reranker.py +8 -0
  31. rag_pipeline/retrieval/vector_retriever.py +38 -0
  32. requirements.txt +0 -0
  33. test/__pycache__/_normalize_qa.cpython-313.pyc +0 -0
  34. test/__pycache__/data_ingest.cpython-313.pyc +0 -0
  35. test/__pycache__/eval_lm.cpython-313.pyc +0 -0
  36. test/__pycache__/eval_qa.cpython-313.pyc +0 -0
  37. test/__pycache__/prepare_retrieve.cpython-313.pyc +0 -0
  38. test/__pycache__/test_llm.cpython-313.pyc +0 -0
  39. test/__pycache__/test_retrieve.cpython-313.pyc +0 -0
  40. test/_normalize_qa.py +43 -0
  41. test/chatbot_inference.py +23 -0
  42. test/data_ingest.py +78 -0
  43. test/eval_lm.py +87 -0
  44. test/eval_qa.py +106 -0
  45. test/prepare_retrieve.py +50 -0
  46. test/test_llm.py +9 -0
  47. test/test_retrieve.py +39 -0
  48. utils.py +211 -0
.gitattributes ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ knowledge/vectorstore_1/docs.pkl filter=lfs diff=lfs merge=lfs -text
2
+ knowledge/vectorstore_1/index.faiss filter=lfs diff=lfs merge=lfs -text
3
+ knowledge/vectorstore_1/index.pkl filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ```
2
+ python -m notebook.An.master.test.data_ingest
3
+ --data_dir notebook/An/master/data \\
4
+ --vectorstore_dir notebook/An/master/knowledge/vectorstore_1 \\
5
+ --embed_model_name alibaba-nlp/gte-multilingual-base \\
6
+ --chunking_strategy recursive \\
7
+ --chunk_size 2048 \\
8
+ --chunk_overlap 512 \\
9
+ --vectorstore faiss
10
+ ```
11
+
12
+ ```
13
+ python -m notebook.An.master.test.test_retrieve
14
+ --query "Heart definition and heart disease"
15
+ --vectorstore_dir notebook/An/master/knowledge/vectorstore_1 \\
16
+ --embed_model_name alibaba-nlp/gte-multilingual-base \\
17
+ --retriever_k 4 \\
18
+ --metric cosine \\
19
+ --threshold 0.5 \\
20
+ ```
__pycache__/app.cpython-313.pyc ADDED
Binary file (6.45 kB). View file
 
__pycache__/utils.cpython-313.pyc ADDED
Binary file (10.9 kB). View file
 
app.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from datetime import datetime
3
+
4
+ # Assuming these are in your project structure
5
+ from .rag_pipeline import ChatAssistant, get_embeddings, vretrieve, retrieve_chatbot_prompt, request_retrieve_prompt
6
+ from .utils import load_local
7
+
8
+ # --- Constants and System Prompt ---
9
+
10
+ # DEVELOPER: Add or remove models here.
11
+ # The key is the display name in the dropdown.
12
+ # The value is a tuple of (model_id, model_provider).
13
+ AVAILABLE_MODELS = {
14
+ "mistral large (mistral)": ("mistral", "mistral"),
15
+ "mistral medium (mistral)": ("mistral-medium", "mistral"),
16
+ "mistral small (mistral)": ("mistral-small", "mistral"),
17
+ "llama3 8B" : ("llama3:8b", "ollama"),
18
+ "llama3.1 8B": ("llama3.1:8b", "ollama"),
19
+ "gpt-oss 20B": ("gpt-oss-20b", "ollama"),
20
+ "gemma3 12B": ("gemma3:12b", "ollama"),
21
+ "gpt 4o mini": ("gpt-4o-mini", "openai"),
22
+ "gpt 4o": ("gpt-4o", "openai"),
23
+ }
24
+ DEFAULT_MODEL_KEY = "mistral medium (mistral)"
25
+
26
+ EMBEDDING_MODEL_ID = "alibaba-nlp/gte-multilingual-base"
27
+ VECTORSTORE_PATH = "notebook/An/master/knowledge/vectorstore_full"
28
+ LOG_FILE_PATH = "log.txt"
29
+ MAX_HISTORY_CONVERSATION = 50
30
+
31
+ # System prompt for the medical assistant
32
+ sys = """
33
+ You are an Medical Assistant specialized in providing information and answering questions related to healthcare and medicine.
34
+ You must answer professionally and empathetically, taking into account the user's feelings and concerns.
35
+ """
36
+
37
+ # --- Initial Setup (runs once) ---
38
+ print("Initializing models and data...")
39
+ embedding_model = get_embeddings(EMBEDDING_MODEL_ID, show_progress=False)
40
+ vectorstore, docs = load_local(VECTORSTORE_PATH, embedding_model)
41
+ print("Initialization complete.")
42
+
43
+
44
+ # --- Helper Functions ---
45
+ def log(log_txt: str):
46
+ """Appends a log entry to the log file."""
47
+ with open(LOG_FILE_PATH, "a", encoding="utf-8") as log_file:
48
+ log_file.write(log_txt + "\n")
49
+
50
+
51
+ # --- Core Chatbot Logic ---
52
+ def chatbot_logic(message: str, history: list, selected_model_key: str):
53
+ """
54
+ Handles the main logic for receiving a message, performing RAG, and generating a response.
55
+ """
56
+ # 1. Look up the model_id and model_provider from the selected key
57
+ model_id, model_provider = AVAILABLE_MODELS[selected_model_key]
58
+
59
+ log(f"** Current time **: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
60
+ log(f"** User message **: {message}")
61
+ log(f"** Using Model **: {model_id} ({model_provider})")
62
+
63
+ # Initialize the assistant with the specified model for this request
64
+ try:
65
+ chat_assistant = ChatAssistant(model_id, model_provider)
66
+ except Exception as e:
67
+ yield f"Error: Could not initialize the model. Please check the ID and provider. Details: {e}"
68
+ return
69
+
70
+ # --- RAG Pipeline ---
71
+ # 2. Format conversation history for context
72
+ history = history[-MAX_HISTORY_CONVERSATION:]
73
+ conversation = "".join(f"User: {user_msg}\nBot: {bot_msg}\n" for user_msg, bot_msg in history)
74
+ query_for_rag = conversation + f"User: {message}\nBot:"
75
+
76
+ # 3. Generate a search query from the conversation
77
+ rag_query = chat_assistant.get_response(request_retrieve_prompt.format(role="user", conversation=query_for_rag))
78
+ rag_query = rag_query[rag_query.lower().rfind("[") + 1: rag_query.rfind("]")]
79
+
80
+ # 4. Retrieve relevant documents if necessary
81
+ if "NO" not in rag_query:
82
+ retrieve_results = vretrieve(rag_query, vectorstore, docs, k=4, metric="mmr", threshold=0.7)
83
+ else:
84
+ retrieve_results = []
85
+
86
+ retrieved_docs = "\n".join([f"Document {i+1}:\n" + doc.page_content for i, doc in enumerate(retrieve_results)])
87
+ log(f"** RAG query **: {rag_query}")
88
+ log(f"** Retrieved documents **:\n{retrieved_docs}")
89
+
90
+ # --- Final Response Generation ---
91
+ # 5. Create the final prompt with retrieved context
92
+ final_prompt = retrieve_chatbot_prompt.format(role="user", documents=retrieved_docs, conversation=query_for_rag)
93
+
94
+ # 6. Stream the response from the LLM
95
+ response = ""
96
+ for token in chat_assistant.get_streaming_response(final_prompt, sys):
97
+ response += token
98
+ yield response
99
+
100
+ log(f"** Bot response **: {response}")
101
+ log("=" * 50 + "\n\n")
102
+
103
+
104
+ # --- Gradio UI ---
105
+ with gr.Blocks(theme="soft") as chatbot_ui:
106
+ gr.Markdown("# MedLLM")
107
+
108
+ model_selector = gr.Dropdown(
109
+ label="Select Model",
110
+ choices=list(AVAILABLE_MODELS.keys()),
111
+ value=DEFAULT_MODEL_KEY,
112
+ )
113
+
114
+ chatbot = gr.Chatbot(label="Chat Window", height=500, bubble_full_width=False)
115
+ msg_input = gr.Textbox(label="Your Message", placeholder="Type your question here and press Enter...", scale=7)
116
+
117
+ def respond(message, chat_history, selected_model_key):
118
+ """Wrapper function to connect chatbot_logic with Gradio's state."""
119
+ bot_message_stream = chatbot_logic(message, chat_history, selected_model_key)
120
+ chat_history.append([message, ""])
121
+ for token in bot_message_stream:
122
+ chat_history[-1][1] = token
123
+ yield chat_history
124
+
125
+ msg_input.submit(
126
+ respond,
127
+ [msg_input, chatbot, model_selector],
128
+ [chatbot]
129
+ ).then(
130
+ lambda: gr.update(value=""), None, [msg_input], queue=False
131
+ )
132
+
133
+
134
+ # --- Launch the App ---
135
+ if __name__ == "__main__":
136
+ chatbot_ui.launch(debug=True, share=False)
config.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: 0.1
2
+
3
+ model:
4
+ name: "llama2:7b"
5
+ temperature: 0.3
6
+ max_tokens: 100000
7
+ provider: "ollama"
8
+ base_url: "http://localhost:11434/v1"
9
+
10
+ rag_config:
11
+ k: 4
12
+ rerank:
13
+ name: "bge-reranker-large"
14
+ model: "BAAI/bge-reranker-large"
15
+ top_n: 100
16
+ embed_model:
17
+ name: "gte-multilingual-base"
18
+ model: "alibaba-nlp/gte-multilingual-base"
19
+ chunk_size: 2048
20
+ chunk_overlap: 512
21
+ similarity_threshold: 0.7
22
+ similarity_metric: "cosine"
23
+
24
+ knowledge:
25
+ vectorstore: "faiss"
knowledge/vectorstore_1/config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"data_paths": ["dataset/RAG_Data/wiki_vi", "dataset/RAG_Data/youmed"], "vectorstore_dir": "notebook/An/master/knowledge/vectorstore_1", "file_type": "txt", "embed_model_name": "alibaba-nlp/gte-multilingual-base", "chunk_size": 2048, "chunk_overlap": 512, "chunk_method": "markdown", "vectorstore": "faiss", "clear_vectorstore": true}
knowledge/vectorstore_1/docs.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e969c1a0beb575363fc3cd0e252b9751f9ad79fc605ec6ab4a2c4ee68845e43
3
+ size 7568017
knowledge/vectorstore_1/index.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:10ffba3c9fc6846d51de37463833eecf8b42b036a78e93e90ff779fbd47268f6
3
+ size 9440301
knowledge/vectorstore_1/index.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c9d81a73aa18b621f660a69e7ce3bba1b8b1875e983752a1e504f1f2922a7fdc
3
+ size 7730542
rag_pipeline/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .generation.llm_wrapper import ChatAssistant
2
+ from .indexing.chunking.recursive import split_document as recursive_chunking
3
+ from .indexing.chunking.markdown import split_document as markdown_chunking
4
+ from .indexing.embedding.embedding import get_embeddings
5
+ from .data_ingest.loader import load_data
6
+ from .generation.prompt_template import *
7
+ from .retrieval.vector_retriever import retrieve as vretrieve
8
+ from .retrieval.reranker import rerank
rag_pipeline/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (704 Bytes). View file
 
rag_pipeline/data_ingest/__pycache__/loader.cpython-313.pyc ADDED
Binary file (2.09 kB). View file
 
rag_pipeline/data_ingest/loader.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+ from langchain.schema import Document
4
+
5
+ def load_data(data_path: str, file_type: str) -> List[Document]:
6
+ """
7
+ Load knowledge data from a specified path and file type.
8
+ Args:
9
+ data_path: The path to the data.
10
+ file_type: The type of the data.
11
+ Returns:
12
+ A list of documents.
13
+ """
14
+ if file_type == "pdf":
15
+ raise NotImplementedError("PDF loading is not yet implemented.")
16
+ elif file_type == "txt":
17
+ return _load_txt(data_path)
18
+
19
+ def _load_txt(data_path: str) -> List[Document]:
20
+ splits = []
21
+
22
+ if not os.path.isdir(data_path):
23
+ raise FileNotFoundError(f"Error: Directory not found at {data_path}")
24
+
25
+ for file_name in os.listdir(data_path):
26
+ if file_name.endswith('.txt'):
27
+ file_path = os.path.join(data_path, file_name)
28
+
29
+ try:
30
+ with open(file_path, 'r', encoding='utf-8') as f:
31
+ content = f.read()
32
+ metadata = {"source": file_name}
33
+ doc = Document(page_content=content, metadata=metadata)
34
+
35
+ splits.append(doc)
36
+
37
+ except Exception as e:
38
+ print(f"Error reading file {file_path}: {e}")
39
+
40
+ return splits
rag_pipeline/data_ingest/parser.py ADDED
File without changes
rag_pipeline/generation/__pycache__/llm_wrapper.cpython-313.pyc ADDED
Binary file (2.9 kB). View file
 
rag_pipeline/generation/__pycache__/prompt_template.cpython-313.pyc ADDED
Binary file (3.99 kB). View file
 
rag_pipeline/generation/llm_wrapper.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import OpenAI
2
+ import backoff
3
+
4
+ import os
5
+
6
+ _base_url_ ={
7
+ "ollama": "http://localhost:11434/v1",
8
+ "mistral": "https://api.mistral.ai/v1",
9
+ "openai": "https://api.openai.com/v1",
10
+ }
11
+
12
+ _api_key_ = {
13
+ "ollama": "ollama",
14
+ "mistral": os.getenv("MISTRAL_API_KEY"),
15
+ "openai": os.getenv("OPENAI_API_KEY"),
16
+ }
17
+
18
+ class ChatAssistant:
19
+ def __init__(self, model_name:str, provider:str = "ollama"):
20
+ """
21
+ Args:
22
+ model_name: The name of the model to use.
23
+ provider: The provider of the model. Can be "ollama", "mistral", or "openai".
24
+ """
25
+ self.model_name = model_name
26
+ self.client = OpenAI(
27
+ base_url=_base_url_[provider],
28
+ api_key=_api_key_[provider],
29
+ )
30
+
31
+ @backoff.on_exception(backoff.expo, Exception)
32
+ def get_response(self, user: str, sys: str = ""):
33
+ response = self.client.chat.completions.create(
34
+ model=self.model_name,
35
+ messages=[
36
+ {"role": "system", "content": sys},
37
+ {"role": "user", "content": user},
38
+ ]
39
+ )
40
+ return response.choices[0].message.content
41
+
42
+ @backoff.on_exception(backoff.expo, Exception)
43
+ def get_streaming_response(self, user: str, sys: str = ""):
44
+ """Yields the response token by token (streaming)."""
45
+ response_stream = self.client.chat.completions.create(
46
+ model=self.model_name,
47
+ messages=[
48
+ {"role": "system", "content": sys},
49
+ {"role": "user", "content": user},
50
+ ],
51
+ stream=True
52
+ )
53
+
54
+ # Iterate over the stream of chunks
55
+ for chunk in response_stream:
56
+ # The actual token is in chunk.choices[0].delta.content
57
+ token = chunk.choices[0].delta.content
58
+ if token is not None:
59
+ yield token
rag_pipeline/generation/prompt_template.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ multichoice_qa_prompt = """
2
+ -- DOCUMENT --
3
+ {document}
4
+ -- END OF DOCUMENT --
5
+
6
+ -- INSTRUCTION --
7
+ You are a medical expert.
8
+ Given the documents, you must answer the question follow these step by step.
9
+ First, you must read the question and the options, and draft an answer for it based on your knowledge.
10
+ Second, you must read the documents and check if they can help answer the question.
11
+ Third, you cross check the document with your knowledge and the draft answer.
12
+ Finally, you answer the question based on your knowledge and the true documents.
13
+ Your response must end with the letter of the most correct option like: "the answer is A".
14
+ The entire thought must under 500 words long.
15
+ -- END OF INSTRUCTION --
16
+
17
+ -- QUESTION --
18
+ {question}
19
+ {options}
20
+ -- END OF QUESTION --
21
+ """
22
+
23
+ qa_prompt = """
24
+ -- DOCUMENT --
25
+ {document}
26
+ -- END OF DOCUMENT --
27
+
28
+ -- INSTRUCTION --
29
+ You are a medical expert.
30
+ Given the documents, you must answer the question follow these step by step.
31
+ First, you must read the question and draft an answer for it based on your knowledge.
32
+ Second, you must read the documents and check if they can help answer the question.
33
+ Third, you cross check the document with your knowledge and the draft answer.
34
+ Finally, you answer the question based on your knowledge and the true documents concisely.
35
+ Your response must as shortest as possible, in Vietnamese and between brackets like: "[...]".
36
+ -- END OF INSTRUCTION --
37
+
38
+ -- QUESTION --
39
+ {question}
40
+ -- END OF QUESTION --
41
+ """
42
+
43
+ retrieve_chatbot_prompt = """
44
+ You are a medical expert.
45
+ You are having a conversation with a {role} and you have an external documents to help you.
46
+ Continue the conversation based on the chat history, the context information, and not prior knowledge.
47
+ Before use the retrieved chunk, you must check if it is relevant to the user query. If it is not relevant, you must ignore it.
48
+ You use the relevant chunk to answer the question and cite the source inside <<<>>>.
49
+ If you don't know the answer, you must say "I don't know".
50
+ ---------------------
51
+ {documents}
52
+ ---------------------
53
+ Given the documents and not prior knowledge, continue the conversation.
54
+ ---------------------
55
+ {conversation}
56
+ ---------------------
57
+ """
58
+
59
+ request_retrieve_prompt = """
60
+ --- INSTRUCTION ---
61
+ You are having a conversation with a {role}.
62
+ You have to provide a short query to retrieve the documents that you need inside the brackets like: "[...]".
63
+ If it is something do not related to medical field, or something you do not need the external knowledge to answer, you must write "[NO]".
64
+ --- END OF INSTRUCTION ---
65
+
66
+ --- COVERSATION ---
67
+ {conversation}
68
+ --- END OF COVERSATION ---
69
+ """
70
+
71
+ answer_prompt = """
72
+ -- INSTRUCTION --
73
+ You are a medical expert.
74
+ Given the documents below, you must answer the question step by step.
75
+ First, you must read the question.
76
+ Second, you must read the documents and check for it's reliability.
77
+ Third, you cross check with your knowledge.
78
+ Finally, you answer the question based on your knowledge and the true documents.
79
+
80
+ Your answer must UNDER 50 words, write on 1 line and write in Vietnamese.
81
+ -- END OF INSTRUCTION --
82
+
83
+ -- QUESTION --
84
+ {question}
85
+ -- END OF QUESTION --
86
+
87
+ -- DOCUMENT --
88
+ {document}
89
+ -- END OF DOCUMENT --
90
+
91
+ """
92
+
93
+ translate_prompt = """
94
+ [ INSTRUCTION ]
95
+ You are a Medical translator expert.
96
+ Your task is to translate this English question into Vietnamese with EXACTLY the same format and write in 1 line.
97
+ [ END OF INSTRUCTION ]
98
+
99
+ [ QUERY TO TRANSLATE ]
100
+ {query}
101
+ [ END OF QUERY TO TRANSLATE ]
102
+ """
103
+
104
+ pdf2txt_prompt = """
105
+ Rewrite this plain text from pdf file follow the right reading order and these instructions:
106
+ - Use markdown format.
107
+ - Use same language.
108
+ - Keep the content intact.
109
+ - Beautify the table.
110
+ - No talk.
111
+
112
+ [ QUERY ]
113
+ {query}
114
+ [ END OF QUERY ]
115
+ """
rag_pipeline/indexing/chunking/__pycache__/markdown.cpython-313.pyc ADDED
Binary file (2.55 kB). View file
 
rag_pipeline/indexing/chunking/__pycache__/recursive.cpython-313.pyc ADDED
Binary file (1.52 kB). View file
 
rag_pipeline/indexing/chunking/markdown.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.text_splitter import MarkdownHeaderTextSplitter, RecursiveCharacterTextSplitter
2
+ from langchain.schema import Document
3
+ from typing import List
4
+
5
+ def __split_1_document__(document: Document, chunk_size: int, chunk_overlap: int) -> List[Document]:
6
+ headers_to_split_on = [
7
+ ("#", "Header 1"),
8
+ ("##", "Header 2"),
9
+ ("###", "Header 3"),
10
+ ]
11
+
12
+ markdown_splitter = MarkdownHeaderTextSplitter(
13
+ headers_to_split_on=headers_to_split_on,
14
+ strip_headers=False,
15
+ return_each_line=False
16
+ )
17
+
18
+ md_header_splits = markdown_splitter.split_text(document.page_content)
19
+
20
+ for doc in md_header_splits:
21
+ doc.metadata.update(document.metadata)
22
+
23
+ text_splitter = RecursiveCharacterTextSplitter(
24
+ chunk_size=chunk_size, chunk_overlap=chunk_overlap
25
+ )
26
+
27
+ final_splits = text_splitter.split_documents(md_header_splits)
28
+
29
+ # Iterate through the final chunks to prepend metadata to the page_content
30
+ for i, doc in enumerate(final_splits):
31
+ header_lines = []
32
+ source_line = f"-- source: {doc.metadata.get('source', 'N/A')}"
33
+
34
+ if 'Header 1' in doc.metadata:
35
+ header_lines.append(doc.metadata['Header 1'])
36
+ if 'Header 2' in doc.metadata:
37
+ header_lines.append(doc.metadata['Header 2'])
38
+ if 'Header 3' in doc.metadata:
39
+ header_lines.append(doc.metadata['Header 3'])
40
+
41
+ header_content = "\n".join(header_lines)
42
+ chunk_header = f"Chunk {i+1}:"
43
+
44
+ # Combine everything into the new page content
45
+ original_content = doc.page_content
46
+ doc.page_content = f"{source_line}\n{header_content}\n{chunk_header}\n{original_content}"
47
+
48
+ return final_splits
49
+
50
+ def split_document(documents: List[Document], chunk_size: int, chunk_overlap: int) -> List[Document]:
51
+ split_documents = []
52
+ for doc in documents:
53
+ split_documents.extend(__split_1_document__(doc, chunk_size, chunk_overlap))
54
+ return split_documents
rag_pipeline/indexing/chunking/recursive.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
2
+ from langchain.schema import Document
3
+ from typing import List
4
+
5
+ def __split_1_document__(document: Document, chunk_size: int, chunk_overlap: int) -> List[Document]:
6
+ text_splitter = RecursiveCharacterTextSplitter(
7
+ chunk_size=chunk_size,
8
+ chunk_overlap=chunk_overlap,
9
+ )
10
+
11
+ text_content = document.page_content
12
+ text_chunks = text_splitter.split_text(text_content)
13
+ split_documents = []
14
+
15
+ for i, chunk in enumerate(text_chunks):
16
+ new_metadata = document.metadata.copy()
17
+
18
+ # new_metadata['chunk_number'] = i + 1
19
+
20
+ new_doc = Document(page_content=chunk, metadata=new_metadata)
21
+ split_documents.append(new_doc)
22
+
23
+ return split_documents
24
+
25
+
26
+ def split_document(documents: List[Document], chunk_size: int, chunk_overlap: int) -> List[Document]:
27
+ split_documents = []
28
+ for doc in documents:
29
+ split_documents.extend(__split_1_document__(doc, chunk_size, chunk_overlap))
30
+ return split_documents
rag_pipeline/indexing/embedding/__pycache__/embedding.cpython-313.pyc ADDED
Binary file (1.06 kB). View file
 
rag_pipeline/indexing/embedding/embedding.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_huggingface import HuggingFaceEmbeddings
2
+
3
+ import torch
4
+
5
+ _model_cache = {}
6
+
7
+ def get_embeddings(model_name: str, show_progress: bool = True) -> HuggingFaceEmbeddings:
8
+ """
9
+ Get the embeddings model. Cache available.
10
+ Args:
11
+ model_name: The name of the model.
12
+ Returns:
13
+ The embeddings model.
14
+ """
15
+ if model_name not in _model_cache:
16
+ embeddings = HuggingFaceEmbeddings(
17
+ model_name=model_name,
18
+ show_progress=show_progress,
19
+ model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu', 'trust_remote_code':True},
20
+ encode_kwargs={'batch_size': 15}
21
+ )
22
+ _model_cache[model_name] = embeddings
23
+ return _model_cache[model_name]
rag_pipeline/retrieval/__pycache__/reranker.cpython-313.pyc ADDED
Binary file (504 Bytes). View file
 
rag_pipeline/retrieval/__pycache__/vector_retriever.cpython-313.pyc ADDED
Binary file (2.25 kB). View file
 
rag_pipeline/retrieval/graph_retriever.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from typing import List, Any
2
+
3
+ def retrieve(query: str, graphstore: Any = None) -> List[str]:
4
+ pass
rag_pipeline/retrieval/hybrid_retriever.py ADDED
File without changes
rag_pipeline/retrieval/reranker.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ from typing import List
4
+
5
+ from langchain.schema import Document
6
+
7
+ def rerank(docs: List[Document]) -> List[Document]:
8
+ return docs
rag_pipeline/retrieval/vector_retriever.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.vectorstores import FAISS
2
+ from langchain.schema import Document
3
+ from langchain_community.embeddings import HuggingFaceEmbeddings
4
+
5
+ from .reranker import rerank
6
+
7
+ from typing import List, Any
8
+
9
+ def retrieve(query: str, vectorstore: FAISS, docs: List[Document] = None, k: int = 4, metric: str = "cosine", threshold: float = 0.5, reranker: Any = None) -> List[Document]:
10
+ """
11
+ Retrieve documents from the vectorstore based on the query and metric.
12
+ Args:
13
+ query: The query to search for.
14
+ metric: The metric to use for retrieval.
15
+ vectorstore: The vectorstore to search in.
16
+ k: The number of documents to retrieve.
17
+ threshold: The threshold for the metric to use for retrieval.
18
+ reranker: The reranker to use for reranking the retrieved documents.
19
+ Returns:
20
+ A list of documents.
21
+ """
22
+ if metric == "cosine":
23
+ docs = vectorstore.similarity_search_with_score(query, k=k)
24
+ docs = [doc for doc, score in docs if score > threshold]
25
+ elif metric == "mmr":
26
+ docs = vectorstore.max_marginal_relevance_search(query, k=k)
27
+ elif metric == "bm25":
28
+ from langchain_community.retrievers import BM25Retriever
29
+ if docs is None:
30
+ raise ValueError("Documents not available. BM25 requires ingested or loaded documents.")
31
+ bm25_retriever = BM25Retriever.from_documents(docs)
32
+ docs = bm25_retriever.get_relevant_documents(query, k=k)
33
+ else:
34
+ raise ValueError(f"Unsupported metric: '{metric}'. Supported metrics are 'similarity', 'mmr', and 'bm25'.")
35
+
36
+ if (reranker != None):
37
+ return rerank(docs)
38
+ return docs
requirements.txt ADDED
Binary file (11.9 kB). View file
 
test/__pycache__/_normalize_qa.cpython-313.pyc ADDED
Binary file (2.2 kB). View file
 
test/__pycache__/data_ingest.cpython-313.pyc ADDED
Binary file (3.98 kB). View file
 
test/__pycache__/eval_lm.cpython-313.pyc ADDED
Binary file (5.68 kB). View file
 
test/__pycache__/eval_qa.cpython-313.pyc ADDED
Binary file (6.52 kB). View file
 
test/__pycache__/prepare_retrieve.cpython-313.pyc ADDED
Binary file (3.71 kB). View file
 
test/__pycache__/test_llm.cpython-313.pyc ADDED
Binary file (603 Bytes). View file
 
test/__pycache__/test_retrieve.cpython-313.pyc ADDED
Binary file (2.32 kB). View file
 
test/_normalize_qa.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import json
2
+ # import uuid
3
+
4
+ # origin_qa_data_path = 'dataset/QA Data/MedMCQA/hard_questions.jsonl'
5
+ # target_qa_data_path = 'dataset/QA Data/MedMCQA/translated_hard_questions.jsonl'
6
+
7
+ # def transform_id(origin_id):
8
+ # # Add 'T' prefix and remove last character
9
+ # return ' T' + origin_id[:-1]
10
+
11
+ # def update_answers():
12
+ # # Read origin data
13
+ # with open(origin_qa_data_path, 'r', encoding='utf-8') as f:
14
+ # origin_data = [json.loads(line) for line in f]
15
+
16
+ # # Read target data
17
+ # with open(target_qa_data_path, 'r', encoding='utf-8') as f:
18
+ # target_data = [json.loads(line) for line in f]
19
+
20
+ # c = []
21
+ # for item in origin_data:
22
+ # for target_item in target_data:
23
+ # if transform_id(item['id']) == target_item['uuid']:
24
+ # if item['cop'] == 0:
25
+ # target_item['answer'] = 'A'
26
+ # elif item['cop'] == 1:
27
+ # target_item['answer'] = 'B'
28
+ # elif item['cop'] == 2:
29
+ # target_item['answer'] = 'C'
30
+ # elif item['cop'] == 3:
31
+ # target_item['answer'] = 'D'
32
+ # c.extend([target_item['uuid']])
33
+ # # print(c)
34
+ # for item in target_data:
35
+ # if item['uuid'] not in c:
36
+ # print(item['uuid'])
37
+ # # Write updated target data back to file
38
+ # with open(target_qa_data_path, 'w', encoding='utf-8') as f:
39
+ # for item in target_data:
40
+ # f.write(json.dumps(item, ensure_ascii=False) + '\n')
41
+
42
+ # # Call the function to update answers
43
+ # update_answers()
test/chatbot_inference.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rag_pipeline import get_embeddings, vretrieve, rerank
2
+ from utils import load_local
3
+
4
+ import argparse
5
+
6
+ def inference():
7
+ embed_model = get_embeddings(args.embed_model_name)
8
+ vectorstore, docs = load_local(args.vectorstore_dir, embed_model)
9
+ retrieve_results = vretrieve(args.query, vectorstore, docs, args.retriever_k, args.metric, args.threshold)
10
+
11
+ retrieve_results = rerank(retrieve_results)
12
+
13
+ print(retrieve_results)
14
+
15
+ def conversation():
16
+ while True:
17
+ query = input("User: ")
18
+ if query == "exit":
19
+ break
20
+ inference(query)
21
+
22
+ if __name__ == '__main__':
23
+ conversation()
test/data_ingest.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from typing import List
4
+
5
+ from ..rag_pipeline import get_embeddings, load_data
6
+ from ..utils import load_local, save_local
7
+
8
+ def main(args):
9
+ print(f"Log: {args}")
10
+
11
+ if args.clear_vectorstore:
12
+ import shutil
13
+ if os.path.isdir(args.vectorstore_dir):
14
+ shutil.rmtree(args.vectorstore_dir)
15
+
16
+ embed_model = get_embeddings(args.embed_model_name)
17
+ vectorstore, docs = load_local(args.vectorstore_dir, embed_model)
18
+
19
+ new_docs = []
20
+ for data_path in args.data_paths:
21
+ new_docs.extend(load_data(data_path, args.file_type))
22
+ print(f"Got {len(new_docs)} documents.")
23
+
24
+ if args.chunk_method == "recursive":
25
+ from ..rag_pipeline import recursive_chunking
26
+ new_docs = recursive_chunking(new_docs, args.chunk_size, args.chunk_overlap)
27
+ elif args.chunk_method == "markdown":
28
+ from ..rag_pipeline import markdown_chunking
29
+ new_docs = markdown_chunking(new_docs, args.chunk_size, args.chunk_overlap)
30
+ print(f"Got {len(new_docs)} chunks.")
31
+
32
+ from langchain_community.vectorstores import FAISS
33
+ if vectorstore is None:
34
+ vectorstore = FAISS.from_documents(new_docs, embed_model)
35
+ docs = new_docs
36
+ print(f"Successfully consumed {len(new_docs)} documents.")
37
+ else:
38
+ docs.extend(new_docs)
39
+ vectorstore.add_documents(new_docs)
40
+
41
+ save_local(args.vectorstore_dir, vectorstore, docs)
42
+
43
+ import json
44
+ with open(os.path.join(args.vectorstore_dir, "config.json"), "a") as f:
45
+ json.dump(vars(args), f)
46
+
47
+
48
+ if __name__ == '__main__':
49
+ parser = argparse.ArgumentParser()
50
+
51
+ data_paths = [
52
+ 'dataset/RAG_Data/wiki_vi',
53
+ 'dataset/RAG_Data/youmed',
54
+ 'dataset/RAG_Data/mimic_ex_report',
55
+ 'dataset/RAG_Data/Download sach y/OCR',
56
+ ]
57
+
58
+ # Dataset params
59
+ parser.add_argument("--data_paths", type=List[str], required=False, default=data_paths)
60
+ parser.add_argument("--vectorstore_dir", type=str, required=False, default="notebook/An/master/knowledge/vectorstore_full")
61
+ parser.add_argument("--file_type", type=str, choices=["pdf", "txt"], default="txt")
62
+
63
+ # Model params
64
+ parser.add_argument("--embed_model_name", type=str, default="alibaba-nlp/gte-multilingual-base")
65
+
66
+ # Index params
67
+ parser.add_argument("--chunk_size", type=int, default=2048)
68
+ parser.add_argument("--chunk_overlap", type=int, default=512)
69
+ parser.add_argument("--chunk_method", type=str, choices=["recursive", "markdown"], default="markdown")
70
+
71
+ # Vectorstore params
72
+ parser.add_argument("--vectorstore", type=str, choices=["faiss", "chroma"], default="faiss")
73
+ parser.add_argument("--clear_vectorstore", action="store_true", default=True)
74
+
75
+
76
+ args = parser.parse_args()
77
+
78
+ main(args)
test/eval_lm.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from ..rag_pipeline import qa_prompt
3
+ from ..rag_pipeline import ChatAssistant
4
+ from ..utils import load_qa_dataset, load_prepared_retrieve_docs
5
+
6
+ from typing import List, Optional
7
+ from langchain.schema import Document
8
+
9
+ def get_answer_from_response(llm_response: str) -> str:
10
+ return llm_response.strip()
11
+
12
+ def build_qa_prompt(question: str, document: Optional[List[Document]]) -> str:
13
+ if document is not None:
14
+ document = '\n'.join([f"Document {i+1}:\n" + doc.page_content for i,doc in enumerate(document)])
15
+
16
+ return qa_prompt.format(question=question, document=document)
17
+
18
+ def process_question(question, prompt, answer, id, args, llm):
19
+ llm_response = llm.get_response("", prompt)
20
+ # ans = get_answer_from_response(llm_response)
21
+ with open("log.txt", "a", encoding="utf-8") as f:
22
+ f.write(f"ID: {id}\n")
23
+ f.write(prompt)
24
+ f.write(f"LLM Response:\n{llm_response}\n")
25
+ f.write(f"Answer: {answer} \n\n")
26
+
27
+ # with open("log_score.txt", "a", encoding="utf-8") as f:
28
+ # f.write("1" if ans == answer else "0")
29
+ # return 1 if ans == answer else 0
30
+ return llm_response
31
+
32
+ def evaluate_qa(questions, prompts, answers, ids, args, llm):
33
+ import concurrent.futures
34
+ from tqdm import tqdm
35
+ ans = []
36
+ with concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor:
37
+ futures = [executor.submit(process_question, questions[i], prompts[i], answers[i], ids[i], args, llm) for i in range(len(questions))]
38
+ for future in tqdm(concurrent.futures.as_completed(futures), total=len(questions)):
39
+ ans.append(future.result())
40
+ return ans
41
+
42
+ def main(args):
43
+ ids, questions, options, answers = load_qa_dataset(args.qa_file)
44
+
45
+ if ids is None:
46
+ raise ValueError(f"No id field in {args.qa_file}.")
47
+
48
+ if args.num_docs > 0:
49
+ if args.prepared_retrieve_docs_path is not None:
50
+ documents = load_prepared_retrieve_docs(args.prepared_retrieve_docs_path)
51
+ docs = [d[:args.num_docs] for i,d in enumerate(documents)]
52
+ else:
53
+ raise ValueError(f"No prepared retrieve docs found.")
54
+ else:
55
+ docs = [None]*len(questions)
56
+
57
+ prompts = [build_qa_prompt(questions[i], docs[i]) for i in range(len(questions))]
58
+
59
+ llm = ChatAssistant(args.model_name, args.provider)
60
+
61
+ with open("log_score.txt", "a", encoding="utf-8") as f:
62
+ f.write("\n")
63
+
64
+ qa_results = evaluate_qa(questions, prompts, answers, ids, args, llm)
65
+ qa_results = [qa_results[i][qa_results[i].rfind("[")+1:qa_results[i].rfind("]")] for i in range(len(qa_results))]
66
+ # print(f"{qa_results}")
67
+ import pyperclip
68
+ pyperclip.copy('\n'.join(qa_results))
69
+
70
+ if __name__ == '__main__':
71
+ parser = argparse.ArgumentParser()
72
+
73
+ parser.add_argument("--qa_file", type=str, default="dataset/QA Data/random.jsonl")
74
+ parser.add_argument("--prepared_retrieve_docs_path", type=str, default="prepared_retrieve_docs.pkl")
75
+
76
+ parser.add_argument("--model_name", type=str, default="mistral-medium")
77
+ parser.add_argument("--provider", type=str, default="mistral")
78
+ parser.add_argument("--max_workers", type=int, default=4)
79
+ parser.add_argument("--num_docs", type=int, default=0)
80
+
81
+ parser.add_argument("--dataset_path", type=str)
82
+
83
+ args = parser.parse_args()
84
+
85
+ print(args)
86
+
87
+ main(args)
test/eval_qa.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from ..rag_pipeline import multichoice_qa_prompt
3
+ from ..rag_pipeline import ChatAssistant
4
+ from ..utils import paralelize, load_qa_dataset, load_prepared_retrieve_docs
5
+
6
+ from datetime import datetime
7
+ from typing import List, Optional
8
+ from langchain.schema import Document
9
+
10
+ def get_answer_from_response(llm_response: str) -> chr:
11
+ """
12
+ Get the answer from the LLM response.
13
+ """
14
+ return llm_response[llm_response.lower().rfind("the answer is ") + 14]
15
+
16
+ def build_multichoice_qa_prompt(question: str, options: str, document: Optional[List[Document]]) -> str:
17
+ """
18
+ Build the prompt for the multichoice QA task.
19
+ """
20
+ if document is not None:
21
+ document = '\n'.join([f"Document {i+1}:\n" + doc.page_content for i,doc in enumerate(document)])
22
+
23
+ return multichoice_qa_prompt.format(question=question, options=options, document=document)
24
+
25
+ def process_question(question, prompt, answer, id, args, llm):
26
+ llm_response = ""
27
+ for j in range(args.retries):
28
+ try:
29
+ llm_response = llm.get_response("", prompt)
30
+ ans = get_answer_from_response(llm_response)
31
+ if ans in ["A", "B", "C", "D", "E"]:
32
+ with open("log.txt", "a", encoding="utf-8") as f:
33
+ f.write(f"ID: {id}\n")
34
+ f.write(prompt)
35
+ f.write(f"LLM Response:\n{llm_response}\n")
36
+ f.write(f"Answer: {answer} {ans}\n\n")
37
+ break
38
+ except Exception as e:
39
+ print(f"Error: {e}")
40
+ ans = "#"
41
+ with open("log_score.txt", "a", encoding="utf-8") as f:
42
+ f.write("1" if ans == answer else "0")
43
+ return 1 if ans == answer else 0
44
+
45
+ def evaluate_qa(questions, prompts, answers, ids, args, llm):
46
+ import concurrent.futures
47
+ from tqdm import tqdm
48
+ correct = 0
49
+ with concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor:
50
+ futures = [executor.submit(process_question, questions[i], prompts[i], answers[i], ids[i], args, llm) for i in range(len(questions))]
51
+ for future in tqdm(concurrent.futures.as_completed(futures), total=len(questions)):
52
+ correct += future.result()
53
+ return correct / len(questions)
54
+
55
+
56
+ def main(args):
57
+ ids, questions, options, answers = load_qa_dataset(args.qa_file)
58
+
59
+ if ids is None:
60
+ raise ValueError(f"No id field in {args.qa_file}.")
61
+
62
+ if args.num_docs > 0:
63
+ if args.prepared_retrieve_docs_path is not None:
64
+ documents = load_prepared_retrieve_docs(args.prepared_retrieve_docs_path)
65
+ docs = [d[:args.num_docs] for i,d in enumerate(documents)]
66
+ else:
67
+ raise ValueError(f"No prepared retrieve docs found.")
68
+ else:
69
+ docs = [None]*len(questions)
70
+
71
+ prompts = [build_multichoice_qa_prompt(questions[i], options[i], docs[i]) for i in range(len(questions))]
72
+
73
+ # print(prompts[0])
74
+ llm = ChatAssistant(args.model_name, args.provider)
75
+
76
+ with open("log_score.txt", "a", encoding="utf-8") as f:
77
+ f.write(f"\n{datetime.now()} {args}\n")
78
+
79
+ acc = evaluate_qa(questions, prompts, answers, ids, args, llm)
80
+ print(f"Accuracy: {acc}")
81
+
82
+
83
+ if __name__ == '__main__':
84
+ parser = argparse.ArgumentParser()
85
+
86
+ # parser.add_argument("--qa_file", type=str, default="dataset/QA Data/MedAB/MedABv2.jsonl")
87
+ # parser.add_argument("--prepared_retrieve_docs_path", type=str, default="dataset/QA Data/MedAB/prepared_retrieve_docs_full.pkl")
88
+
89
+ parser.add_argument("--qa_file", type=str, default="dataset/QA Data/MedMCQA/translated_hard_questions.jsonl")
90
+ parser.add_argument("--prepared_retrieve_docs_path", type=str, default="dataset/QA Data/MedMCQA/prepared_retrieve_docs_full.pkl")
91
+
92
+ # Eval params
93
+ parser.add_argument("--model_name", type=str, default="mistral-medium")
94
+ parser.add_argument("--provider", type=str, default="mistral")
95
+ parser.add_argument("--max_workers", type=int, default=4)
96
+ parser.add_argument("--num_docs", type=int, default=0)
97
+ parser.add_argument("--retries", type=int, default=4)
98
+
99
+
100
+ # Dataset params
101
+ parser.add_argument("--dataset_path", type=str)
102
+
103
+ args = parser.parse_args()
104
+ print(f"Log:{args}")
105
+
106
+ main(args)
test/prepare_retrieve.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ from ..rag_pipeline import get_embeddings, vretrieve
5
+ from ..utils import load_local, load_qa_dataset, safe_save_langchain_docs
6
+
7
+ def main(args):
8
+ embed_model = get_embeddings(args.embed_model_name, show_progress=False)
9
+ vectorstore, docs = load_local(args.vectorstore_dir, embed_model)
10
+
11
+ ids, questions, options, answers = load_qa_dataset(args.qa_data_path)
12
+
13
+ rag_queries = [f"Question: {questions[i]}\n{options[i]}" for i in range(len(questions))]
14
+ if (args.rag_queries_path is not None) and os.path.exists(args.rag_queries_path):
15
+ import json
16
+ with open(args.rag_queries_path, "r", encoding="utf-8") as f:
17
+ rag_queries = [json.loads(line)["query"] for line in f]
18
+
19
+ from tqdm import tqdm
20
+ retrieve_results = [vretrieve(rag_queries[i], vectorstore, docs, args.retriever_k, args.metric, args.threshold) for i in tqdm(range(len(rag_queries)), desc="Retrieving documents")]
21
+
22
+ safe_save_langchain_docs(retrieve_results, args.prepared_retrieve_docs_path)
23
+
24
+
25
+ if __name__ == '__main__':
26
+ parser = argparse.ArgumentParser()
27
+
28
+ # Dataset params
29
+ parser.add_argument("--qa_data_path", type=str, default="dataset/QA Data/MedMCQA/translated_hard_questions.jsonl")
30
+
31
+ # Vectorstore params
32
+ parser.add_argument("--vectorstore_dir", type=str, default="notebook/An/master/knowledge/vectorstore_full")
33
+ parser.add_argument("--prepared_retrieve_docs_path", type=str, default="dataset/QA Data/MedMCQA/prepared_retrieve_docs_full.pkl")
34
+ parser.add_argument("--rag_queries_path", type=str, default=None)
35
+
36
+ # Model params
37
+ parser.add_argument("--embed_model_name", type=str, default="alibaba-nlp/gte-multilingual-base")
38
+
39
+ # Vectorstore retriever params
40
+ parser.add_argument("--vectorstore", type=str, choices=["faiss", "chroma"], default="faiss")
41
+ parser.add_argument("--metric", type=str, choices=["cosine", "mmr", "bm25"], default="mmr")
42
+ parser.add_argument("--retriever_k", type=int, default=20, help="Number of documents to retrieve")
43
+ parser.add_argument("--threshold", type=float, default=0.5, help="Threshold for cosine similarity")
44
+ parser.add_argument("--reranker_model_name", type=str, default=None)
45
+ parser.add_argument("--reranker_k", type=int, default=50, help="Number of documents to rerank")
46
+
47
+ args = parser.parse_args()
48
+ print(args)
49
+
50
+ main(args)
test/test_llm.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from ..rag_pipeline import ChatAssistant
2
+ from ..rag_pipeline import request_retrieve_prompt
3
+
4
+ cb = ChatAssistant("mistral-medium", "mistral")
5
+
6
+ query = "Beta blocker for hypertension"
7
+ query = request_retrieve_prompt.format(conversation=query, role="customer")
8
+ response = cb.get_response(user=query)
9
+ print(response)
test/test_retrieve.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ from ..rag_pipeline import get_embeddings, rerank
5
+ from ..utils import load_local
6
+
7
+ from ..rag_pipeline import vretrieve
8
+
9
+ def main(args):
10
+ embed_model = get_embeddings(args.embed_model_name)
11
+ vectorstore, docs = load_local(args.vectorstore_dir, embed_model)
12
+ retrieve_results = vretrieve(args.query, vectorstore, docs, args.retriever_k, args.metric, args.threshold)
13
+
14
+ retrieve_results = rerank(retrieve_results)
15
+
16
+ print(retrieve_results)
17
+
18
+ if __name__ == '__main__':
19
+ parser = argparse.ArgumentParser()
20
+
21
+ parser.add_argument("--query", type=str, required=False, default="What are the applications of beta blockers in the treatment of hypertension?")
22
+
23
+ # Vectorstore params
24
+ parser.add_argument("--vectorstore_dir", type=str, required=False, default="notebook/An/master/knowledge/vectorstore_full")
25
+
26
+ # Model params
27
+ parser.add_argument("--embed_model_name", type=str, default="alibaba-nlp/gte-multilingual-base")
28
+
29
+ # Vectorstore retriever params
30
+ parser.add_argument("--vectorstore", type=str, choices=["faiss", "chroma"], default="faiss")
31
+ parser.add_argument("--metric", type=str, choices=["cosine", "mmr", "bm25"], default="cosine")
32
+ parser.add_argument("--retriever_k", type=int, default=4, help="Number of documents to retrieve")
33
+ parser.add_argument("--threshold", type=float, default=0.7, help="Threshold for cosine similarity")
34
+ parser.add_argument("--reranker_model_name", type=str, default=None)
35
+ parser.add_argument("--reranker_k", type=int, default=20, help="Number of documents to rerank")
36
+
37
+ args = parser.parse_args()
38
+
39
+ main(args)
utils.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ from typing import List, Optional
4
+
5
+ from langchain_community.vectorstores import FAISS
6
+ from langchain_community.embeddings import HuggingFaceEmbeddings
7
+ from langchain.schema import Document
8
+
9
+
10
+ def load_local(vectorstore_dir: str, embed_model: HuggingFaceEmbeddings) -> tuple[Optional[FAISS], Optional[List[Document]]]:
11
+ """
12
+ Load the vectorstore and documents from disk.
13
+ Args:
14
+ vectorstore_dir: The directory to load the vectorstore from.
15
+ embed_model: The embedding model to use.
16
+ Returns:
17
+ vector_store: The vectorstore.
18
+ """
19
+ from langchain_community.vectorstores import FAISS
20
+
21
+ if not os.path.isdir(vectorstore_dir):
22
+ print(f"Vectorstore directory not found at {vectorstore_dir}. Creating a new one.")
23
+ os.makedirs(vectorstore_dir, exist_ok=True)
24
+
25
+ try:
26
+ vector_store = FAISS.load_local(vectorstore_dir, embed_model, allow_dangerous_deserialization=True)
27
+
28
+ docs_path = os.path.join(vectorstore_dir, "docs.pkl")
29
+ if os.path.exists(docs_path):
30
+ with open(docs_path, "rb") as f:
31
+ docs = pickle.load(f)
32
+ else:
33
+ docs = None
34
+ print("Warning: docs.pkl not found. BM25 search will not be available.")
35
+
36
+ print(f"Successfully loaded RAG state from {vectorstore_dir}")
37
+ return vector_store, docs
38
+ except Exception as e:
39
+ print(f"Could not load from {vectorstore_dir}. It might be empty or corrupted. Error: {e}")
40
+ return None, None
41
+
42
+ def save_local(vectorstore_dir: str, vectorstore: FAISS, docs: Optional[List[Document]]) -> None:
43
+ """
44
+ Save the vectorstore and documents to disk.
45
+ Args:
46
+ vectorstore_dir: The directory to save the vectorstore to.
47
+ vectorstore: The vectorstore to save.
48
+ docs: The documents to save.
49
+ """
50
+ if vectorstore is None:
51
+ raise ValueError("Nothing to save.")
52
+ if docs is None:
53
+ print("Warning: No documents to save. BM25 search will not be available.")
54
+
55
+ os.makedirs(vectorstore_dir, exist_ok=True)
56
+ vectorstore.save_local(vectorstore_dir)
57
+
58
+ if docs is not None:
59
+ with open(os.path.join(vectorstore_dir, "docs.pkl"), "wb") as f:
60
+ pickle.dump(docs, f)
61
+
62
+ print(f"Successfully saved RAG state to {vectorstore_dir}")
63
+
64
+ def load_qa_dataset(qa_dataset_path: str) -> tuple[List[str], List[str], List[str], List[str]]:
65
+ """
66
+ Load the QA dataset. (jsonl)
67
+ Args:
68
+ qa_dataset_path: The path to the QA dataset.
69
+ Returns:
70
+ Tuple: (ids, questions, options, answers)\\
71
+ ids: The ids of the questions\\
72
+ questions: The questions\\
73
+ options: The options for each question\\
74
+ answers: The answers for each question.
75
+ """
76
+ import json
77
+ if not os.path.exists(qa_dataset_path):
78
+ raise FileNotFoundError(f"Error: File not found at {qa_dataset_path}")
79
+
80
+ with open(qa_dataset_path, "r", encoding="utf-8") as f:
81
+ data = [json.loads(line) for line in f]
82
+ questions = [item["question"] for item in data]
83
+ try:
84
+ options = [
85
+ (f"A. {item['A']} \n" if item['A'] not in [" ", "", None] else "") +
86
+ (f"B. {item['B']} \n" if item['B'] not in [" ", "", None] else "") +
87
+ (f"C. {item['C']} \n" if item['C'] not in [" ", "", None] else "") +
88
+ (f"D. {item['D']} \n" if item['D'] not in [" ", "", None] else "") +
89
+ (f"E. {item['E']} \n" if item['E'] not in [" ", "", None] else "")
90
+ for item in data]
91
+ except KeyError:
92
+ options = [" " for item in data]
93
+ answers = [item["answer"] for item in data]
94
+ uuids = [item["uuid"] for item in data]
95
+ return uuids, questions, options, answers
96
+
97
+ def load_prepared_retrieve_docs(prepared_retrieve_docs_path: str) -> List[List[Document]]:
98
+ """
99
+ Load the prepared retrieve docs from a file.
100
+ Args:
101
+ prepared_retrieve_docs_path: The path to the prepared retrieve docs.
102
+ Returns:
103
+ A list of lists of documents.
104
+ """
105
+ return safe_load_langchain_docs(prepared_retrieve_docs_path)
106
+
107
+ def paralelize(func, max_workers: int = 4, **kwargs) -> List:
108
+ """
109
+ Parallelizes a function call over multiple keyword argument iterables.
110
+
111
+ Args:
112
+ func: The function to execute in parallel.
113
+ max_workers: The maximum number of threads to use.
114
+ **kwargs: Keyword arguments where each value is an iterable (e.g., a list).
115
+ All iterables must be of the same length.
116
+ The keyword names do not matter, but their order does.
117
+ Returns:
118
+ A list of the results of the function calls.
119
+ """
120
+ from concurrent.futures import ThreadPoolExecutor
121
+ from tqdm import tqdm
122
+
123
+ if not kwargs:
124
+ return []
125
+
126
+ arg_lists = list(kwargs.values())
127
+ if len(set(len(lst) for lst in arg_lists)) > 1:
128
+ raise ValueError("All iterable arguments must have the same length.")
129
+
130
+ total_items = len(arg_lists[0])
131
+ iterable = zip(*arg_lists)
132
+ unpacker_func = lambda args_tuple: func(*args_tuple)
133
+
134
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
135
+ results = list(tqdm(executor.map(unpacker_func, iterable), total=total_items))
136
+ return results
137
+
138
+ def safe_save_langchain_docs(documents: List[List[Document]], filepath: str):
139
+ """
140
+ Converts LangChain Document objects into a serializable list of dictionaries
141
+ and saves them to a file using pickle.
142
+
143
+ Args:
144
+ documents (List[List[Document]]): The nested list of LangChain Documents.
145
+ filepath (str): The path to the file where the data will be saved.
146
+ """
147
+ serializable_data = []
148
+ print(f"Preparing to save {len(documents)} lists of documents...")
149
+
150
+ # Convert each Document object into a dictionary
151
+ for doc_list in documents:
152
+ serializable_doc_list = []
153
+ for doc in doc_list:
154
+ serializable_doc_list.append({
155
+ "page_content": doc.page_content,
156
+ "metadata": doc.metadata,
157
+ })
158
+ serializable_data.append(serializable_doc_list)
159
+
160
+ print(f"Conversion complete. Saving to {filepath}...")
161
+ try:
162
+ # Use 'with' to ensure the file is closed properly, even if errors occur
163
+ with open(filepath, "wb") as f:
164
+ pickle.dump(serializable_data, f)
165
+ print("File saved successfully.")
166
+ except Exception as e:
167
+ print(f"An error occurred while saving the file: {e}")
168
+
169
+ def safe_load_langchain_docs(filepath: str) -> List[List[Document]]:
170
+ """
171
+ Loads data from a pickle file and reconstructs the LangChain Document objects.
172
+
173
+ Args:
174
+ filepath (str): The path to the file to load.
175
+
176
+ Returns:
177
+ List[List[Document]]: The reconstructed nested list of LangChain Documents.
178
+ """
179
+ reconstructed_documents = []
180
+
181
+ print(f"Loading data from {filepath}...")
182
+ try:
183
+ with open(filepath, "rb") as f:
184
+ loaded_data = pickle.load(f)
185
+ print("File loaded successfully. Reconstructing Document objects...")
186
+
187
+ # Reconstruct the Document objects from the dictionaries
188
+ for doc_list_data in loaded_data:
189
+ reconstructed_doc_list = []
190
+ for doc_data in doc_list_data:
191
+ reconstructed_doc_list.append(
192
+ Document(
193
+ page_content=doc_data["page_content"],
194
+ metadata=doc_data["metadata"]
195
+ )
196
+ )
197
+ reconstructed_documents.append(reconstructed_doc_list)
198
+
199
+ print("Document objects reconstructed successfully.")
200
+ return reconstructed_documents
201
+
202
+ except FileNotFoundError:
203
+ print(f"Error: The file at {filepath} was not found.")
204
+ return []
205
+ except EOFError:
206
+ print(f"Error: The file at {filepath} is corrupted or incomplete (EOFError).")
207
+ print("Please re-run the script that generates this file.")
208
+ return []
209
+ except Exception as e:
210
+ print(f"An unexpected error occurred while loading the file: {e}")
211
+ return []