File size: 2,008 Bytes
eacbb97
 
 
 
d676cb8
eacbb97
 
 
6c3e1ec
eacbb97
 
 
 
 
 
 
 
 
 
6c3e1ec
eacbb97
6c3e1ec
eacbb97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c3e1ec
eacbb97
6c3e1ec
 
 
 
d676cb8
 
 
 
 
6c3e1ec
 
eacbb97
6c3e1ec
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import gradio as gr
from hqq.engine.hf import HQQModelForCausalLM, AutoTokenizer
import torch, transformers
from threading import Thread
import time

#Load the model
model_id = 'mobiuslabsgmbh/Llama-2-7b-chat-hf_1bitgs8_hqq' 
model     = HQQModelForCausalLM.from_quantized(model_id, adapter='adapter_v0.1.lora', device='cuda')
tokenizer = AutoTokenizer.from_pretrained(model_id)

#Setup Inference Mode
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
if not tokenizer.pad_token: tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model.config.use_cache  = True
model.eval();

# Optional: torch compile for faster inference
model = torch.compile(model)

def chat_processor(chat, max_new_tokens=100, do_sample=True, device='cuda'):
    tokenizer.use_default_system_prompt = False
    streamer = transformers.TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)

    generate_params = dict(
        tokenizer("<s> [INST] " + chat + " [/INST] ", return_tensors="pt").to(device),
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        pad_token_id=tokenizer.pad_token_id,
        top_p=0.90 if do_sample else None,
        top_k=50 if do_sample else None,
        temperature= 0.6 if do_sample else None,
        num_beams=1,
        repetition_penalty=1.2,
    )

    t = Thread(target=model.generate, kwargs=generate_params)
    t.start()
    
    #print("User: ", chat); 
    #print("Assistant: ");
    #outputs = ""
    #for text in streamer:
    #    outputs += text
    #    print(text, end="", flush=True)

    #torch.cuda.empty_cache()
  
    return t, streamer

def chat(message, history):
    t, stream = chat_processor(chat=message)
    response = ""
    for character in stream:
        if character is not None:
            response += character
            # print(character)
            yield response
    time.sleep(0.1)
    t.join()
    torch.cuda.empty_cache()

gr.ChatInterface(chat).launch()