Small change to chat prompt
Browse files- modeling_internlm.py +3 -5
modeling_internlm.py
CHANGED
|
@@ -96,7 +96,7 @@ class InternLMRotaryEmbedding(torch.nn.Module):
|
|
| 96 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
| 97 |
super().__init__()
|
| 98 |
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
| 99 |
-
self.register_buffer("inv_freq", inv_freq)
|
| 100 |
|
| 101 |
# Build here to make `torch.jit.trace` work.
|
| 102 |
self.max_seq_len_cached = max_position_embeddings
|
|
@@ -769,9 +769,7 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
|
|
| 769 |
def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = []):
|
| 770 |
prompt = ""
|
| 771 |
for record in history:
|
| 772 |
-
prompt += f"""
|
| 773 |
-
if len(prompt) == 0:
|
| 774 |
-
prompt += "<s>"
|
| 775 |
prompt += f"""<|User|>:{query}<eoh>\n<|Bot|>:"""
|
| 776 |
return tokenizer([prompt], return_tensors="pt")
|
| 777 |
|
|
@@ -995,4 +993,4 @@ class InternLMForSequenceClassification(InternLMPreTrainedModel):
|
|
| 995 |
past_key_values=transformer_outputs.past_key_values,
|
| 996 |
hidden_states=transformer_outputs.hidden_states,
|
| 997 |
attentions=transformer_outputs.attentions,
|
| 998 |
-
)
|
|
|
|
| 96 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
| 97 |
super().__init__()
|
| 98 |
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
| 99 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 100 |
|
| 101 |
# Build here to make `torch.jit.trace` work.
|
| 102 |
self.max_seq_len_cached = max_position_embeddings
|
|
|
|
| 769 |
def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = []):
|
| 770 |
prompt = ""
|
| 771 |
for record in history:
|
| 772 |
+
prompt += f"""<|User|>:{record[0]}<eoh>\n<|Bot|>:{record[1]}<eoa>\n"""
|
|
|
|
|
|
|
| 773 |
prompt += f"""<|User|>:{query}<eoh>\n<|Bot|>:"""
|
| 774 |
return tokenizer([prompt], return_tensors="pt")
|
| 775 |
|
|
|
|
| 993 |
past_key_values=transformer_outputs.past_key_values,
|
| 994 |
hidden_states=transformer_outputs.hidden_states,
|
| 995 |
attentions=transformer_outputs.attentions,
|
| 996 |
+
)
|