File size: 2,387 Bytes
08ab954 e3ce821 08ab954 e3ce821 08ab954 e3ce821 08ab954 e3ce821 08ab954 0dc884d e3ce821 08ab954 e3ce821 08ab954 e3ce821 08ab954 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import logging
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Checkpoint paths
model_checkpoint_path = "model_checkpoint.pth"
tokenizer_checkpoint_path = "tokenizer_checkpoint.pth"
# Load model and tokenizer from checkpoint if they exist
if os.path.exists(model_checkpoint_path) and os.path.exists(tokenizer_checkpoint_path):
try:
model = torch.load(model_checkpoint_path)
tokenizer = torch.load(tokenizer_checkpoint_path)
logger.info("Model and tokenizer loaded from checkpoint.")
except Exception as e:
logger.error(f"Failed to load model or tokenizer from checkpoint: {e}")
raise
else:
# Load model directly
try:
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b")
logger.info("Model and tokenizer loaded successfully.")
except Exception as e:
logger.error(f"Failed to load model or tokenizer: {e}")
raise
def respond(user_input, history, system_message, max_tokens=20, temperature=0.9, top_p=0.9):
messages = [{"role": "system", "content": system_message}]
messages.extend(history)
messages.append({"role": "user", "content": user_input})
# Convert messages to a single string
input_text = " ".join([msg["content"] for msg in messages])
# Tokenize the input text
inputs = tokenizer(input_text, return_tensors="pt")
# Generate attention mask
attention_mask = inputs["attention_mask"]
# Generate text using the model
outputs = model.generate(
inputs.input_ids,
attention_mask=attention_mask,
max_length=max_tokens,
temperature=temperature,
top_p=top_p,
pad_token_id=tokenizer.eos_token_id,
do_sample=True
)
# Decode the generated text
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
if __name__ == "__main__":
print("Welcome to the Chatbot!")
while True:
user_input = input("You: ")
system_message = "Chatbot: "
history = [{"role": "assistant", "content": "Hello, how can I assist you today?"}]
response = respond(user_input, history, system_message)
print(response) |