File size: 3,061 Bytes
b053147 3dd65bd b053147 3dd65bd b053147 3dd65bd 8a2f8be b053147 f3fccb3 3dd65bd b053147 3dd65bd b053147 3dd65bd d013d1f 3dd65bd b053147 3e5a320 3dd65bd b053147 3dd65bd b053147 3dd65bd b053147 3dd65bd b053147 3dd65bd 8a2f8be b053147 3dd65bd 7c06fd6 b053147 3dd65bd 7c06fd6 b053147 3dd65bd b053147 7c06fd6 b053147 3dd65bd 7c06fd6 3dd65bd b053147 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
---
library_name: transformers
tags: []
---
### Inference
```python
# Load model directly
from transformers import AutoModelForCausalLM, GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("suriya7/conversational-gpt-1")
model = AutoModelForCausalLM.from_pretrained("suriya7/conversational-gpt-1")
```
### Chatting
```python
import torch
prompt = """
<|im_start|>system\nYou are a helpful AI assistant named Securitron, trained by Aquilax.<|im_end|>
"""
# Keep a list for the last one conversation exchanges
conversation_history = []
while True:
user_prompt = input("User Question: ")
if user_prompt.lower() == 'break':
break
# Format the user's input
user = f"""<|im_start|>user
{user_prompt}<|im_end|>"""
# Add the user's question to the conversation history
conversation_history.append(user)
# Ensure conversation starts with a user's input and keep only the last 2 exchanges (4 turns)
conversation_history = conversation_history[-5:]
# Build the full prompt
current_prompt = prompt + "\n".join(conversation_history)
# Tokenize the prompt
encodeds = tokenizer(current_prompt, return_tensors="pt", truncation=True).input_ids
# Move model and inputs to the appropriate device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
inputs = encodeds.to(device)
# Create an empty list to store generated tokens
generated_ids = inputs
# Start generating tokens one by one
assistant_response = ""
# print("Assistant: ", end="", flush=True) # Print "Assistant:" once before streaming starts
for _ in range(512): # Specify a max token limit for streaming
# Generate the next token in the sequence
next_token = model.generate(
generated_ids,
max_new_tokens=1,
pad_token_id=50259,
eos_token_id=50259,
num_return_sequences=1,
do_sample=True, # Use sampling for more diverse responses
top_k=50, # Limit to the top-k tokens to sample from
temperature=0.7, # Adjust temperature for randomness
top_p =0.90
)
# Add the generated token to the list
generated_ids = torch.cat([generated_ids, next_token[:, -1:]], dim=1)
# Decode the generated token (flatten it to a list of IDs)
token_id = next_token[0, -1].item() # Extract the last token as an integer
token = tokenizer.decode([token_id], skip_special_tokens=True)
# Append the token to the ongoing response
assistant_response += token
print(token, end="", flush=True) # Stream the token in real time
# If EOS token is encountered, stop generating
if token_id == 50259: # EOS token
break
print() # Print a newline after streaming is complete
# Add the assistant's response to the conversation history
conversation_history.append(f"<|im_start|>{assistant_response.strip()}<|im_end|>")
```
|