metastable-void commited on
Commit
e38ab6b
·
1 Parent(s): 6c44471
Files changed (2) hide show
  1. README.md +2 -3
  2. app.py +18 -36
README.md CHANGED
@@ -1,12 +1,11 @@
1
  ---
2
- title: 真空ジェネレータ
3
  emoji: ⚡
4
  colorFrom: red
5
  colorTo: purple
6
  python_version: 3.11
7
  models:
8
- - llm-jp/llm-jp-3-1.8b-instruct
9
- - vericava/llm-jp-3-1.8b-instruct-lora-vericava17
10
  sdk: gradio
11
  sdk_version: 5.23.1
12
  app_file: app.py
 
1
  ---
2
+ title: 真空ジェネレータ v2
3
  emoji: ⚡
4
  colorFrom: red
5
  colorTo: purple
6
  python_version: 3.11
7
  models:
8
+ - vericava/gpt2-medium-vericava-posts-v3
 
9
  sdk: gradio
10
  sdk_version: 5.23.1
11
  app_file: app.py
app.py CHANGED
@@ -21,19 +21,9 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "32768"))
21
 
22
 
23
  if torch.cuda.is_available():
24
- model_id = "vericava/llm-jp-3-1.8b-instruct-lora-vericava17"
25
- base_model_id = "llm-jp/llm-jp-3-1.8b-instruct"
26
- tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True)
27
- tokenizer.chat_template = "{{bos_token}}{% for message in messages %}{% if message['role'] == 'user' %}{{ '\\n\\n### 前の投稿:\\n' + message['content'] + '' }}{% elif message['role'] == 'system' %}{{ '以下は、SNS上の投稿です。あなたはSNSの投稿生成botとして、次に続く投稿を考えなさい。説明はせず、投稿の内容のみを鉤括弧をつけずに答えよ。' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '\\n\\n### 次の投稿:\\n' + message['content'] + eos_token }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '\\n\\n### 次の投稿:\\n' }}{% endif %}{% endfor %}"
28
- model = AutoModelForCausalLM.from_pretrained(
29
- base_model_id,
30
- trust_remote_code=True,
31
- )
32
- model.load_adapter(model_id)
33
  my_pipeline=pipeline(
34
  task="text-generation",
35
- model=model,
36
- tokenizer=tokenizer,
37
  do_sample=True,
38
  num_beams=1,
39
  )
@@ -49,37 +39,29 @@ def generate(
49
  top_k: int = 50,
50
  repetition_penalty: float = 1.0,
51
  ) -> Iterator[str]:
52
- from datetime import datetime, timezone, timedelta
53
-
54
- d=datetime.now(timezone(timedelta(hours=9), 'JST'))
55
- m=d.month
56
- if m < 3 or m > 11:
57
- season = '冬'
58
- elif m < 6:
59
- season = '春'
60
- elif m < 9:
61
- season = '夏'
62
- else:
63
- season = '秋'
64
-
65
- h=d.hour
66
- go = '午前' if h < 12 else '午後'
67
- h = h % 12
68
- minute = d.minute
69
- time = go + str(h) + '時' + str(minute) + '分'
70
-
71
- messages = [
72
- {"role": "system", "content": "なお今は日本の" + season + "で、時刻は" + time + "であるものとする。また、あなたは真空という名前のユーザであるとする。"},
73
- {"role": "user", "content": message},
74
- ]
75
 
76
  output = my_pipeline(
77
- messages,
78
  temperature=temperature,
79
  max_new_tokens=max_new_tokens,
 
 
 
80
  )
81
  print(output)
82
- yield output[-1]["generated_text"][-1]["content"]
 
 
 
83
 
84
  demo = gr.ChatInterface(
85
  fn=generate,
 
21
 
22
 
23
  if torch.cuda.is_available():
 
 
 
 
 
 
 
 
 
24
  my_pipeline=pipeline(
25
  task="text-generation",
26
+ model="vericava/gpt2-medium-vericava-posts-v3",
 
27
  do_sample=True,
28
  num_beams=1,
29
  )
 
39
  top_k: int = 50,
40
  repetition_penalty: float = 1.0,
41
  ) -> Iterator[str]:
42
+ user_input = " ".join(message.strip().split("\n"))
43
+
44
+ user_input = user_input if (
45
+ user_input.endswith("。")
46
+ or user_input.endswith("?")
47
+ or user_input.endswith("!")
48
+ or user_input.endswith("?")
49
+ or user_input.endswith("!")
50
+ ) else user_input + "。"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  output = my_pipeline(
53
+ user_input,
54
  temperature=temperature,
55
  max_new_tokens=max_new_tokens,
56
+ repetition_penalty=repetition_penalty,
57
+ top_k=top_k,
58
+ top_p=top_p,
59
  )
60
  print(output)
61
+ gen_text = output[len(user_input):]
62
+ gen_text = gen_text[:gen_text.find("\n")] if "\n" in gen_text else gen_text
63
+ gen_text = gen_text[:(gen_text.rfind("。") + 1)] if "。" in gen_text else gen_text
64
+ yield gen_text
65
 
66
  demo = gr.ChatInterface(
67
  fn=generate,