suolyer commited on
Commit
6c4d83b
·
1 Parent(s): 3c9e67b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +32 -0
README.md CHANGED
@@ -27,6 +27,38 @@ model = RoFormerModel.from_pretrained("IDEA-CCNL/Zhouwenwang-110M")
27
 
28
 
29
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  ## Scores on downstream chinese tasks (without any data augmentation)
31
  | Model| afqmc | tnews | iflytek | ocnli | cmnli | wsc | csl |
32
  | :--------: | :-----: | :----: | :-----: | :----: | :----: | :----: | :----: |
 
27
 
28
 
29
  ```
30
+
31
+ ### Generate task
32
+ You can use Zhouwenwang-110M to continue writing
33
+
34
+ ```python
35
+ from model.roformer.modeling_roformer import RoFormerModel
36
+ from transformers import AutoTokenizer
37
+ import torch
38
+ import numpy as np
39
+
40
+ sentence = '清华大学位于'
41
+ max_length = 32
42
+ model_pretrained_weight_path = '/home/' # 预训练模型权重路径
43
+
44
+ tokenizer = AutoTokenizer.from_pretrained(model_pretrained_weight_path)
45
+ model = RoFormerModel.from_pretrained(model_pretrained_weight_path)
46
+
47
+ for i in range(max_length):
48
+ encode = torch.tensor(
49
+ [[tokenizer.cls_token_id]+tokenizer.encode(sentence, add_special_tokens=False)]).long()
50
+ logits = model(encode)[0]
51
+ logits = torch.nn.functional.linear(
52
+ logits, model.embeddings.word_embeddings.weight)
53
+ logits = torch.nn.functional.softmax(
54
+ logits, dim=-1).cpu().detach().numpy()[0]
55
+ sentence = sentence + \
56
+ tokenizer.decode(int(np.random.choice(logits.shape[1], p=logits[-1])))
57
+ if sentence[-1] == '。':
58
+ break
59
+ print(sentence)
60
+ ```
61
+
62
  ## Scores on downstream chinese tasks (without any data augmentation)
63
  | Model| afqmc | tnews | iflytek | ocnli | cmnli | wsc | csl |
64
  | :--------: | :-----: | :----: | :-----: | :----: | :----: | :----: | :----: |