Spaces:
Runtime error
Runtime error
| import spaces | |
| import subprocess | |
| subprocess.run( | |
| 'pip install flash-attn --no-build-isolation', | |
| env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, | |
| shell=True | |
| ) | |
| import os | |
| import torch | |
| from dotenv import load_dotenv | |
| from langchain_community.vectorstores import Qdrant | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain.prompts import ChatPromptTemplate | |
| from langchain.schema.runnable import RunnablePassthrough | |
| from langchain.schema.output_parser import StrOutputParser | |
| from qdrant_client import QdrantClient, models | |
| from langchain_openai import ChatOpenAI | |
| import gradio as gr | |
| import logging | |
| from typing import List, Tuple | |
| from dataclasses import dataclass | |
| from datetime import datetime | |
| from transformers import AutoTokenizer, AutoModelForCausalLM ,pipeline | |
| from langchain_huggingface.llms import HuggingFacePipeline | |
| import re | |
| from langchain_huggingface.llms import HuggingFacePipeline | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline,BitsAndBytesConfig,TextIteratorStreamer | |
| from langchain_cerebras import ChatCerebras | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class Message: | |
| role: str | |
| content: str | |
| timestamp: str | |
| class ChatHistory: | |
| def __init__(self): | |
| self.messages: List[Message] = [] | |
| def add_message(self, role: str, content: str): | |
| timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| self.messages.append(Message(role=role, content=content, timestamp=timestamp)) | |
| def get_formatted_history(self, max_messages: int = 10) -> str: | |
| """Returns the most recent conversation history formatted as a string""" | |
| recent_messages = self.messages[-max_messages:] if len(self.messages) > max_messages else self.messages | |
| formatted_history = "\n".join([ | |
| f"{msg.role}: {msg.content}" for msg in recent_messages | |
| ]) | |
| return formatted_history | |
| def clear(self): | |
| self.messages = [] | |
| # Load environment variables | |
| load_dotenv() | |
| # HuggingFace API Token | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| C_apikey = os.getenv("C_apikey") | |
| OPENAPI_KEY = os.getenv("OPENAPI_KEY") | |
| if not HF_TOKEN: | |
| logger.error("HF_TOKEN is not set in the environment variables.") | |
| exit(1) | |
| # HuggingFace Embeddings | |
| embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| # Qdrant Client Setup | |
| try: | |
| client = QdrantClient( | |
| url=os.getenv("QDRANT_URL"), | |
| api_key=os.getenv("QDRANT_API_KEY"), | |
| prefer_grpc=False | |
| ) | |
| except Exception as e: | |
| logger.error("Failed to connect to Qdrant. Ensure QDRANT_URL and QDRANT_API_KEY are correctly set.") | |
| exit(1) | |
| # Define collection name | |
| collection_name = "mawared" | |
| # Try to create collection | |
| try: | |
| client.create_collection( | |
| collection_name=collection_name, | |
| vectors_config=models.VectorParams( | |
| size=384, # GTE-large embedding size | |
| distance=models.Distance.COSINE | |
| ) | |
| ) | |
| logger.info(f"Created new collection: {collection_name}") | |
| except Exception as e: | |
| if "already exists" in str(e): | |
| logger.info(f"Collection {collection_name} already exists, continuing...") | |
| else: | |
| logger.error(f"Error creating collection: {e}") | |
| exit(1) | |
| # Create Qdrant vector store | |
| db = Qdrant( | |
| client=client, | |
| collection_name=collection_name, | |
| embeddings=embeddings, | |
| ) | |
| # Create retriever | |
| retriever = db.as_retriever( | |
| search_type="similarity", | |
| search_kwargs={"k": 5} | |
| ) | |
| # retriever = db.as_retriever( | |
| # search_type="mmr", | |
| # search_kwargs={"k": 5, "fetch_k": 10, "lambda_mult": 0.5} | |
| # ) | |
| # retriever = db.as_retriever( | |
| # search_type="similarity_score_threshold", | |
| # search_kwargs={"k": 5, "score_threshold": 0.8} | |
| # ) | |
| # Load model directly | |
| # Set up the LLM | |
| # llm = ChatOpenAI( | |
| # base_url="https://api-inference.huggingface.co/v1/", | |
| # temperature=0, | |
| # api_key=HF_TOKEN, | |
| # model="mistralai/Mistral-Nemo-Instruct-2407", | |
| # max_tokens=None, | |
| # timeout=None | |
| # ) | |
| #llm = ChatOpenAI( | |
| # base_url="https://openrouter.ai/api/v1", | |
| #temperature=0.01, | |
| # api_key=OPENAPI_KEY, | |
| #model="google/gemini-2.0-flash-exp:free", | |
| #max_tokens=None, | |
| #timeout=None, | |
| # max_retries=3, | |
| #) | |
| # llm = ChatCerebras( | |
| # model="llama-3.3-70b", | |
| # api_key=C_apikey, | |
| # stream=True | |
| # ) | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True | |
| ) | |
| model_id = "meta-llama/Llama-3.2-3B-Instruct" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16, | |
| device_map="cuda", | |
| attn_implementation="flash_attention_2", | |
| #quantization_config=quantization_config | |
| ) | |
| pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=8192 ) | |
| llm = HuggingFacePipeline(pipeline=pipe) | |
| # Create prompt template with chat history | |
| template = """ | |
| You are an expert assistant specializing in the Mawared HR System. Your role is to provide precise and contextually relevant answers based on the retrieved context and chat history. | |
| Key Responsibilities: | |
| Use the given chat history and retrieved context to craft accurate and detailed responses. | |
| If necessary, ask specific and targeted clarifying questions to gather more information. | |
| Present step-by-step instructions in a clear, numbered format when applicable. | |
| Rules for Responses: | |
| Strictly use the information from the provided context and chat history. Avoid making up or fabricating any details. | |
| Do not reference the retrieval process, sources, pages, or documents in your responses. | |
| Maintain a conversational flow by asking relevant follow-up questions to engage the user and enhance the interaction. | |
| Inputs for Your Response: | |
| Previous Conversation: {chat_history} | |
| Retrieved Context: {context} | |
| Current Question: {question} | |
| Answer:{{answer}} | |
| Your answers must be expressive, detailed, and fully address the userβs needs without deviating from the provided information. | |
| """ | |
| prompt = ChatPromptTemplate.from_template(template) | |
| # Create the RAG chain with chat history | |
| def create_rag_chain(chat_history: str): | |
| chain = ( | |
| { | |
| "context": retriever, | |
| "question": RunnablePassthrough(), | |
| "chat_history": lambda x: chat_history | |
| } | |
| | prompt | |
| | llm | |
| | StrOutputParser() | |
| ) | |
| return chain | |
| # Initialize chat history | |
| chat_history = ChatHistory() | |
| # Gradio Function | |
| def ask_question_gradio(question, history): | |
| try: | |
| # Add user question to chat history | |
| chat_history.add_message("user", question) | |
| # Get formatted history | |
| formatted_history = chat_history.get_formatted_history() | |
| # Create chain with current chat history | |
| rag_chain = create_rag_chain(formatted_history) | |
| # Generate response | |
| response = "" | |
| for chunk in rag_chain.stream(question): | |
| response += chunk | |
| # Add assistant response to chat history | |
| chat_history.add_message("assistant", response) | |
| # Update Gradio chat history | |
| history.append({"role": "user", "content": question}) | |
| history.append({"role": "assistant", "content": response}) | |
| return "", history | |
| except Exception as e: | |
| logger.error(f"Error during question processing: {e}") | |
| return "", history + [{"role": "assistant", "content": "An error occurred. Please try again later."}] | |
| def clear_chat(): | |
| chat_history.clear() | |
| return [], "" | |
| # Gradio Interface | |
| with gr.Blocks(theme='lone17/kotaemon') as iface: | |
| gr.Image("Image.jpg" , width=1200 , height=300 ,show_label=False, show_download_button=False) | |
| gr.Markdown("# Mawared HR Assistant 2.5.1") | |
| gr.Markdown('### Instructions') | |
| gr.Markdown("Ask a question about MawaredHR and get a detailed answer , if you get an error try again with same prompt , its an Api issue and we are working on it π") | |
| chatbot = gr.Chatbot( | |
| height=750, | |
| show_label=False, | |
| type="messages" # Using the new messages format | |
| ) | |
| with gr.Row(): | |
| question_input = gr.Textbox( | |
| label="Ask a question:", | |
| placeholder="Type your question here...", | |
| scale=30 | |
| ) | |
| clear_button = gr.Button("Clear Chat", scale=1) | |
| question_input.submit( | |
| ask_question_gradio, | |
| inputs=[question_input, chatbot], | |
| outputs=[question_input, chatbot] | |
| ) | |
| clear_button.click( | |
| clear_chat, | |
| outputs=[chatbot, question_input] | |
| ) | |
| # Launch the Gradio App | |
| if __name__ == "__main__": | |
| iface.launch() |