chat / lib /generate.py
adamelliotfields's picture
Add config
3eb01b6 verified
from threading import Thread
from typing import Iterator
from spaces import GPU, config
from transformers import TextIteratorStreamer
from .loader import get_loader
@GPU
def generate(
message: str,
chat_history: list[dict[str, str]],
system_message="",
model="Qwen/Qwen2.5-0.5B-Instruct",
max_tokens=512,
temperature=0.6,
repetition_penalty=1.2,
top_p=0.9,
top_k=50,
) -> Iterator[str]:
# Prepend system prompt
if not chat_history or chat_history[0].get("role") != "system":
chat_history.insert(0, {"role": "system", "content": system_message})
else:
chat_history[0]["content"] = system_message
# Append user message before generating
chat_history.append({"role": "user", "content": message})
yield from transformers_generate(
chat_history,
model,
max_tokens,
temperature,
repetition_penalty,
top_p,
top_k,
)
def transformers_generate(
chat_history: list[dict[str, str]],
model: str,
max_tokens: int,
temperature: float,
repetition_penalty: float,
top_p: float,
top_k: int,
) -> Iterator[str]:
loader = get_loader(singleton=not config.Config.zero_gpu)
loader.load(model)
llm = loader.llm
tokenizer = loader.tokenizer
# Handle models that don't have a padding token
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
# https://huggingface.co/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.apply_chat_template
results = tokenizer.apply_chat_template(
chat_history,
tokenize=True,
return_dict=True, # get the attention mask
return_tensors="pt",
# https://huggingface.co/docs/transformers/chat_templating#what-are-generation-prompts
add_generation_prompt=True,
)
input_ids = results["input_ids"].to(llm.device)
attention_mask = results["attention_mask"].to(llm.device)
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
# https://huggingface.co/blog/how-to-generate
generate_kwargs = dict(
do_sample=True,
streamer=streamer,
input_ids=input_ids,
attention_mask=attention_mask,
pad_token_id=tokenizer.pad_token_id,
top_p=top_p,
top_k=top_k,
temperature=temperature,
max_new_tokens=max_tokens,
repetition_penalty=repetition_penalty,
)
# Stream text off the main thread
t = Thread(target=llm.generate, kwargs=generate_kwargs)
t.start()
# Collect output tokens
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)