Spaces:
Running
Running
File size: 3,138 Bytes
432cd4a 00c98e3 3f3da62 432cd4a eec50a8 0b2dc4c abe2d0f e929713 0e0341f abe2d0f 0e0341f 8c068ee eec50a8 8c068ee 0e0341f eec50a8 8c068ee 0e0341f c2d3107 e929713 00c98e3 81ab351 0e0341f 81ab351 eec50a8 0e0341f 803024c 81ab351 0e0341f 81ab351 0e0341f 803024c 81ab351 00c98e3 0e0341f 3c7c10f b84cd4b 14ddf0d 0e0341f 8ca2a6e 0e0341f 14ddf0d b84cd4b 3f3da62 432cd4a e929713 0e0341f 432cd4a eec50a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Determine device
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "thrishala/mental_health_chatbot"
try:
# Load model with appropriate settings
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
max_memory={0: "15GiB"} if torch.cuda.is_available() else None,
offload_folder="offload",
).eval()
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.model_max_length = 4096 # Set to model's actual context length
except Exception as e:
print(f"Error loading model: {e}")
exit()
def generate_text_streaming(prompt, max_new_tokens=128):
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=4096 # Match model's context length
).to(model.device)
generated_tokens = []
with torch.no_grad():
for _ in range(max_new_tokens):
outputs = model.generate(
**inputs,
max_new_tokens=1,
do_sample=False,
eos_token_id=tokenizer.eos_token_id,
return_dict_in_generate=True
)
new_token = outputs.sequences[0, -1]
generated_tokens.append(new_token)
# Update inputs for next iteration
inputs = {
"input_ids": torch.cat([inputs["input_ids"], new_token.unsqueeze(0).unsqueeze(0)], dim=-1),
"attention_mask": torch.cat([inputs["attention_mask"], torch.ones(1, 1, device=model.device)], dim=-1)
}
# Decode the accumulated tokens
current_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
yield current_text # Yield the full text so far
if new_token == tokenizer.eos_token_id:
break
def respond(message, history, system_message, max_tokens):
# Build prompt with full history
prompt = f"{system_message}\n"
for user_msg, bot_msg in history:
prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n"
prompt += f"User: {message}\nAssistant:"
# Keep track of the full response
full_response = ""
try:
for token_chunk in generate_text_streaming(prompt, max_tokens):
# Update the full response and yield incremental changes
full_response = token_chunk
yield full_response
except Exception as e:
print(f"Error during generation: {e}")
yield "An error occurred."
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(
value="You are a friendly and helpful mental health chatbot.",
label="System message",
),
gr.Slider(minimum=1, maximum=512, value=128, step=1, label="Max new tokens"),
],
)
if __name__ == "__main__":
demo.launch() |