karimouda commited on
Commit
fa87815
·
verified ·
1 Parent(s): aae3044

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -12
app.py CHANGED
@@ -8,11 +8,11 @@ import torch
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
10
  DESCRIPTION = """\
11
- # SILMA Kashif 2B Instruct V1.0 Playgroud
12
-
13
- This is a demo of [`silma-ai/SILMA-Kashif-2B-Instruct-v1.0`](https://huggingface.co/silma-ai/SILMA-Kashif-2B-Instruct-v1.0).
14
-
15
- ** NOTE: this is a RAG model, it is only trained to answer questions based on context.
16
  """
17
 
18
  MAX_MAX_NEW_TOKENS = 2048
@@ -21,7 +21,7 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
21
 
22
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23
 
24
- model_id = "silma-ai/SILMA-Kashif-2B-Instruct-v1.0"
25
  tokenizer = AutoTokenizer.from_pretrained(model_id)
26
  model = AutoModelForCausalLM.from_pretrained(
27
  model_id,
@@ -37,7 +37,10 @@ def generate(
37
  message: str,
38
  chat_history: list[dict],
39
  max_new_tokens: int = 1024,
40
- temperature: float = 0.01,
 
 
 
41
  ) -> Iterator[str]:
42
  conversation = chat_history.copy()
43
  conversation.append({"role": "user", "content": message})
@@ -47,21 +50,24 @@ def generate(
47
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
48
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
49
  input_ids = input_ids.to(model.device)
50
- print(temperature)
51
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
52
  generate_kwargs = dict(
53
  {"input_ids": input_ids},
54
  streamer=streamer,
55
  max_new_tokens=max_new_tokens,
56
  do_sample=True,
57
- temperature=temperature
 
 
 
 
58
  )
59
  t = Thread(target=model.generate, kwargs=generate_kwargs)
60
  t.start()
61
- print(streamer)
62
  outputs = []
63
  for text in streamer:
64
- print(text)
65
  outputs.append(text)
66
  yield "".join(outputs)
67
 
@@ -108,6 +114,10 @@ demo = gr.ChatInterface(
108
  stop_btn=None,
109
  examples=[
110
  ["Hello there! How are you doing?"],
 
 
 
 
111
  ],
112
  cache_examples=False,
113
  type="messages",
@@ -118,4 +128,4 @@ demo = gr.ChatInterface(
118
 
119
 
120
  if __name__ == "__main__":
121
- demo.queue(max_size=20).launch()
 
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
10
  DESCRIPTION = """\
11
+ # Gemma 2 9B IT
12
+ Gemma 2 is Google's latest iteration of open LLMs.
13
+ This is a demo of [`google/gemma-2-9b-it`](https://huggingface.co/google/gemma-2-9b-it), fine-tuned for instruction following.
14
+ For more details, please check [our post](https://huggingface.co/blog/gemma2).
15
+ 👉 Looking for a larger and more powerful version? Try the 27B version in [HuggingChat](https://huggingface.co/chat/models/google/gemma-2-27b-it).
16
  """
17
 
18
  MAX_MAX_NEW_TOKENS = 2048
 
21
 
22
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23
 
24
+ model_id = "google/gemma-2-9b-it"
25
  tokenizer = AutoTokenizer.from_pretrained(model_id)
26
  model = AutoModelForCausalLM.from_pretrained(
27
  model_id,
 
37
  message: str,
38
  chat_history: list[dict],
39
  max_new_tokens: int = 1024,
40
+ temperature: float = 0.6,
41
+ top_p: float = 0.9,
42
+ top_k: int = 50,
43
+ repetition_penalty: float = 1.2,
44
  ) -> Iterator[str]:
45
  conversation = chat_history.copy()
46
  conversation.append({"role": "user", "content": message})
 
50
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
51
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
52
  input_ids = input_ids.to(model.device)
53
+
54
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
55
  generate_kwargs = dict(
56
  {"input_ids": input_ids},
57
  streamer=streamer,
58
  max_new_tokens=max_new_tokens,
59
  do_sample=True,
60
+ top_p=top_p,
61
+ top_k=top_k,
62
+ temperature=temperature,
63
+ num_beams=1,
64
+ repetition_penalty=repetition_penalty,
65
  )
66
  t = Thread(target=model.generate, kwargs=generate_kwargs)
67
  t.start()
68
+
69
  outputs = []
70
  for text in streamer:
 
71
  outputs.append(text)
72
  yield "".join(outputs)
73
 
 
114
  stop_btn=None,
115
  examples=[
116
  ["Hello there! How are you doing?"],
117
+ ["Can you explain briefly to me what is the Python programming language?"],
118
+ ["Explain the plot of Cinderella in a sentence."],
119
+ ["How many hours does it take a man to eat a Helicopter?"],
120
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
121
  ],
122
  cache_examples=False,
123
  type="messages",
 
128
 
129
 
130
  if __name__ == "__main__":
131
+ demo.queue(max_size=20).launch()