Spaces:
Runtime error
Runtime error
| import os | |
| import gradio as gr | |
| from llama_cpp import Llama | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import Chroma | |
| from langchain.prompts import PromptTemplate | |
| class RAGInterface: | |
| def __init__(self): | |
| # Initialize embedding model | |
| self.embeddings = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2", | |
| model_kwargs={'device': 'cpu'}, | |
| encode_kwargs={'normalize_embeddings': True} | |
| ) | |
| # Load vector store | |
| persist_directory = os.path.join(os.path.dirname(__file__), 'mydb') | |
| self.vectorstore = Chroma( | |
| persist_directory=persist_directory, | |
| embedding_function=self.embeddings | |
| ) | |
| # Initialize LLM | |
| self.llm = Llama.from_pretrained( | |
| repo_id="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", | |
| filename="Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", | |
| n_ctx=2048 | |
| ) | |
| # Define RAG prompt template | |
| self.template = """Answer the question based only on the following context: | |
| {context} | |
| Question: {question} | |
| Answer the question in a clear way. If you cannot find the answer in the context, | |
| just say "I don't have enough information to answer this question." | |
| Make sure to: | |
| 1. Only use information from the provided context | |
| 2. If you're unsure, acknowledge it | |
| """ | |
| self.prompt = PromptTemplate.from_template(self.template) | |
| def respond(self, message, history, system_message, temperature,max_tokens=2048): | |
| # Build messages list | |
| messages = [{"role": "system", "content": system_message}] | |
| for user_msg, assistant_msg in history: | |
| if user_msg: | |
| messages.append({"role": "user", "content": user_msg}) | |
| if assistant_msg: | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| # Search vector store | |
| retriever = self.vectorstore.as_retriever(search_kwargs={"k": 5}) | |
| docs = retriever.get_relevant_documents(message) | |
| context = "\n\n".join([doc.page_content for doc in docs]) | |
| # Format prompt and add to messages | |
| final_prompt = self.prompt.format(context=context, question=message) | |
| messages.append({"role": "user", "content": final_prompt}) | |
| # Generate response | |
| response = self.llm.create_chat_completion( | |
| messages=messages, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| ) | |
| return response['choices'][0]['message']['content'] | |
| def create_interface(self): | |
| # Custom CSS for better styling | |
| custom_css = """ | |
| <style> | |
| /* Global Styles */ | |
| body, #root { | |
| font-family: Helvetica, Arial, sans-serif; | |
| background-color: #1a1a1a; | |
| color: #fafafa; | |
| } | |
| /* Header Styles */ | |
| .app-header { | |
| background: linear-gradient(45deg, #1a1a1a 0%, #333333 100%); | |
| padding: 24px; | |
| border-radius: 8px; | |
| margin-bottom: 24px; | |
| text-align: center; | |
| } | |
| .app-title { | |
| font-size: 36px; | |
| margin: 0; | |
| color: #fafafa; | |
| } | |
| .app-subtitle { | |
| font-size: 18px; | |
| margin: 8px 0; | |
| color: #fafafa; | |
| opacity: 0.8; | |
| } | |
| /* Chat Container */ | |
| .chat-container { | |
| background-color: #2a2a2a; | |
| border-radius: 8px; | |
| padding: 20px; | |
| margin-bottom: 20px; | |
| } | |
| /* Control Panel */ | |
| .control-panel { | |
| background-color: #333; | |
| padding: 16px; | |
| border-radius: 8px; | |
| margin-top: 16px; | |
| } | |
| /* Gradio Component Overrides */ | |
| .gr-button { | |
| background-color: #4a4a4a; | |
| color: #fff; | |
| border: none; | |
| border-radius: 4px; | |
| padding: 8px 16px; | |
| transition: background-color 0.3s; | |
| } | |
| .gr-button:hover { | |
| background-color: #5a5a5a; | |
| } | |
| .gr-input, .gr-dropdown { | |
| background-color: #3a3a3a; | |
| color: #fff; | |
| border: 1px solid #4a4a4a; | |
| border-radius: 4px; | |
| padding: 8px; | |
| } | |
| </style> | |
| """ | |
| # Header HTML | |
| header_html = f""" | |
| <div class="app-header"> | |
| <h1 class="app-title">Document-Based Question Answering</h1> | |
| <h2 class="app-subtitle">Powered by Llama and RAG</h2> | |
| </div> | |
| {custom_css} | |
| """ | |
| # Create Gradio interface | |
| demo = gr.ChatInterface( | |
| fn=self.respond, | |
| additional_inputs=[ | |
| gr.Textbox( | |
| value="You are a friendly chatbot.", | |
| label="System Message", | |
| elem_classes="control-panel" | |
| ), | |
| # gr.Slider( | |
| # minimum=1, | |
| # maximum=2048, | |
| # value=512, | |
| # step=1, | |
| # label="Max New Tokens", | |
| # elem_classes="control-panel" | |
| # ), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature", | |
| elem_classes="control-panel" | |
| ), | |
| ], | |
| title="", # Title is handled in custom HTML | |
| description="Ask questions about Computers and get AI-powered answers.", | |
| # examples=[ | |
| # "What is a Computer?", | |
| # "How does machine learning work?", | |
| # "Explain artificial intelligence.", | |
| # ], | |
| theme=gr.themes.Default(), | |
| ) | |
| # Wrap the interface with custom CSS | |
| with gr.Blocks(css=custom_css) as wrapper: | |
| gr.HTML(header_html) | |
| demo.render() | |
| return wrapper | |
| def main(): | |
| interface = RAGInterface() | |
| demo = interface.create_interface() | |
| demo.launch(debug=True) | |
| if __name__ == "__main__": | |
| main() |