|
--- |
|
library_name: transformers |
|
tags: [] |
|
--- |
|
|
|
### Inference |
|
|
|
```python |
|
|
|
# Load model directly |
|
from transformers import AutoModelForCausalLM, GPT2Tokenizer |
|
|
|
tokenizer = GPT2Tokenizer.from_pretrained("suriya7/conversational-gpt-1") |
|
model = AutoModelForCausalLM.from_pretrained("suriya7/conversational-gpt-1") |
|
``` |
|
|
|
### Chatting |
|
|
|
```python |
|
|
|
import torch |
|
|
|
prompt = """ |
|
<|im_start|>system\nYou are a helpful AI assistant named Securitron, trained by Aquilax.<|im_end|> |
|
""" |
|
|
|
# Keep a list for the last one conversation exchanges |
|
conversation_history = [] |
|
|
|
while True: |
|
user_prompt = input("User Question: ") |
|
if user_prompt.lower() == 'break': |
|
break |
|
|
|
# Format the user's input |
|
user = f"""<|im_start|>user |
|
{user_prompt}<|im_end|>""" |
|
|
|
# Add the user's question to the conversation history |
|
conversation_history.append(user) |
|
|
|
# Ensure conversation starts with a user's input and keep only the last 2 exchanges (4 turns) |
|
conversation_history = conversation_history[-5:] |
|
|
|
# Build the full prompt |
|
current_prompt = prompt + "\n".join(conversation_history) |
|
|
|
# Tokenize the prompt |
|
encodeds = tokenizer(current_prompt, return_tensors="pt", truncation=True).input_ids |
|
|
|
# Move model and inputs to the appropriate device |
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
inputs = encodeds.to(device) |
|
|
|
# Create an empty list to store generated tokens |
|
generated_ids = inputs |
|
|
|
# Start generating tokens one by one |
|
assistant_response = "" |
|
# print("Assistant: ", end="", flush=True) # Print "Assistant:" once before streaming starts |
|
for _ in range(512): # Specify a max token limit for streaming |
|
# Generate the next token in the sequence |
|
next_token = model.generate( |
|
generated_ids, |
|
max_new_tokens=1, |
|
pad_token_id=50259, |
|
eos_token_id=50259, |
|
num_return_sequences=1, |
|
do_sample=True, # Use sampling for more diverse responses |
|
top_k=50, # Limit to the top-k tokens to sample from |
|
temperature=0.7, # Adjust temperature for randomness |
|
top_p =0.90 |
|
) |
|
|
|
# Add the generated token to the list |
|
generated_ids = torch.cat([generated_ids, next_token[:, -1:]], dim=1) |
|
|
|
# Decode the generated token (flatten it to a list of IDs) |
|
token_id = next_token[0, -1].item() # Extract the last token as an integer |
|
token = tokenizer.decode([token_id], skip_special_tokens=True) |
|
|
|
|
|
# Append the token to the ongoing response |
|
assistant_response += token |
|
print(token, end="", flush=True) # Stream the token in real time |
|
|
|
# If EOS token is encountered, stop generating |
|
if token_id == 50259: # EOS token |
|
break |
|
|
|
print() # Print a newline after streaming is complete |
|
|
|
# Add the assistant's response to the conversation history |
|
conversation_history.append(f"<|im_start|>{assistant_response.strip()}<|im_end|>") |
|
|
|
``` |
|
|
|
|
|
|
|
|