from transformers import AutoTokenizer, AutoModelForCausalLM import torch from datetime import datetime import os from mew_log import log_info, log_error # Import your custom logging methods class HFModel: def __init__(self, model_name): parts = model_name.split("/") self.friendly_name = parts[1] self.chat_history = [] self.log_file = f"chat_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md" try: log_info(f"=== Loading Model: {self.friendly_name} ===") self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.bfloat16) #.cuda() self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) log_info(f"=== Model Loaded Successfully: {self.friendly_name} ===") except Exception as e: log_error(f"ERROR Loading Model: {e}", e) raise def generate_response(self, input_text, max_length=100, skip_special_tokens=True): try: inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device) outputs = self.model.generate( **inputs, max_length=max_length, pad_token_id=self.tokenizer.eos_token_id, # Ensure proper padding do_sample=True, # Enable sampling for more diverse outputs top_k=50, # Limit sampling to top-k tokens top_p=0.95, # Use nucleus sampling temperature=0.7, # Control randomness ) #inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device) #outputs = self.model.generate(**inputs, max_length=max_length) response = self.tokenizer.decode(outputs[0], skip_special_tokens=skip_special_tokens).strip() log_info(f"Generated Response: {response}") return response except Exception as e: log_error(f"ERROR Generating Response: {e}", e) raise def stream_response(self, input_text, max_length=100, skip_special_tokens=True): try: inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device) for output in self.model.generate(**inputs, max_length=max_length, do_stream=True): response = self.tokenizer.decode(output, skip_special_tokens=skip_special_tokens).strip() yield response except Exception as e: log_error(f"ERROR Streaming Response: {e}", e) raise def chat(self, user_input, max_length=100, skip_special_tokens=True): try: # Add user input to chat history self.chat_history.append({"role": "user", "content": user_input}) log_info(f"User Input: {user_input}") # Generate model response model_response = self.generate_response(user_input, max_length=max_length, skip_special_tokens=skip_special_tokens) # Add model response to chat history self.chat_history.append({"role": "assistant", "content": model_response}) log_info(f"Assistant Response: {model_response}") # Save chat log self.save_chat_log() return model_response except Exception as e: log_error(f"ERROR in Chat: {e}", e) raise def save_chat_log(self): try: with open(self.log_file, "a", encoding="utf-8") as f: for entry in self.chat_history[-2:]: # Save only the latest interaction role = entry["role"] content = entry["content"] f.write(f"**{role.capitalize()}:**\n\n{content}\n\n---\n\n") log_info(f"Chat log saved to {self.log_file}") except Exception as e: log_error(f"ERROR Saving Chat Log: {e}", e) raise def clear_chat_history(self): try: self.chat_history = [] log_info("Chat history cleared.") except Exception as e: log_error(f"ERROR Clearing Chat History: {e}", e) raise def print_chat_history(self): try: for entry in self.chat_history: role = entry["role"] content = entry["content"] print(f"{role.capitalize()}: {content}\n") log_info("Printed chat history to console.") except Exception as e: log_error(f"ERROR Printing Chat History: {e}", e) raise