ETHEREUM / app.py
ndwdgda's picture
Update app.py
08ab954 verified
import logging
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Checkpoint paths
model_checkpoint_path = "model_checkpoint.pth"
tokenizer_checkpoint_path = "tokenizer_checkpoint.pth"
# Load model and tokenizer from checkpoint if they exist
if os.path.exists(model_checkpoint_path) and os.path.exists(tokenizer_checkpoint_path):
try:
model = torch.load(model_checkpoint_path)
tokenizer = torch.load(tokenizer_checkpoint_path)
logger.info("Model and tokenizer loaded from checkpoint.")
except Exception as e:
logger.error(f"Failed to load model or tokenizer from checkpoint: {e}")
raise
else:
# Load model directly
try:
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b")
logger.info("Model and tokenizer loaded successfully.")
except Exception as e:
logger.error(f"Failed to load model or tokenizer: {e}")
raise
def respond(user_input, history, system_message, max_tokens=20, temperature=0.9, top_p=0.9):
messages = [{"role": "system", "content": system_message}]
messages.extend(history)
messages.append({"role": "user", "content": user_input})
# Convert messages to a single string
input_text = " ".join([msg["content"] for msg in messages])
# Tokenize the input text
inputs = tokenizer(input_text, return_tensors="pt")
# Generate attention mask
attention_mask = inputs["attention_mask"]
# Generate text using the model
outputs = model.generate(
inputs.input_ids,
attention_mask=attention_mask,
max_length=max_tokens,
temperature=temperature,
top_p=top_p,
pad_token_id=tokenizer.eos_token_id,
do_sample=True
)
# Decode the generated text
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
if __name__ == "__main__":
print("Welcome to the Chatbot!")
while True:
user_input = input("You: ")
system_message = "Chatbot: "
history = [{"role": "assistant", "content": "Hello, how can I assist you today?"}]
response = respond(user_input, history, system_message)
print(response)