Spaces:
Sleeping
Sleeping
| import stat | |
| import gradio as gr | |
| from llama_index.core.postprocessor import SimilarityPostprocessor | |
| from llama_index.core.postprocessor import SentenceTransformerRerank | |
| from llama_index.core.postprocessor import MetadataReplacementPostProcessor | |
| from llama_index.core import StorageContext | |
| import chromadb | |
| from llama_index.vector_stores.chroma import ChromaVectorStore | |
| import zipfile | |
| import requests | |
| import torch | |
| from llama_index.core import Settings | |
| from llama_index.llms.huggingface import HuggingFaceLLM | |
| from llama_index.core import VectorStoreIndex, SimpleDirectoryReader | |
| import sys | |
| import logging | |
| import os | |
| enable_rerank = True | |
| # sentence_window,naive,recursive_retrieval | |
| retrieval_strategy = "sentence_window" | |
| base_embedding_source = "hf" # local,openai,hf | |
| # intfloat/multilingual-e5-small local:BAAI/bge-small-en-v1.5 text-embedding-3-small nvidia/NV-Embed-v2 Alibaba-NLP/gte-large-en-v1.5 | |
| base_embedding_model = "Alibaba-NLP/gte-large-en-v1.5" | |
| # meta-llama/Llama-3.1-8B meta-llama/Llama-3.2-3B-Instruct meta-llama/Llama-2-7b-chat-hf google/gemma-2-9b CohereForAI/c4ai-command-r-plus CohereForAI/aya-23-8B | |
| base_llm_model = "mistralai/Mistral-7B-Instruct-v0.3" | |
| # AdaptLLM/finance-chat | |
| base_llm_source = "hf" # cohere,hf,anthropic | |
| base_similarity_top_k = 20 | |
| # ChromaDB | |
| env_extension = "_large" # _large _dev_window _large_window | |
| db_collection = f"gte{env_extension}" # intfloat gte | |
| read_db = True | |
| active_chroma = True | |
| root_path = "." | |
| chroma_db_path = f"{root_path}/chroma_db" # ./chroma_db | |
| # ./processed_files.json | |
| processed_files_log = f"{root_path}/processed_files{env_extension}.json" | |
| # check hyperparameter | |
| if retrieval_strategy not in ["sentence_window", "naive"]: # recursive_retrieval | |
| raise Exception(f"{retrieval_strategy} retrieval_strategy is not support") | |
| os.environ["OPENAI_API_KEY"] = 'sk-xxxxxxxxxx' | |
| hf_api_key = os.getenv("HF_API_KEY") | |
| logging.basicConfig(stream=sys.stdout, level=logging.INFO) | |
| logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout)) | |
| torch.cuda.empty_cache() | |
| os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' | |
| print(f"loading embedding ..{base_embedding_model}") | |
| if base_embedding_source == 'hf': | |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
| Settings.embed_model = HuggingFaceEmbedding( | |
| model_name=base_embedding_model, trust_remote_code=True) # , | |
| else: | |
| raise Exception("embedding model is invalid") | |
| # setup prompts - specific to StableLM | |
| if base_llm_source == 'hf': | |
| from llama_index.core import PromptTemplate | |
| # This will wrap the default prompts that are internal to llama-index | |
| # taken from https://huggingface.co/Writer/camel-5b-hf | |
| query_wrapper_prompt = PromptTemplate( | |
| "Below is an instruction that describes a task. " | |
| "you need to make sure that user's question and retrived context mention the same stock symbol if not please give no answer to user" | |
| "Write a response that appropriately completes the request.\n\n" | |
| "### Instruction:\n{query_str}\n\n### Response:" | |
| ) | |
| if base_llm_source == 'hf': | |
| llm = HuggingFaceLLM( | |
| context_window=2048, | |
| max_new_tokens=512, # 256 | |
| generate_kwargs={"temperature": 0.1, "do_sample": False}, # 0.25 | |
| query_wrapper_prompt=query_wrapper_prompt, | |
| tokenizer_name=base_llm_model, | |
| model_name=base_llm_model, | |
| device_map="auto", | |
| tokenizer_kwargs={"max_length": 2048}, | |
| # uncomment this if using CUDA to reduce memory usage | |
| model_kwargs={"torch_dtype": torch.float16} | |
| ) | |
| Settings.chunk_size = 512 | |
| Settings.llm = llm | |
| """#### Load documents, build the VectorStoreIndex""" | |
| def download_and_extract_chroma_db(url, destination): | |
| """Download and extract ChromaDB from Hugging Face Datasets.""" | |
| # Create destination folder if it doesn't exist | |
| if not os.path.exists(destination): | |
| os.makedirs(destination) | |
| else: | |
| # If the folder exists, remove it to ensure a fresh extract | |
| print("Destination folder exists. Removing it...") | |
| for root, dirs, files in os.walk(destination, topdown=False): | |
| for file in files: | |
| os.remove(os.path.join(root, file)) | |
| for dir in dirs: | |
| os.rmdir(os.path.join(root, dir)) | |
| print("Destination folder cleared.") | |
| db_zip_path = os.path.join(destination, "chroma_db.zip") | |
| if not os.path.exists(db_zip_path): | |
| # Download the ChromaDB zip file | |
| print("Downloading ChromaDB from Hugging Face Datasets...") | |
| headers = { | |
| "Authorization": f"Bearer {hf_api_key}" | |
| } | |
| response = requests.get(url, headers=headers, stream=True) | |
| response.raise_for_status() | |
| with open(db_zip_path, "wb") as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| print("Download completed.") | |
| else: | |
| print("Zip file already exists, skipping download.") | |
| # Extract the zip file | |
| print("Extracting ChromaDB...") | |
| with zipfile.ZipFile(db_zip_path, 'r') as zip_ref: | |
| zip_ref.extractall(destination) | |
| print("Extraction completed. Zip file retained.") | |
| # URL to your dataset hosted on Hugging Face | |
| chroma_db_url = "https://huggingface.co/datasets/iamboolean/set50-db/resolve/main/chroma_db.zip" | |
| # Local destination for the ChromaDB | |
| chroma_db_path_extract = "./" # You can change this to your desired path | |
| # Download and extract the ChromaDB | |
| download_and_extract_chroma_db(chroma_db_url, chroma_db_path_extract) | |
| # Define ChromaDB client (persistent mode)er | |
| db = chromadb.PersistentClient(path=chroma_db_path) | |
| print(f"db path:{chroma_db_path}") | |
| chroma_collection = db.get_or_create_collection(db_collection) | |
| print(f"db collection:{db_collection}") | |
| # Set up ChromaVectorStore and embeddings | |
| vector_store = ChromaVectorStore(chroma_collection=chroma_collection) | |
| storage_context = StorageContext.from_defaults(vector_store=vector_store) | |
| document_count = chroma_collection.count() | |
| print(f"Total documents in the collection: {document_count}") | |
| index = VectorStoreIndex.from_vector_store( | |
| vector_store=vector_store, | |
| # embed_model=embed_model, | |
| ) | |
| """#### Query Index""" | |
| rerank = SentenceTransformerRerank( | |
| model="cross-encoder/ms-marco-MiniLM-L-2-v2", top_n=10 | |
| ) | |
| node_postprocessors = [] | |
| # node_postprocessors.append(SimilarityPostprocessor(similarity_cutoff=0.6)) | |
| if retrieval_strategy == 'sentence_window': | |
| node_postprocessors.append( | |
| MetadataReplacementPostProcessor(target_metadata_key="window")) | |
| if enable_rerank: | |
| node_postprocessors.append(rerank) | |
| query_engine = index.as_query_engine( | |
| similarity_top_k=base_similarity_top_k, | |
| # the target key defaults to `window` to match the node_parser's default | |
| node_postprocessors=node_postprocessors, | |
| ) | |
| def metadata_formatter(metadata): | |
| company_symbol = metadata['file_name'].split( | |
| '-')[0] # Split at '-' and take the first part | |
| # Split at '-' and then '.' to extract the year | |
| year = metadata['file_name'].split('-')[1].split('.')[0] | |
| page_number = metadata['page_label'] | |
| return f"Company File: {metadata['file_name'].split('-')[0]}, Year: {metadata['file_name'].split('-')[1].split('.')[0]}, Page Number: {metadata['page_label']}" | |
| def query_journal(question): | |
| response = query_engine.query(question) # Query the index | |
| matched_nodes = response.source_nodes # Extract matched nodes | |
| # Prepare the matched nodes details | |
| retrieved_context = "\n".join([ | |
| # f"Node ID: {node.node_id}\n" | |
| # f"Matched Content: {node.node.text}\n" | |
| # f"Metadata: {node.node.metadata if node.node.metadata else 'None'}" | |
| f"Metadata: {metadata_formatter(node.node.metadata) if node.node.metadata else 'None'}" | |
| for node in matched_nodes | |
| ]) | |
| generated_answer = str(response) | |
| # Return both retrieved context and detailed matched nodes | |
| return retrieved_context, generated_answer | |
| # Define the Gradio interface | |
| with gr.Blocks() as app: | |
| # Title | |
| gr.Markdown( | |
| """ | |
| <div style="text-align: center;"> | |
| <h1>SET50RAG: Retrieval-Augmented Generation for Thai Public Companies Question Answering</h1> | |
| </div> | |
| """ | |
| ) | |
| # Description | |
| gr.Markdown( | |
| """ | |
| The **SET50RAG** tool provides an interactive way to analyze and extract insights from **243 annual reports** of Thai public companies spanning **5 years**. | |
| By leveraging advanced **Retrieval-Augmented Generation**, including **GTE-Large embedding models**, **Sentence Window with Reranking**, and powerful **Large Language Models (LLMs)** like **Mistral-7B**, the system efficiently retrieves and answers complex financial queries. | |
| This scalable and cost-effective solution reduces reliance on parametric knowledge, ensuring contextually accurate and relevant responses. | |
| """ | |
| ) | |
| # How to Use Section | |
| gr.Markdown( | |
| """ | |
| ### How to Use | |
| 1. Type your question in the box or select an example question below. | |
| 2. Click **Submit** to retrieve the context and get an AI-generated answer. | |
| 3. Review the retrieved context and the generated answer to gain insights. | |
| --- | |
| """ | |
| ) | |
| # Example Questions Section | |
| gr.Markdown( | |
| """ | |
| ### Example Questions | |
| - What is the revenue of PTTOR in 2022? | |
| - what is effect of COVID-19 on BDMS show me in Timeline format from 2019 to 2023? | |
| - How does CPALL plan for electric vehicles? | |
| """ | |
| ) | |
| # Interactive Section (RAG Box) | |
| with gr.Row(): | |
| with gr.Column(): | |
| user_question = gr.Textbox( | |
| label="Ask a Question", | |
| placeholder="Type your question here, e.g., 'What is the revenue of PTTOR in 2022?'", | |
| ) | |
| example_question_button = gr.Button("Use Example Question") | |
| with gr.Column(): | |
| generated_answer = gr.Textbox( | |
| label="Generated Answer", | |
| placeholder="The AI-generated answer will appear here.", | |
| interactive=False, | |
| ) | |
| retrieved_context = gr.Textbox( | |
| label="Retrieved Context", | |
| placeholder="Relevant context will appear here.", | |
| interactive=False, | |
| ) | |
| # Button for user interaction | |
| submit_button = gr.Button("Submit") | |
| # Example question logic | |
| def use_example_question(): | |
| return "What is the revenue of PTTOR in 2022?" | |
| example_question_button.click( | |
| use_example_question, inputs=[], outputs=[user_question] | |
| ) | |
| # Interaction logic for submitting user queries | |
| submit_button.click( | |
| query_journal, inputs=[user_question], outputs=[ | |
| retrieved_context, generated_answer] | |
| ) | |
| # Footer | |
| gr.Markdown( | |
| """ | |
| --- | |
| ### Limitations and Bias: | |
| - Optimized for Thai financial reports from SET50 companies. Results may vary for other domains. | |
| - Retrieval and accuracy depend on data quality and embedding models. | |
| """ | |
| ) | |
| # Launch the app | |
| # app.launch() | |
| app.launch(server_name="0.0.0.0") # , server_port=7860 | |