File size: 4,015 Bytes
fc0c635
 
 
 
5d07573
fc0c635
 
 
 
 
 
 
 
9a8affc
 
 
 
 
 
5d07573
9a8affc
 
fc0c635
9a8affc
 
 
 
 
 
 
5d07573
9a8affc
fc0c635
 
9a8affc
 
 
 
 
 
5d07573
9a8affc
fc0c635
 
9a8affc
 
 
 
fc0c635
9a8affc
 
fc0c635
9a8affc
 
 
fc0c635
9a8affc
 
fc0c635
9a8affc
 
5d07573
9a8affc
fc0c635
 
9a8affc
 
 
 
 
 
 
 
5d07573
9a8affc
fc0c635
 
9a8affc
 
 
 
5d07573
9a8affc
fc0c635
 
9a8affc
 
 
 
 
 
 
5d07573
9a8affc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from datetime import datetime
import os
from mew_log import log_info, log_error  # Import your custom logging methods

class HFModel:
    def __init__(self, model_name):
        parts = model_name.split("/")
        self.friendly_name = parts[1]
        self.chat_history = []
        self.log_file = f"chat_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md"

        try:
            log_info(f"=== Loading Model: {self.friendly_name} ===")
            self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.bfloat16).cuda()
            self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
            log_info(f"=== Model Loaded Successfully: {self.friendly_name} ===")
        except Exception as e:
            log_error(f"ERROR Loading Model: {e}", e)
            raise

    def generate_response(self, input_text, max_length=100, skip_special_tokens=True):
        try:
            inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device)
            outputs = self.model.generate(**inputs, max_length=max_length)
            response = self.tokenizer.decode(outputs[0], skip_special_tokens=skip_special_tokens).strip()
            log_info(f"Generated Response: {response}")
            return response
        except Exception as e:
            log_error(f"ERROR Generating Response: {e}", e)
            raise

    def stream_response(self, input_text, max_length=100, skip_special_tokens=True):
        try:
            inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device)
            for output in self.model.generate(**inputs, max_length=max_length, do_stream=True):
                response = self.tokenizer.decode(output, skip_special_tokens=skip_special_tokens).strip()
                yield response
        except Exception as e:
            log_error(f"ERROR Streaming Response: {e}", e)
            raise

    def chat(self, user_input, max_length=100, skip_special_tokens=True):
        try:
            # Add user input to chat history
            self.chat_history.append({"role": "user", "content": user_input})
            log_info(f"User Input: {user_input}")

            # Generate model response
            model_response = self.generate_response(user_input, max_length=max_length, skip_special_tokens=skip_special_tokens)

            # Add model response to chat history
            self.chat_history.append({"role": "assistant", "content": model_response})
            log_info(f"Assistant Response: {model_response}")

            # Save chat log
            self.save_chat_log()

            return model_response
        except Exception as e:
            log_error(f"ERROR in Chat: {e}", e)
            raise

    def save_chat_log(self):
        try:
            with open(self.log_file, "a", encoding="utf-8") as f:
                for entry in self.chat_history[-2:]:  # Save only the latest interaction
                    role = entry["role"]
                    content = entry["content"]
                    f.write(f"**{role.capitalize()}:**\n\n{content}\n\n---\n\n")
            log_info(f"Chat log saved to {self.log_file}")
        except Exception as e:
            log_error(f"ERROR Saving Chat Log: {e}", e)
            raise

    def clear_chat_history(self):
        try:
            self.chat_history = []
            log_info("Chat history cleared.")
        except Exception as e:
            log_error(f"ERROR Clearing Chat History: {e}", e)
            raise

    def print_chat_history(self):
        try:
            for entry in self.chat_history:
                role = entry["role"]
                content = entry["content"]
                print(f"{role.capitalize()}: {content}\n")
            log_info("Printed chat history to console.")
        except Exception as e:
            log_error(f"ERROR Printing Chat History: {e}", e)
            raise