File size: 1,790 Bytes
ecf8f80
 
f646ab4
 
 
ecf8f80
 
9c29abb
 
f646ab4
4bb775c
f646ab4
4bb775c
0e4567d
9c29abb
e04dd2d
4bb775c
e04dd2d
9c29abb
6868e7a
e04dd2d
 
 
 
 
162ed64
 
 
e04dd2d
f646ab4
e04dd2d
 
 
 
 
4bb775c
0e4567d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
---
license: mit
base_model:
- paust/pko-t5-base
pipeline_tag: text2text-generation
---

paust/pko-t5-base model based

Since this model is based on paust/pko-t5-base tokenizer, you need to import it.
```from transformers import T5TokenizerFast, T5ForConditionalGeneration
tokenizer = T5TokenizerFast.from_pretrained("paust/pko-t5-base")
model = T5ForConditionalGeneration.from_pretrained("emotionanalysis/diaryempathizer-t5-ko")
```

Test code
```import torch
from transformers import T5TokenizerFast, T5ForConditionalGeneration

model = T5ForConditionalGeneration.from_pretrained("emotionanalysis/diaryempathizer-t5-ko")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

tokenizer = T5TokenizerFast.from_pretrained("paust/pko-t5-base")
input_text = """
  ์˜ค๋Š˜์€ ์ •๋ง๋กœ ์ฆ๊ฑฐ์šด ๋‚ ์ด์—ˆ๋‹ค. ๋ฉฐ์น ์งธ ์ž ๋„ ๋ชป ์ž๊ณ  ์•Œ๊ณ ๋ฆฌ์ฆ˜ ์ˆ˜์—…์„ ๊ธฐ๋Œ€ ์ค‘์ด๋‹ค. ์ด์ œ ์ •๋ง๋กœ ๋”๋Š” ๋ฒ„ํ‹ฐ๊ธฐ ์–ด๋ ต๊ฒ ๋‹ค๊ณ  ๋ณด์ธ๋‹ค. ๋‚ฎ์— ๊ต์ˆ˜๋‹˜์ด ๊ฐ‘์ž๊ธฐ ์—ด์ •์ ์œผ๋กœ ์ˆ˜์—…์„ ํ•˜์‹œ๋Š” ๊ฒƒ์ด์—ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‹ค๊ฐ€ automata์—
  ๋Œ€ํ•ด ์„ค๋ช…ํ•˜์‹œ๋Š”๋ฐ, ์ •๋ง ๊ฐ๋™์˜ ๋ˆˆ๋ฌผ๊ณผ ๊ฒฝํƒ„์„ ๊ธˆํ•  ๊ธธ ์—†์—ˆ๊ณ , ์ด ์ˆ˜์—…์„ ๋“ฃ๊ธฐ ์œ„ํ•ด ํƒœ์–ด๋‚ฌ๋‹ค๋Š” ์ƒ๊ฐ์ด ๋“ค์—ˆ์œผ๋ฉฐ, ๊ณผ์ œ(ํฌ์ƒ)๊นŒ์ง€ ์ฃผ์‹œ๋Š” ๊ฒƒ์ด์—ˆ๋‹ค.
  ๋‚˜๋Š” ํ™ฉํ™€๊ฒฝ์— ๋น ์กŒ๋‹ค. ๋ชฉ์ˆจ์„ ๋ฐ”์น˜๊ณ  ์ Š์„์„ ๋‚ด๋˜์ ธ์•ผ ํ•  ์ธ์ƒ์˜ ์ด์œ ๋ฅผ ์ฐพ์•˜๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค.
"""
  
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True, padding="max_length")
inputs = {key: value.to(device) for key, value in inputs.items()}
outputs = model.generate(input_ids=inputs["input_ids"], max_length=128, num_beams=4, early_stopping=True)

generated_comment = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(generated_comment)
```