Spaces:
Running
on
L40S
Running
on
L40S
#!/usr/bin/env python | |
import os | |
from threading import Thread | |
from typing import Iterator | |
import gradio as gr | |
import spaces | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig | |
from peft import PeftModel, PeftConfig | |
DESCRIPTION = "# Mistral-7B-CyGr v0.2" | |
MAX_MAX_NEW_TOKENS = 2048 | |
DEFAULT_MAX_NEW_TOKENS = 256 | |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "512")) | |
if not torch.cuda.is_available(): | |
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>" | |
else: | |
base_model = "ilsp/Meltemi-7B-Instruct-v1.5" | |
adapter_model = "CYENS/mistral-cygr-10epochs" | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit = True, | |
bnb_4bit_use_double_quant = True, | |
bnb_4bit_quant_type = "nf4", | |
bnb_4bit_compute_dtype = torch.bfloat16 | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
base_model, | |
quantization_config=bnb_config, | |
device_map="auto", | |
token=os.getenv("HF_TOKEN") | |
) | |
cygr_model = PeftModel.from_pretrained(model, adapter_model, token=os.getenv("HF_TOKEN")) | |
tokenizer = AutoTokenizer.from_pretrained(base_model, token=os.getenv("HF_TOKEN")) | |
tokenizer.padding_side = "left" | |
#model = model.to("cuda") | |
model.eval() | |
#cygr_model = cygr_model.to("cuda") | |
cygr_model.eval() | |
def respond( | |
message: str, | |
chat_history: list[tuple[str, str]], | |
llm_choice: str, | |
system_message: str, | |
max_new_tokens: int = 1024, | |
temperature: float = 0.6, | |
top_p: float = 0.9, | |
top_k: int = 50, | |
repetition_penalty: float = 1.2, | |
) -> Iterator[str]: | |
conversation = [{"role": "system", "content": system_message}] | |
for user, assistant in chat_history: | |
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]) | |
conversation.append({"role": "user", "content": message}) | |
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt") | |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") | |
input_ids = input_ids.to(model.device) if llm_choice == "Greek Meltemi" else input_ids.to(cygr_model.device) | |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = dict( | |
{"input_ids": input_ids}, | |
streamer=streamer, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
num_beams=1, | |
repetition_penalty=repetition_penalty, | |
) | |
# TODO: use llm_choice to decide what to generate | |
t = None | |
if llm_choice == "Greek Meltemi": | |
print("using Meltemi") | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
elif llm_choice == "Cypriot Meltemi": | |
print("using fine-tuned Meltemi") | |
t = Thread(target=cygr_model.generate, kwargs=generate_kwargs) | |
else: | |
raise Exception("Not a valid LLM.") | |
t.start() | |
outputs = [] | |
for text in streamer: | |
outputs.append(text) | |
yield "".join(outputs) | |
chat_interface = gr.ChatInterface( | |
fn=respond, | |
additional_inputs=[ | |
gr.Radio(['Greek Meltemi', 'Cypriot Meltemi'], value='Greek Meltemi', label='LLM'), | |
gr.Textbox( | |
value="Είσαι ένα γλωσσικό μοντέλο για την κυπριακή γλώσσα.", label="System message" | |
), | |
gr.Slider( | |
label="Max new tokens", | |
minimum=1, | |
maximum=MAX_MAX_NEW_TOKENS, | |
step=1, | |
value=DEFAULT_MAX_NEW_TOKENS, | |
), | |
gr.Slider( | |
label="Temperature", | |
minimum=0.1, | |
maximum=1.0, | |
step=0.1, | |
value=1.0, | |
), | |
gr.Slider( | |
label="Top-p (nucleus sampling)", | |
minimum=0.05, | |
maximum=1.0, | |
step=0.05, | |
value=0.95, | |
), | |
gr.Slider( | |
label="Top-k", | |
minimum=1, | |
maximum=100, | |
step=1, | |
value=50, | |
), | |
gr.Slider( | |
label="Repetition penalty", | |
minimum=1.0, | |
maximum=2.0, | |
step=0.05, | |
value=1.2, | |
), | |
] | |
#stop_btn=None | |
) | |
""" | |
- Redirect the user to https://huggingface.co/oauth/authorize?redirect_uri={REDIRECT_URI}&scope=openid%20profile&client_id={CLIENT_ID}&state={STATE}, | |
where STATE is a random string that you will need to verify later. | |
- Handle the callback on /auth/callback or /login/callback (or your own custom callback URL) and verify the state parameter. | |
- Use the code query parameter to get an access token and id token from https://huggingface.co/oauth/token (POST request with client_id, code, | |
grant_type=authorization_code and redirect_uri as form data, and with Authorization: Basic {base64(client_id:client_secret)} as a header). | |
""" | |
with gr.Blocks(css="style.css") as demo: | |
#gr.Markdown(DESCRIPTION) | |
"""gr.DuplicateButton( | |
value="Duplicate Space for private use", | |
elem_id="duplicate-button", | |
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1", | |
)""" | |
#gr.LoginButton().activate() | |
chat_interface.render() | |
if __name__ == "__main__": | |
demo.launch() | |
#demo.queue(max_size=20).launch(share=True) |