init commit
Browse files
README.md
CHANGED
@@ -46,18 +46,18 @@ import torch
|
|
46 |
from transformers import T5Tokenizer, T5Config, T5ForConditionalGeneration
|
47 |
|
48 |
# load tokenizer and model
|
49 |
-
pretrained_model = "
|
50 |
|
51 |
special_tokens = ["<extra_id_{}>".format(i) for i in range(100)]
|
52 |
tokenizer = T5Tokenizer.from_pretrained(
|
53 |
-
|
54 |
do_lower_case=True,
|
55 |
max_length=512,
|
56 |
truncation=True,
|
57 |
additional_special_tokens=special_tokens,
|
58 |
)
|
59 |
-
config = T5Config.from_pretrained(
|
60 |
-
model = T5ForConditionalGeneration.from_pretrained(
|
61 |
model.resize_token_embeddings(len(tokenizer))
|
62 |
model.eval()
|
63 |
|
@@ -66,8 +66,8 @@ text = "新闻分类任务:【微软披露拓扑量子计算机计划!】这
|
|
66 |
encode_dict = tokenizer(text, max_length=512, padding='max_length',truncation=True)
|
67 |
|
68 |
inputs = {
|
69 |
-
"input_ids": torch.tensor(encode_dict['input_ids']).long(),
|
70 |
-
"attention_mask": torch.tensor(encode_dict['attention_mask']).long(),
|
71 |
}
|
72 |
|
73 |
# generate answer
|
@@ -80,8 +80,9 @@ logits = model.generate(
|
|
80 |
|
81 |
logits=logits[:,1:]
|
82 |
predict_label = [tokenizer.decode(i,skip_special_tokens=True) for i in logits]
|
|
|
83 |
|
84 |
-
# model
|
85 |
```
|
86 |
|
87 |
## 引用 Citation
|
|
|
46 |
from transformers import T5Tokenizer, T5Config, T5ForConditionalGeneration
|
47 |
|
48 |
# load tokenizer and model
|
49 |
+
pretrained_model = "/cognitive_comp/wuxiaojun/pretrained/pytorch/huggingface/Randeng-T5-784M-MultiTask-Chinese"
|
50 |
|
51 |
special_tokens = ["<extra_id_{}>".format(i) for i in range(100)]
|
52 |
tokenizer = T5Tokenizer.from_pretrained(
|
53 |
+
pretrained_model,
|
54 |
do_lower_case=True,
|
55 |
max_length=512,
|
56 |
truncation=True,
|
57 |
additional_special_tokens=special_tokens,
|
58 |
)
|
59 |
+
config = T5Config.from_pretrained(pretrained_model)
|
60 |
+
model = T5ForConditionalGeneration.from_pretrained(pretrained_model, config=config)
|
61 |
model.resize_token_embeddings(len(tokenizer))
|
62 |
model.eval()
|
63 |
|
|
|
66 |
encode_dict = tokenizer(text, max_length=512, padding='max_length',truncation=True)
|
67 |
|
68 |
inputs = {
|
69 |
+
"input_ids": torch.tensor([encode_dict['input_ids']]).long(),
|
70 |
+
"attention_mask": torch.tensor([encode_dict['attention_mask']]).long(),
|
71 |
}
|
72 |
|
73 |
# generate answer
|
|
|
80 |
|
81 |
logits=logits[:,1:]
|
82 |
predict_label = [tokenizer.decode(i,skip_special_tokens=True) for i in logits]
|
83 |
+
print(predict_label)
|
84 |
|
85 |
+
# model output: 科技
|
86 |
```
|
87 |
|
88 |
## 引用 Citation
|