Spaces:
Runtime error
Runtime error
| import os | |
| import time | |
| import gradio as gr | |
| from tqdm import tqdm | |
| from loguru import logger | |
| from transformers import AutoTokenizer,AutoModel | |
| from duckduckgo_search import ddg_suggestions | |
| from duckduckgo_search import ddg_translate, ddg, ddg_news | |
| from langchain.document_loaders import UnstructuredFileLoader | |
| from langchain.text_splitter import CharacterTextSplitter,RecursiveCharacterTextSplitter | |
| from langchain.llms import OpenAI | |
| from langchain.schema import Document | |
| from langchain.embeddings import OpenAIEmbeddings | |
| from langchain.embeddings.huggingface import HuggingFaceEmbeddings | |
| from langchain.vectorstores import FAISS | |
| from langchain.chains import ConversationalRetrievalChain,RetrievalQA,LLMChain | |
| from langchain.prompts import PromptTemplate | |
| from langchain.prompts.prompt import PromptTemplate | |
| from langchain.chat_models import ChatOpenAI | |
| from langchain import OpenAI,VectorDBQA | |
| def load_model(): | |
| tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) | |
| # gpu:.half().cuda() | |
| model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).quantize(bits=4, compile_parallel_kernel=True, parallel_num=2).float() | |
| model = model.eval() | |
| return tokenizer,model | |
| def chat_glm(input, history=None): | |
| if history is None: | |
| history = [] | |
| tokenizer,model = load_model() | |
| response, history = model.chat(tokenizer, input, history) | |
| logger.debug("chatglm:", input,response) | |
| return history, history | |
| def search_web(query): | |
| logger.debug("searchweb:", query) | |
| results = ddg(query) | |
| web_content = '' | |
| if results: | |
| for result in results: | |
| web_content += result['body'] | |
| return web_content | |
| def search_vec(query): | |
| logger.debug("searchvec:", query) | |
| embedding_model_name = 'GanymedeNil/text2vec-large-chinese' | |
| vec_path = 'cache' | |
| embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name) | |
| vector_store = FAISS.load_local(vec_path,embeddings) | |
| qa = VectorDBQA.from_chain_type(llm=OpenAI(), chain_type="stuff", vectorstore=vector_store,return_source_documents=True) | |
| result = qa({"query": query}) | |
| return result['result'] | |
| def chat_gpt(input, use_web, use_vec, history=None): | |
| if history is None: | |
| history = [] | |
| # history = [] # 4097 tokens limit | |
| context = "无" | |
| if use_vec: | |
| context = search_vec(input) | |
| prompt_template = f"""基于以下已知信息,请专业地回答用户的问题。 | |
| 若答案中存在编造成分,请在该部分开头添加“据我推测”。另外,答案请使用中文。 | |
| 已知内容: | |
| {context}"""+""" | |
| 问题: | |
| {question}""" | |
| prompt = PromptTemplate(template=prompt_template,input_variables=["question"]) | |
| llm = OpenAI(temperature = 0.2) | |
| chain = LLMChain(llm=llm, prompt=prompt) | |
| result = chain.run(input) | |
| return result | |
| def predict(input, | |
| large_language_model, | |
| use_web, | |
| use_vec, | |
| openai_key, | |
| history=None): | |
| logger.debug("predict..",large_language_model,use_web) | |
| if openai_key is not None: | |
| os.environ['OPENAI_API_KEY'] = openai_key | |
| else: | |
| return '',"You forgot OpenAI API key","You forgot OpenAI API key" | |
| if history == None: | |
| history = [] | |
| if large_language_model == "GPT-3.5-turbo": | |
| resp = chat_gpt(input, use_web, use_vec, history) | |
| elif large_language_model == "ChatGLM-6B-int4": | |
| _,resp = chat_glm(input, history) | |
| resp = resp[-1][1] | |
| elif large_language_model == "Search Web": | |
| resp = search_web(input) | |
| elif large_language_model == "Search VectorStore": | |
| resp = search_vec(input) | |
| history.append((input, resp)) | |
| return '', history, history | |
| def clear_session(): | |
| return '', None | |
| block = gr.Blocks() | |
| with block as demo: | |
| gr.Markdown("""<h1><center>MedKBQA(demo)</center></h1> | |
| <center><font size=3> | |
| 本项目基于LangChain、ChatGLM以及Open AI接口, 提供基于本地医药知识的自动问答应用. <br> | |
| </center></font> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| model_choose = gr.Accordion("模型选择") | |
| with model_choose: | |
| large_language_model = gr.Dropdown( | |
| ["ChatGLM-6B-int4","GPT-3.5-turbo","Search Web","Search VectorStore"], | |
| label="large language model", | |
| value="ChatGLM-6B-int4") | |
| use_web = gr.Radio(["True", "False"], | |
| label="Web Search", | |
| value="False") | |
| use_vec = gr.Radio(["True", "False"], | |
| label="VectorStore Search", | |
| value="False") | |
| openai_key = gr.Textbox(label="请输入OpenAI API key", type="password") | |
| with gr.Column(scale=4): | |
| chatbot = gr.Chatbot(label='ChatLLM').style(height=600) | |
| message = gr.Textbox(label='请输入问题') | |
| state = gr.State() | |
| with gr.Row(): | |
| clear_history = gr.Button("🧹 清除历史对话") | |
| send = gr.Button("🚀 发送") | |
| send.click(predict, | |
| inputs=[ | |
| message, large_language_model, use_web, use_vec, openai_key, state | |
| ], | |
| outputs=[message, chatbot, state]) | |
| clear_history.click(fn=clear_session, | |
| inputs=[], | |
| outputs=[chatbot, state], | |
| queue=False) | |
| message.submit(predict, | |
| inputs=[ | |
| message, large_language_model, use_web, use_vec, openai_key, state | |
| ], | |
| outputs=[message, chatbot, state]) | |
| gr.Markdown("""提醒:<br> | |
| 1. 使用时请先选择使用chatglm或者chatgpt进行问答. <br> | |
| 2. 使用chatgpt时需要输入您的api key. | |
| """) | |
| demo.queue().launch(server_name='0.0.0.0', share=False) |