|
import logging |
|
import os |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
model_checkpoint_path = "model_checkpoint.pth" |
|
tokenizer_checkpoint_path = "tokenizer_checkpoint.pth" |
|
|
|
|
|
if os.path.exists(model_checkpoint_path) and os.path.exists(tokenizer_checkpoint_path): |
|
try: |
|
model = torch.load(model_checkpoint_path) |
|
tokenizer = torch.load(tokenizer_checkpoint_path) |
|
logger.info("Model and tokenizer loaded from checkpoint.") |
|
except Exception as e: |
|
logger.error(f"Failed to load model or tokenizer from checkpoint: {e}") |
|
raise |
|
else: |
|
|
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") |
|
model = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b") |
|
logger.info("Model and tokenizer loaded successfully.") |
|
except Exception as e: |
|
logger.error(f"Failed to load model or tokenizer: {e}") |
|
raise |
|
|
|
def respond(user_input, history, system_message, max_tokens=20, temperature=0.9, top_p=0.9): |
|
messages = [{"role": "system", "content": system_message}] |
|
messages.extend(history) |
|
messages.append({"role": "user", "content": user_input}) |
|
|
|
|
|
input_text = " ".join([msg["content"] for msg in messages]) |
|
|
|
|
|
inputs = tokenizer(input_text, return_tensors="pt") |
|
|
|
|
|
attention_mask = inputs["attention_mask"] |
|
|
|
|
|
outputs = model.generate( |
|
inputs.input_ids, |
|
attention_mask=attention_mask, |
|
max_length=max_tokens, |
|
temperature=temperature, |
|
top_p=top_p, |
|
pad_token_id=tokenizer.eos_token_id, |
|
do_sample=True |
|
) |
|
|
|
|
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
return response |
|
|
|
if __name__ == "__main__": |
|
print("Welcome to the Chatbot!") |
|
while True: |
|
user_input = input("You: ") |
|
system_message = "Chatbot: " |
|
history = [{"role": "assistant", "content": "Hello, how can I assist you today?"}] |
|
response = respond(user_input, history, system_message) |
|
print(response) |