Spaces:
Running
Running
| # ========== Standard Library ========== | |
| import os | |
| import tempfile | |
| import zipfile | |
| from typing import List, Optional, Tuple, Union | |
| import collections | |
| # ========== Third-Party Libraries ========== | |
| import gradio as gr | |
| from groq import Groq | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.document_loaders import DirectoryLoader, UnstructuredFileLoader | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain_core.vectorstores import InMemoryVectorStore | |
| from langchain_groq import ChatGroq | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| # ========== Configs ========== | |
| TITLE = """<h1 align="center">🗨️🦙 Llama 4 Docx Chatter</h1>""" | |
| AVATAR_IMAGES = ( | |
| None, | |
| "./logo.png", | |
| ) | |
| # Acceptable file extensions | |
| TEXT_EXTENSIONS = [".docx", ".zip"] | |
| # ========== Models & Clients ========== | |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
| client = Groq(api_key=GROQ_API_KEY) | |
| llm = ChatGroq(model="meta-llama/llama-4-scout-17b-16e-instruct", api_key=GROQ_API_KEY) | |
| embed_model = HuggingFaceEmbeddings(model_name="mixedbread-ai/mxbai-embed-large-v1") | |
| # ========== Core Components ========== | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=1000, | |
| chunk_overlap=100, | |
| separators=["\n\n", "\n"], | |
| ) | |
| rag_template = """You are an expert assistant tasked with answering questions based on the provided documents. | |
| Use only the given context to generate your answer. | |
| If the answer cannot be found in the context, clearly state that you do not know. | |
| Be detailed and precise in your response, but avoid mentioning or referencing the context itself. | |
| Context: | |
| {context} | |
| Question: | |
| {question} | |
| Answer:""" | |
| rag_prompt = PromptTemplate.from_template(rag_template) | |
| # ========== App State ========== | |
| class AppState: | |
| vectorstore: Optional[InMemoryVectorStore] = None | |
| rag_chain = None | |
| state = AppState() | |
| # ========== Utility Functions ========== | |
| def load_documents_from_files(files: List[str]) -> List: | |
| """Load documents from uploaded files directly without moving.""" | |
| all_documents = [] | |
| # Temporary directory if ZIP needs extraction | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| for file_path in files: | |
| ext = os.path.splitext(file_path)[1].lower() | |
| if ext == ".zip": | |
| # Extract ZIP inside temp_dir | |
| with zipfile.ZipFile(file_path, "r") as zip_ref: | |
| zip_ref.extractall(temp_dir) | |
| # Load all docx from extracted zip | |
| loader = DirectoryLoader( | |
| path=temp_dir, | |
| glob="**/*.docx", | |
| use_multithreading=True, | |
| ) | |
| docs = loader.load() | |
| all_documents.extend(docs) | |
| elif ext == ".docx": | |
| # Load single docx directly | |
| loader = UnstructuredFileLoader(file_path) | |
| docs = loader.load() | |
| all_documents.extend(docs) | |
| return all_documents | |
| def get_last_user_message(chatbot: List[Union[gr.ChatMessage, dict]]) -> Optional[str]: | |
| """Get last user prompt.""" | |
| for message in reversed(chatbot): | |
| content = ( | |
| message.get("content") if isinstance(message, dict) else message.content | |
| ) | |
| if ( | |
| message.get("role") if isinstance(message, dict) else message.role | |
| ) == "user": | |
| return content | |
| return None | |
| # ========== Main Logic ========== | |
| def upload_files( | |
| files: Optional[List[str]], chatbot: List[Union[gr.ChatMessage, dict]] | |
| ): | |
| """Handle file upload - .docx or .zip containing docx.""" | |
| if not files: | |
| return chatbot | |
| file_summaries = [] # <-- Collect formatted file/folder info | |
| documents = [] | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| for file_path in files: | |
| filename = os.path.basename(file_path) | |
| ext = os.path.splitext(file_path)[1].lower() | |
| if ext == ".zip": | |
| file_summaries.append(f"📦 **{filename}** (ZIP file) contains:") | |
| try: | |
| with zipfile.ZipFile(file_path, "r") as zip_ref: | |
| zip_ref.extractall(temp_dir) | |
| zip_contents = zip_ref.namelist() | |
| # Group files by folder | |
| folder_map = collections.defaultdict(list) | |
| for item in zip_contents: | |
| if item.endswith("/"): | |
| continue # skip folder entries themselves | |
| folder = os.path.dirname(item) | |
| file_name = os.path.basename(item) | |
| folder_map[folder].append(file_name) | |
| # Format nicely | |
| for folder, files_in_folder in folder_map.items(): | |
| if folder: | |
| file_summaries.append(f"📂 {folder}/") | |
| else: | |
| file_summaries.append(f"📄 (root)") | |
| for f in files_in_folder: | |
| file_summaries.append(f" - {f}") | |
| # Load docx files extracted from ZIP | |
| loader = DirectoryLoader( | |
| path=temp_dir, | |
| glob="**/*.docx", | |
| use_multithreading=True, | |
| ) | |
| docs = loader.load() | |
| documents.extend(docs) | |
| except zipfile.BadZipFile: | |
| chatbot.append( | |
| gr.ChatMessage( | |
| role="assistant", | |
| content=f"❌ Failed to open ZIP file: {filename}", | |
| ) | |
| ) | |
| elif ext == ".docx": | |
| file_summaries.append(f"📄 **{filename}**") | |
| loader = UnstructuredFileLoader(file_path) | |
| docs = loader.load() | |
| documents.extend(docs) | |
| else: | |
| file_summaries.append(f"❌ Unsupported file type: {filename}") | |
| if not documents: | |
| chatbot.append( | |
| gr.ChatMessage( | |
| role="assistant", content="No valid .docx files found in upload." | |
| ) | |
| ) | |
| return chatbot | |
| # Split documents | |
| chunks = text_splitter.split_documents(documents) | |
| if not chunks: | |
| chatbot.append( | |
| gr.ChatMessage( | |
| role="assistant", content="Failed to split documents into chunks." | |
| ) | |
| ) | |
| return chatbot | |
| # Create Vectorstore | |
| state.vectorstore = InMemoryVectorStore.from_documents( | |
| documents=chunks, | |
| embedding=embed_model, | |
| ) | |
| retriever = state.vectorstore.as_retriever() | |
| # Build RAG Chain | |
| state.rag_chain = ( | |
| {"context": retriever, "question": RunnablePassthrough()} | |
| | rag_prompt | |
| | llm | |
| | StrOutputParser() | |
| ) | |
| # Final display | |
| chatbot.append( | |
| gr.ChatMessage( | |
| role="assistant", | |
| content="**Uploaded Files:**\n" | |
| + "\n".join(file_summaries) | |
| + "\n\n✅ Ready to chat!", | |
| ) | |
| ) | |
| return chatbot | |
| def user_message( | |
| text_prompt: str, chatbot: List[Union[gr.ChatMessage, dict]] | |
| ) -> Tuple[str, List[Union[gr.ChatMessage, dict]]]: | |
| """Add user's text input to conversation.""" | |
| if text_prompt.strip(): | |
| chatbot.append(gr.ChatMessage(role="user", content=text_prompt)) | |
| return "", chatbot | |
| def process_query( | |
| chatbot: List[Union[gr.ChatMessage, dict]], | |
| ) -> List[Union[gr.ChatMessage, dict]]: | |
| """Process user's query through RAG pipeline.""" | |
| prompt = get_last_user_message(chatbot) | |
| if not prompt: | |
| chatbot.append( | |
| gr.ChatMessage(role="assistant", content="Please type a question first.") | |
| ) | |
| return chatbot | |
| if state.rag_chain is None: | |
| chatbot.append( | |
| gr.ChatMessage(role="assistant", content="Please upload documents first.") | |
| ) | |
| return chatbot | |
| chatbot.append(gr.ChatMessage(role="assistant", content="Thinking...")) | |
| try: | |
| response = state.rag_chain.invoke(prompt) | |
| chatbot[-1].content = response | |
| except Exception as e: | |
| chatbot[-1].content = f"Error: {str(e)}" | |
| return chatbot | |
| def reset_app( | |
| chatbot: List[Union[gr.ChatMessage, dict]], | |
| ) -> List[Union[gr.ChatMessage, dict]]: | |
| """Reset application state.""" | |
| state.vectorstore = None | |
| state.rag_chain = None | |
| return [ | |
| gr.ChatMessage( | |
| role="assistant", content="App reset! Upload new documents to start." | |
| ) | |
| ] | |
| # ========== UI Layout ========== | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.HTML(TITLE) | |
| chatbot = gr.Chatbot( | |
| label="Llama 4 RAG", | |
| type="messages", | |
| bubble_full_width=False, | |
| avatar_images=AVATAR_IMAGES, | |
| scale=2, | |
| height=350, | |
| ) | |
| with gr.Row(equal_height=True): | |
| text_prompt = gr.Textbox( | |
| placeholder="Ask a question...", show_label=False, autofocus=True, scale=28 | |
| ) | |
| send_button = gr.Button( | |
| value="Send", | |
| variant="primary", | |
| scale=1, | |
| min_width=80, | |
| ) | |
| upload_button = gr.UploadButton( | |
| label="Upload", | |
| file_count="multiple", | |
| file_types=TEXT_EXTENSIONS, | |
| scale=1, | |
| min_width=80, | |
| ) | |
| reset_button = gr.Button( | |
| value="Reset", | |
| variant="stop", | |
| scale=1, | |
| min_width=80, | |
| ) | |
| send_button.click( | |
| fn=user_message, | |
| inputs=[text_prompt, chatbot], | |
| outputs=[text_prompt, chatbot], | |
| queue=False, | |
| ).then(fn=process_query, inputs=[chatbot], outputs=[chatbot]) | |
| text_prompt.submit( | |
| fn=user_message, | |
| inputs=[text_prompt, chatbot], | |
| outputs=[text_prompt, chatbot], | |
| queue=False, | |
| ).then(fn=process_query, inputs=[chatbot], outputs=[chatbot]) | |
| upload_button.upload( | |
| fn=upload_files, inputs=[upload_button, chatbot], outputs=[chatbot], queue=False | |
| ) | |
| reset_button.click(fn=reset_app, inputs=[chatbot], outputs=[chatbot], queue=False) | |
| demo.queue().launch() | |