File size: 3,138 Bytes
432cd4a
00c98e3
3f3da62
432cd4a
eec50a8
 
0b2dc4c
abe2d0f
 
e929713
0e0341f
abe2d0f
 
0e0341f
8c068ee
 
eec50a8
8c068ee
0e0341f
eec50a8
8c068ee
0e0341f
 
c2d3107
e929713
 
 
00c98e3
81ab351
0e0341f
 
 
 
 
 
 
 
81ab351
eec50a8
0e0341f
 
803024c
 
81ab351
0e0341f
81ab351
0e0341f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
803024c
81ab351
00c98e3
0e0341f
3c7c10f
b84cd4b
14ddf0d
 
0e0341f
 
 
 
8ca2a6e
0e0341f
 
 
 
 
14ddf0d
 
b84cd4b
3f3da62
432cd4a
 
 
e929713
 
 
 
0e0341f
432cd4a
 
 
 
eec50a8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Determine device
device = "cuda" if torch.cuda.is_available() else "cpu"

model_id = "thrishala/mental_health_chatbot"

try:
    # Load model with appropriate settings
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="auto",
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        max_memory={0: "15GiB"} if torch.cuda.is_available() else None,
        offload_folder="offload",
    ).eval()

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.model_max_length = 4096  # Set to model's actual context length

except Exception as e:
    print(f"Error loading model: {e}")
    exit()

def generate_text_streaming(prompt, max_new_tokens=128):
    inputs = tokenizer(
        prompt, 
        return_tensors="pt", 
        truncation=True, 
        max_length=4096  # Match model's context length
    ).to(model.device)
    
    generated_tokens = []
    with torch.no_grad():
        for _ in range(max_new_tokens):
            outputs = model.generate(
                **inputs,
                max_new_tokens=1,
                do_sample=False,
                eos_token_id=tokenizer.eos_token_id,
                return_dict_in_generate=True
            )
            
            new_token = outputs.sequences[0, -1]
            generated_tokens.append(new_token)
            
            # Update inputs for next iteration
            inputs = {
                "input_ids": torch.cat([inputs["input_ids"], new_token.unsqueeze(0).unsqueeze(0)], dim=-1),
                "attention_mask": torch.cat([inputs["attention_mask"], torch.ones(1, 1, device=model.device)], dim=-1)
            }
            
            # Decode the accumulated tokens
            current_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
            yield current_text  # Yield the full text so far
            
            if new_token == tokenizer.eos_token_id:
                break

def respond(message, history, system_message, max_tokens):
    # Build prompt with full history
    prompt = f"{system_message}\n"
    for user_msg, bot_msg in history:
        prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n"
    prompt += f"User: {message}\nAssistant:"
    
    # Keep track of the full response
    full_response = ""
    
    try:
        for token_chunk in generate_text_streaming(prompt, max_tokens):
            # Update the full response and yield incremental changes
            full_response = token_chunk
            yield full_response
            
    except Exception as e:
        print(f"Error during generation: {e}")
        yield "An error occurred."

demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(
            value="You are a friendly and helpful mental health chatbot.",
            label="System message",
        ),
        gr.Slider(minimum=1, maximum=512, value=128, step=1, label="Max new tokens"),
    ],
)

if __name__ == "__main__":
    demo.launch()