Mr-Vicky-01 commited on
Commit
4eb3214
·
verified ·
1 Parent(s): 1daf042

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +27 -40
README.md CHANGED
@@ -6,20 +6,25 @@ license: apache-2.0
6
  ## INFERENCE
7
 
8
  ```python
9
- # Load model directly
10
- from transformers import AutoModelForCausalLM, AutoTokenizer
11
  import torch
12
 
 
13
  tokenizer = AutoTokenizer.from_pretrained("AquilaX-AI/QnA")
14
  model = AutoModelForCausalLM.from_pretrained("AquilaX-AI/QnA")
15
 
 
16
  prompt = """
17
  <|im_start|>system\nYou are a helpful AI assistant named Securitron<|im_end|>
18
  """
19
 
20
- # Keep a list for the last one conversation exchanges
21
  conversation_history = []
22
 
 
 
 
 
23
  while True:
24
  user_prompt = input("\nUser Question: ")
25
  if user_prompt.lower() == 'break':
@@ -33,47 +38,29 @@ while True:
33
  # Add the user's question to the conversation history
34
  conversation_history.append(user)
35
 
36
- # Ensure conversation starts with a user's input and keep only the last 2 exchanges (4 turns)
37
  conversation_history = conversation_history[-5:]
38
 
39
  # Build the full prompt
40
  current_prompt = prompt + "\n".join(conversation_history)
41
 
42
  # Tokenize the prompt
43
- encodeds = tokenizer(current_prompt, return_tensors="pt", truncation=True).input_ids
44
-
45
- # Move model and inputs to the appropriate device
46
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
47
- model.to(device)
48
- inputs = encodeds.to(device)
49
-
50
- # Create an empty list to store generated tokens
51
- generated_ids = inputs
52
-
53
- # Start generating tokens one by one
54
- assistant_response = ""
55
- for _ in range(512): # Specify a max token limit for streaming
56
- next_token = model.generate(
57
- generated_ids,
58
- max_new_tokens=1,
59
- pad_token_id=151644,
60
- eos_token_id=151645,
61
- num_return_sequences=1,
62
- do_sample=False,
63
- # top_k=5,
64
- # temperature=0.2,
65
- # top_p=0.90
66
- )
67
-
68
- generated_ids = torch.cat([generated_ids, next_token[:, -1:]], dim=1)
69
- token_id = next_token[0, -1].item()
70
- token = tokenizer.decode([token_id], skip_special_tokens=True)
71
-
72
- assistant_response += token
73
- print(token, end="", flush=True)
74
-
75
- if token_id == 151645: # EOS token
76
- break
77
-
78
- conversation_history.append(f"{assistant_response.strip()}<|im_end|>")
79
  ```
 
6
  ## INFERENCE
7
 
8
  ```python
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
 
10
  import torch
11
 
12
+ # Load model and tokenizer
13
  tokenizer = AutoTokenizer.from_pretrained("AquilaX-AI/QnA")
14
  model = AutoModelForCausalLM.from_pretrained("AquilaX-AI/QnA")
15
 
16
+ # Define the system prompt
17
  prompt = """
18
  <|im_start|>system\nYou are a helpful AI assistant named Securitron<|im_end|>
19
  """
20
 
21
+ # Initialize conversation history
22
  conversation_history = []
23
 
24
+ # Set up device
25
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
26
+ model.to(device)
27
+
28
  while True:
29
  user_prompt = input("\nUser Question: ")
30
  if user_prompt.lower() == 'break':
 
38
  # Add the user's question to the conversation history
39
  conversation_history.append(user)
40
 
41
+ # Keep only the last 2 exchanges (4 turns)
42
  conversation_history = conversation_history[-5:]
43
 
44
  # Build the full prompt
45
  current_prompt = prompt + "\n".join(conversation_history)
46
 
47
  # Tokenize the prompt
48
+ encodeds = tokenizer(current_prompt, return_tensors="pt", truncation=True).input_ids.to(device)
49
+
50
+ # Initialize TextStreamer for real-time token generation
51
+ text_streamer = TextStreamer(tokenizer, skip_prompt=True)
52
+
53
+ # Generate response with TextStreamer
54
+ response = model.generate(
55
+ input_ids=encodeds,
56
+ streamer=text_streamer,
57
+ max_new_tokens=512,
58
+ use_cache=True,
59
+ pad_token_id=151645,
60
+ eos_token_id=151645,
61
+ num_return_sequences=1
62
+ )
63
+
64
+ # Finalize conversation history with the assistant's response
65
+ conversation_history.append(tokenizer.decode(response[0]).split('<|im_start|>assistant')[-1].split('<|im_end|>')[0].strip() + "<|im_end|>")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  ```