motaunginc / import torch
MandlaZwane's picture
first com
a164db5
raw
history blame contribute delete
852 Bytes
import torch
from transformers import BertTokenizer, BertForSequenceClassification
# Load the model
model = BertForSequenceClassification.from_pretrained("Shanks.pth")
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
def chat():
while True:
user_input = input("You: ")
if user_input == "quit":
break
inputs = tokenizer(user_input, return_tensors='pt')
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class_id = torch.argmax(logits, dim=-1).item()
# Map the predicted class ID to a response
# Replace this with your actual mapping logic
response = {0: "Response A", 1: "Response B", 2: "Response C"}[predicted_class_id]
print(f"Bot: {response}")
if __name__ == "__main__":
chat()