Spaces:
Sleeping
Sleeping
| import os | |
| from langchain.document_loaders import TextLoader, DirectoryLoader | |
| from langchain.vectorstores import FAISS | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| import torch | |
| import numpy as np | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig | |
| from datetime import datetime | |
| import gradio as gr | |
| class DocumentRetrievalAndGeneration: | |
| def __init__(self, embedding_model_name, lm_model_id, data_folder, faiss_index_path): | |
| self.documents = self.load_documents(data_folder) | |
| self.embeddings = SentenceTransformer(embedding_model_name) | |
| self.gpu_index = self.load_faiss_index(faiss_index_path) | |
| self.llm = self.initialize_llm(lm_model_id) | |
| def load_documents(self, folder_path): | |
| loader = DirectoryLoader(folder_path, loader_cls=TextLoader) | |
| documents = loader.load() | |
| print('Length of documents:', len(documents)) | |
| return documents | |
| def load_faiss_index(self, faiss_index_path): | |
| cpu_index = faiss.read_index(faiss_index_path) | |
| gpu_resource = faiss.StandardGpuResources() | |
| gpu_index = faiss.index_cpu_to_gpu(gpu_resource, 0, cpu_index) | |
| return gpu_index | |
| def initialize_llm(self, model_id): | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16 | |
| ) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| generate_text = pipeline( | |
| model=model, | |
| tokenizer=tokenizer, | |
| return_full_text=True, | |
| task='text-generation', | |
| temperature=0.6, | |
| max_new_tokens=2048, | |
| ) | |
| return generate_text | |
| def query_and_generate_response(self, query): | |
| query_embedding = self.embeddings.encode(query, convert_to_tensor=True).cpu().numpy() | |
| distances, indices = self.gpu_index.search(np.array([query_embedding]), k=5) | |
| content = "" | |
| for idx in indices[0]: | |
| content += "-" * 50 + "\n" | |
| content += self.documents[idx].page_content + "\n" | |
| print(self.documents[idx].page_content) | |
| print("############################") | |
| prompt = f"Query: {query}\nSolution: {content}\n" | |
| # Encode and prepare inputs | |
| messages = [{"role": "user", "content": prompt}] | |
| encodeds = self.llm.tokenizer.apply_chat_template(messages, return_tensors="pt") | |
| model_inputs = encodeds.to(self.llm.device) | |
| # Perform inference and measure time | |
| start_time = datetime.now() | |
| generated_ids = self.llm.model.generate(model_inputs, max_new_tokens=1000, do_sample=True) | |
| elapsed_time = datetime.now() - start_time | |
| # Decode and return output | |
| decoded = self.llm.tokenizer.batch_decode(generated_ids) | |
| generated_response = decoded[0] | |
| print("Generated response:", generated_response) | |
| print("Time elapsed:", elapsed_time) | |
| print("Device in use:", self.llm.device) | |
| return generated_response | |
| def qa_infer_gradio(self, query): | |
| response = self.query_and_generate_response(query) | |
| return response | |
| if __name__ == "__main__": | |
| # Example usage | |
| embedding_model_name = 'flax-sentence-embeddings/all_datasets_v3_MiniLM-L12' | |
| lm_model_id = "mistralai/Mistral-7B-Instruct-v0.2" | |
| data_folder = 'sample_embedding_folder' | |
| faiss_index_path = 'faiss_index_new_model3.index' | |
| doc_retrieval_gen = DocumentRetrievalAndGeneration(embedding_model_name, lm_model_id, data_folder, faiss_index_path) | |
| # Define Gradio interface function | |
| def launch_interface(): | |
| css_code = """ | |
| .gradio-container { | |
| background-color: #daccdb; | |
| } | |
| /* Button styling for all buttons */ | |
| button { | |
| background-color: #927fc7; /* Default color for all other buttons */ | |
| color: black; | |
| border: 1px solid black; | |
| padding: 10px; | |
| margin-right: 10px; | |
| font-size: 16px; /* Increase font size */ | |
| font-weight: bold; /* Make text bold */ | |
| } | |
| """ | |
| EXAMPLES = ["TDA4 product planning and datasheet release progress? ", | |
| "I'm using Code Composer Studio 5.4.0.00091 and enabled FPv4SPD16 floating point support for CortexM4 in TDA2. However, after building the project, the .asm file shows --float_support=vfplib instead of FPv4SPD16. Why is this happening?", | |
| "Master core in TDA2XX is a15 and in TDA3XX it is m4,so we have to shift all modules that are being used by a15 in TDA2XX to m4 in TDA3xx."] | |
| file_path = "ticketNames.txt" | |
| # Read the file content | |
| with open(file_path, "r") as file: | |
| content = file.read() | |
| ticket_names = json.loads(content) | |
| dropdown = gr.Dropdown(label="Sample queries", choices=ticket_names) | |
| # Define Gradio interface | |
| interface = gr.Interface( | |
| fn=doc_retrieval_gen.qa_infer_gradio, | |
| inputs=[gr.Textbox(label="QUERY", placeholder="Enter your query here")], | |
| allow_flagging='never', | |
| examples=EXAMPLES, | |
| cache_examples=False, | |
| outputs=gr.Textbox(label="SOLUTION"), | |
| css=css_code | |
| ) | |
| # Launch Gradio interface | |
| interface.launch(debug=True) | |
| # Launch the interface | |
| launch_interface() | |