kdevoe commited on
Commit
75c6b52
·
verified ·
1 Parent(s): 2bafb40

Adding ability to switch between small, base and large models

Browse files
Files changed (1) hide show
  1. app.py +26 -9
app.py CHANGED
@@ -1,16 +1,25 @@
1
  import gradio as gr
2
  from transformers import T5Tokenizer, T5ForConditionalGeneration
3
  from langchain.memory import ConversationBufferMemory
 
4
 
5
- # Load the tokenizer and model for flan-t5
 
 
 
 
 
 
 
 
 
6
  tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
7
- model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base")
8
 
9
  # Set up conversational memory using LangChain's ConversationBufferMemory
10
  memory = ConversationBufferMemory()
11
 
12
- # Define the chatbot function with memory
13
- def chat_with_flan(input_text):
14
  # Retrieve conversation history and append the current user input
15
  conversation_history = memory.load_memory_variables({})['history']
16
 
@@ -20,6 +29,9 @@ def chat_with_flan(input_text):
20
  # Tokenize the input for the model
21
  input_ids = tokenizer.encode(full_input, return_tensors="pt")
22
 
 
 
 
23
  # Generate the response from the model
24
  outputs = model.generate(input_ids, max_length=200, num_return_sequences=1)
25
 
@@ -38,17 +50,22 @@ with gr.Blocks() as interface:
38
  # Add the instruction message above the input box
39
  gr.Markdown("**Instructions:** Press `Shift + Enter` to submit, and `Enter` for a new line.")
40
 
 
 
 
41
  # Input box for the user
42
  user_input = gr.Textbox(label="Your Input", placeholder="Type your message here...", lines=2, show_label=True)
43
 
44
- def update_chat(input_text, chat_history):
45
- updated_history = chat_with_flan(input_text)
 
46
  return updated_history, ""
47
 
48
  # Submit when pressing Enter
49
- user_input.submit(update_chat, inputs=[user_input, chatbot_output], outputs=[chatbot_output, user_input])
50
 
 
 
 
51
  # Launch the Gradio app
52
  interface.launch()
53
-
54
-
 
1
  import gradio as gr
2
  from transformers import T5Tokenizer, T5ForConditionalGeneration
3
  from langchain.memory import ConversationBufferMemory
4
+ import torch
5
 
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+
8
+ # Load all three Flan-T5 models (small, base, large)
9
+ models = {
10
+ "small": T5ForConditionalGeneration.from_pretrained("google/flan-t5-small").to(device),
11
+ "base": T5ForConditionalGeneration.from_pretrained("google/flan-t5-base").to(device),
12
+ "large": T5ForConditionalGeneration.from_pretrained("google/flan-t5-large").to(device)
13
+ }
14
+
15
+ # Load the tokenizer (same tokenizer for all models)
16
  tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
 
17
 
18
  # Set up conversational memory using LangChain's ConversationBufferMemory
19
  memory = ConversationBufferMemory()
20
 
21
+ # Define the chatbot function with memory and model size selection
22
+ def chat_with_flan(input_text, model_size):
23
  # Retrieve conversation history and append the current user input
24
  conversation_history = memory.load_memory_variables({})['history']
25
 
 
29
  # Tokenize the input for the model
30
  input_ids = tokenizer.encode(full_input, return_tensors="pt")
31
 
32
+ # Get the model based on the selected size
33
+ model = models[model_size]
34
+
35
  # Generate the response from the model
36
  outputs = model.generate(input_ids, max_length=200, num_return_sequences=1)
37
 
 
50
  # Add the instruction message above the input box
51
  gr.Markdown("**Instructions:** Press `Shift + Enter` to submit, and `Enter` for a new line.")
52
 
53
+ # Add a dropdown for selecting the model size (small, base, large)
54
+ model_selector = gr.Dropdown(choices=["small", "base", "large"], value="base", label="Select Model Size")
55
+
56
  # Input box for the user
57
  user_input = gr.Textbox(label="Your Input", placeholder="Type your message here...", lines=2, show_label=True)
58
 
59
+ # Define the function to update the chat based on selected model
60
+ def update_chat(input_text, model_size):
61
+ updated_history = chat_with_flan(input_text, model_size)
62
  return updated_history, ""
63
 
64
  # Submit when pressing Enter
65
+ user_input.submit(update_chat, inputs=[user_input, model_selector], outputs=[chatbot_output, user_input])
66
 
67
+ # Layout for model selector and chatbot UI
68
+ gr.Row([model_selector])
69
+
70
  # Launch the Gradio app
71
  interface.launch()