from google import SemanticSearch, GoogleSearch, Document from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import BitsAndBytesConfig from transformers.utils import is_flash_attn_2_available import yaml import torch def load_configs(config_file: str) -> dict: with open(config_file, "r") as f: configs = yaml.safe_load(f) return configs class RAGModel: def __init__(self, configs) -> None: self.configs = configs self.device = configs["model"]["device"] model_url = configs["model"]["genration_model"] # quantization_config = BitsAndBytesConfig( # load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16 # ) self.model = AutoModelForCausalLM.from_pretrained( model_url, torch_dtype=torch.float16, # quantization_config=quantization_config, low_cpu_mem_usage=False, attn_implementation="sdpa", ).to(self.device) self.tokenizer = AutoTokenizer.from_pretrained( model_url, ) def create_prompt(self, query, topk_items: list[str]): context = "_ " + "\n-".join(c for c in topk_items) base_prompt = f"""Based on the follwing context items, please answer the query. Give time for yourself to read the context and then answer the query. Do not return thinking process, just return the answer. If you do not find the answer, or if the query is offesnsive or in any other way harmfull just return "I'm not aware of it" Now use the following context items to answer the user query. {context}. user query : {query} """ dialog_template = [{"role": "user", "content": base_prompt}] prompt = self.tokenizer.apply_chat_template( conversation=dialog_template, tokenize=False, add_feneration_prompt=True ) return prompt def answer_query(self, query: str, topk_items: list[str]): prompt = self.create_prompt(query, topk_items) print(prompt) input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device) output = self.model.generate(**input_ids, max_new_tokens=512) text = self.tokenizer.decode(output[0]) return text if __name__ == "__main__": configs = load_configs(config_file="rag.configs.yml") query = "what is LLM" # g = GoogleSearch(query) # data = g.all_page_data # d = Document(data, 512) # s = SemanticSearch( "all-mpnet-base-v2", "mps") # topk = s.semantic_search(query=query, k=32) r = RAGModel(configs) output = r.answer_query(query=query, topk_items=[""]) print(output)