|
--- |
|
library_name: transformers |
|
license: apache-2.0 |
|
--- |
|
|
|
## INFERENCE |
|
|
|
```python |
|
# Load model directly |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("AquilaX-AI/QnA") |
|
model = AutoModelForCausalLM.from_pretrained("AquilaX-AI/QnA") |
|
|
|
prompt = """ |
|
<|im_start|>system\nYou are a helpful AI assistant named Securitron<|im_end|> |
|
""" |
|
|
|
# Keep a list for the last one conversation exchanges |
|
conversation_history = [] |
|
|
|
while True: |
|
user_prompt = input("\nUser Question: ") |
|
if user_prompt.lower() == 'break': |
|
break |
|
|
|
# Format the user's input |
|
user = f"""<|im_start|>user |
|
{user_prompt}<|im_end|> |
|
<|im_start|>assistant""" |
|
|
|
# Add the user's question to the conversation history |
|
conversation_history.append(user) |
|
|
|
# Ensure conversation starts with a user's input and keep only the last 2 exchanges (4 turns) |
|
conversation_history = conversation_history[-5:] |
|
|
|
# Build the full prompt |
|
current_prompt = prompt + "\n".join(conversation_history) |
|
|
|
# Tokenize the prompt |
|
encodeds = tokenizer(current_prompt, return_tensors="pt", truncation=True).input_ids |
|
|
|
# Move model and inputs to the appropriate device |
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
inputs = encodeds.to(device) |
|
|
|
# Create an empty list to store generated tokens |
|
generated_ids = inputs |
|
|
|
# Start generating tokens one by one |
|
assistant_response = "" |
|
for _ in range(512): # Specify a max token limit for streaming |
|
next_token = model.generate( |
|
generated_ids, |
|
max_new_tokens=1, |
|
pad_token_id=151644, |
|
eos_token_id=151645, |
|
num_return_sequences=1, |
|
do_sample=False, |
|
# top_k=5, |
|
# temperature=0.2, |
|
# top_p=0.90 |
|
) |
|
|
|
generated_ids = torch.cat([generated_ids, next_token[:, -1:]], dim=1) |
|
token_id = next_token[0, -1].item() |
|
token = tokenizer.decode([token_id], skip_special_tokens=True) |
|
|
|
assistant_response += token |
|
print(token, end="", flush=True) |
|
|
|
if token_id == 151645: # EOS token |
|
break |
|
|
|
conversation_history.append(f"{assistant_response.strip()}<|im_end|>") |
|
``` |