duzx16
commited on
Commit
·
debaf00
1
Parent(s):
3ba9437
Fix Chinese punctuation
Browse files- modeling_chatglm.py +18 -4
modeling_chatglm.py
CHANGED
|
@@ -4,6 +4,7 @@ import math
|
|
| 4 |
import copy
|
| 5 |
import os
|
| 6 |
import warnings
|
|
|
|
| 7 |
|
| 8 |
import torch
|
| 9 |
import torch.utils.checkpoint
|
|
@@ -1099,6 +1100,21 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1099 |
for layer_past in past
|
| 1100 |
)
|
| 1101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1102 |
@torch.no_grad()
|
| 1103 |
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
|
| 1104 |
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
|
|
@@ -1121,8 +1137,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1121 |
outputs = self.generate(**input_ids, **gen_kwargs)
|
| 1122 |
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
| 1123 |
response = tokenizer.decode(outputs)
|
| 1124 |
-
response =
|
| 1125 |
-
response = response.replace("[[训练时间]]", "2023年")
|
| 1126 |
history = history + [(query, response)]
|
| 1127 |
return response, history
|
| 1128 |
|
|
@@ -1148,8 +1163,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1148 |
for outputs in self.stream_generate(**input_ids, **gen_kwargs):
|
| 1149 |
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
| 1150 |
response = tokenizer.decode(outputs)
|
| 1151 |
-
response =
|
| 1152 |
-
response = response.replace("[[训练时间]]", "2023年")
|
| 1153 |
new_history = history + [(query, response)]
|
| 1154 |
yield response, new_history
|
| 1155 |
|
|
|
|
| 4 |
import copy
|
| 5 |
import os
|
| 6 |
import warnings
|
| 7 |
+
import re
|
| 8 |
|
| 9 |
import torch
|
| 10 |
import torch.utils.checkpoint
|
|
|
|
| 1100 |
for layer_past in past
|
| 1101 |
)
|
| 1102 |
|
| 1103 |
+
def process_response(self, response):
|
| 1104 |
+
response = response.strip()
|
| 1105 |
+
response = response.replace("[[训练时间]]", "2023年")
|
| 1106 |
+
punkts = [
|
| 1107 |
+
[",", ","],
|
| 1108 |
+
["!", "!"],
|
| 1109 |
+
[":", ":"],
|
| 1110 |
+
[";", ";"],
|
| 1111 |
+
["\?", "?"],
|
| 1112 |
+
]
|
| 1113 |
+
for item in punkts:
|
| 1114 |
+
response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
|
| 1115 |
+
response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
|
| 1116 |
+
return response
|
| 1117 |
+
|
| 1118 |
@torch.no_grad()
|
| 1119 |
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
|
| 1120 |
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
|
|
|
|
| 1137 |
outputs = self.generate(**input_ids, **gen_kwargs)
|
| 1138 |
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
| 1139 |
response = tokenizer.decode(outputs)
|
| 1140 |
+
response = self.process_response(response)
|
|
|
|
| 1141 |
history = history + [(query, response)]
|
| 1142 |
return response, history
|
| 1143 |
|
|
|
|
| 1163 |
for outputs in self.stream_generate(**input_ids, **gen_kwargs):
|
| 1164 |
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
| 1165 |
response = tokenizer.decode(outputs)
|
| 1166 |
+
response = self.process_response(response)
|
|
|
|
| 1167 |
new_history = history + [(query, response)]
|
| 1168 |
yield response, new_history
|
| 1169 |
|