pro-grammer commited on
Commit
1a5c0c8
·
verified ·
1 Parent(s): fce2a2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -38
app.py CHANGED
@@ -1,52 +1,110 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- client = InferenceClient("pro-grammer/MindfulAI")
5
-
6
- def respond(
7
- message,
8
- history: list[tuple[str, str]],
9
- system_message,
10
- max_tokens,
11
- temperature,
12
- top_p,
13
- ):
14
- prompt = system_message.strip() + "\n\n"
15
- for user_msg, assistant_msg in history:
16
- if user_msg:
17
- prompt += f"User: {user_msg.strip()}\n"
18
- if assistant_msg:
19
- prompt += f"Assistant: {assistant_msg.strip()}\n"
20
- prompt += f"User: {message.strip()}\nAssistant:"
 
 
 
 
 
 
 
 
21
 
22
- response = ""
23
 
24
- for token in client.text_generation(
25
- prompt,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  max_new_tokens=max_tokens,
27
- stream=True,
28
  temperature=temperature,
29
  top_p=top_p,
30
- ):
31
- token_text = token.choices[0].text
32
- response += token_text
33
- yield response
 
 
34
 
35
  demo = gr.ChatInterface(
36
- respond,
 
 
37
  additional_inputs=[
38
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
39
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
40
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
41
- gr.Slider(
42
- minimum=0.1,
43
- maximum=1.0,
44
- value=0.95,
45
- step=0.05,
46
- label="Top-p (nucleus sampling)",
47
- ),
48
  ],
49
  )
50
 
51
  if __name__ == "__main__":
52
- demo.launch(share=True)
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+
5
+ # --- Model Initialization ---
6
+
7
+ # Paths for tokenizer and your model checkpoint
8
+ tokenizer_path = "facebook/opt-1.3b"
9
+ model_path = "transfer_learning_therapist.pth"
10
+
11
+ # Load tokenizer and set pad token if needed
12
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
13
+ if tokenizer.pad_token is None:
14
+ tokenizer.pad_token = tokenizer.eos_token
15
+
16
+ # Set device
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ print(f"Using device: {device}")
19
+
20
+ # Load the base model and then update with your checkpoint
21
+ model = AutoModelForCausalLM.from_pretrained(tokenizer_path)
22
+ checkpoint = torch.load(model_path, map_location=device)
23
+ model_dict = model.state_dict()
24
+ pretrained_dict = {k: v for k, v in checkpoint['model_state_dict'].items() if k in model_dict}
25
+ model_dict.update(pretrained_dict)
26
+ model.load_state_dict(model_dict)
27
+ model.to(device)
28
+ model.eval()
29
 
30
+ # --- Inference Function ---
31
 
32
+ def generate_response(prompt, max_new_tokens=150, temperature=0.7, top_p=0.9, repetition_penalty=1.2):
33
+ """Generates a response from your model based on the prompt."""
34
+ model.eval()
35
+ model.config.use_cache = True
36
+
37
+ prompt = prompt.strip()
38
+ if not prompt:
39
+ return "Please provide a valid input."
40
+
41
+ # Tokenize the input prompt
42
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
43
+
44
+ try:
45
+ with torch.no_grad():
46
+ outputs = model.generate(
47
+ inputs.input_ids,
48
+ attention_mask=inputs.attention_mask,
49
+ max_new_tokens=max_new_tokens,
50
+ temperature=temperature,
51
+ top_p=top_p,
52
+ do_sample=True,
53
+ pad_token_id=tokenizer.pad_token_id,
54
+ eos_token_id=tokenizer.eos_token_id,
55
+ repetition_penalty=repetition_penalty,
56
+ num_beams=1, # greedy decoding
57
+ no_repeat_ngram_size=3, # avoid repeated phrases
58
+ )
59
+ except Exception as e:
60
+ return f"Error generating response: {e}"
61
+ finally:
62
+ model.config.use_cache = False
63
+
64
+ full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
65
+ # If your prompt is formatted with role markers (e.g., "Therapist:"), extract only that part:
66
+ if "Therapist:" in full_response:
67
+ therapist_response = full_response.split("Therapist:")[-1].strip()
68
+ else:
69
+ therapist_response = full_response.strip()
70
+ return therapist_response
71
+
72
+ # --- Gradio Interface Function ---
73
+
74
+ def respond(message, history, system_message, max_tokens, temperature, top_p):
75
+ """
76
+ Build the conversation context by combining the system message and the dialogue history,
77
+ then generate a new response from the model.
78
+ """
79
+ # Create a conversation prompt with your desired role labels.
80
+ conversation = f"System: {system_message}\n"
81
+ for user_msg, assistant_msg in history:
82
+ conversation += f"Human: {user_msg}\nTherapist: {assistant_msg}\n"
83
+ conversation += f"Human: {message}\nTherapist:"
84
+
85
+ response = generate_response(
86
+ conversation,
87
  max_new_tokens=max_tokens,
 
88
  temperature=temperature,
89
  top_p=top_p,
90
+ )
91
+
92
+ history.append((message, response))
93
+ return history, history
94
+
95
+ # --- Gradio ChatInterface Setup ---
96
 
97
  demo = gr.ChatInterface(
98
+ fn=respond,
99
+ title="MindfulAI Chat",
100
+ description="Chat with MindfulAI – an AI Therapist powered by your custom model.",
101
  additional_inputs=[
102
+ gr.Textbox(value="You are a friendly AI Therapist.", label="System message"),
103
+ gr.Slider(minimum=1, maximum=2048, value=150, step=1, label="Max new tokens"),
104
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
105
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
 
 
 
 
 
 
106
  ],
107
  )
108
 
109
  if __name__ == "__main__":
110
+ demo.launch()