|
|
|
|
|
from src.retrieval.retriever_chain import get_base_retriever, load_hf_llm, create_qa_chain |
|
|
|
|
|
|
|
|
HF_MODEL = "huggingfaceh4/zephyr-7b-beta" |
|
|
|
|
|
|
|
|
|
|
|
def get_qa_chain(): |
|
|
""" |
|
|
Instantiates QA Chain. |
|
|
|
|
|
Returns: |
|
|
Runnable: Returns an instance of QA Chain. |
|
|
""" |
|
|
|
|
|
|
|
|
retriever = get_base_retriever(k=4, search_type="mmr") |
|
|
|
|
|
|
|
|
llm = load_hf_llm(repo_id=HF_MODEL, max_new_tokens=512, temperature=0.4) |
|
|
|
|
|
|
|
|
qa_chain = create_qa_chain(retriever, llm) |
|
|
|
|
|
return qa_chain |
|
|
|
|
|
|
|
|
def set_global_qa_chain(local_qa_chain): |
|
|
global global_qa_chain |
|
|
global_qa_chain = local_qa_chain |
|
|
|
|
|
|
|
|
|
|
|
def generate_response(message, history): |
|
|
""" |
|
|
Generates response based on the question being asked. |
|
|
|
|
|
Args: |
|
|
message (str): Question asked by the user. |
|
|
history (dict): Chat history. NOT USED FOR NOW. |
|
|
|
|
|
Returns: |
|
|
str: Returns the generated response. |
|
|
""" |
|
|
|
|
|
|
|
|
response = global_qa_chain.invoke(message) |
|
|
print(response) |
|
|
|
|
|
return response |
|
|
|