kdevoe commited on
Commit
07b00c0
·
verified ·
1 Parent(s): 0e4ab50

Converting model to DistilGPT2

Browse files
Files changed (1) hide show
  1. app.py +20 -15
app.py CHANGED
@@ -1,28 +1,32 @@
1
  import gradio as gr
2
- from transformers import T5Tokenizer, T5ForConditionalGeneration
 
3
  from langchain.memory import ConversationBufferMemory
4
- from langchain.prompts import PromptTemplate
5
 
6
- # Load the tokenizer and model for flan-t5
7
- tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
8
- model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base")
 
 
 
 
9
 
10
  # Set up conversational memory using LangChain's ConversationBufferMemory
11
  memory = ConversationBufferMemory()
12
 
13
  # Define the chatbot function with memory
14
- def chat_with_flan(input_text):
15
  # Retrieve conversation history and append the current user input
16
  conversation_history = memory.load_memory_variables({})['history']
17
 
18
  # Combine the history with the current user input
19
  full_input = f"{conversation_history}\nUser: {input_text}\nAssistant:"
20
 
21
- # Tokenize the input for the model
22
- input_ids = tokenizer.encode(full_input, return_tensors="pt")
23
 
24
- # Generate the response from the model
25
- outputs = model.generate(input_ids, max_length=200, num_return_sequences=1)
26
 
27
  # Decode the model output
28
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
@@ -34,13 +38,14 @@ def chat_with_flan(input_text):
34
 
35
  # Set up the Gradio interface
36
  interface = gr.Interface(
37
- fn=chat_with_flan,
38
- inputs=gr.Textbox(label="Chat with FLAN-T5"),
39
- outputs=gr.Textbox(label="FLAN-T5's Response"),
40
- title="FLAN-T5 Chatbot with Memory",
41
- description="This is a simple chatbot powered by the FLAN-T5 model with conversational memory, using LangChain.",
42
  )
43
 
44
  # Launch the Gradio app
45
  interface.launch()
46
 
 
 
1
  import gradio as gr
2
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
3
+ import torch
4
  from langchain.memory import ConversationBufferMemory
 
5
 
6
+ # Load the tokenizer and model for DistilGPT-2
7
+ tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
8
+ model = GPT2LMHeadModel.from_pretrained("distilgpt2")
9
+
10
+ # Move model to device (GPU if available)
11
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
12
+ model.to(device)
13
 
14
  # Set up conversational memory using LangChain's ConversationBufferMemory
15
  memory = ConversationBufferMemory()
16
 
17
  # Define the chatbot function with memory
18
+ def chat_with_distilgpt2(input_text):
19
  # Retrieve conversation history and append the current user input
20
  conversation_history = memory.load_memory_variables({})['history']
21
 
22
  # Combine the history with the current user input
23
  full_input = f"{conversation_history}\nUser: {input_text}\nAssistant:"
24
 
25
+ # Tokenize the input and convert to tensor
26
+ input_ids = tokenizer.encode(full_input, return_tensors="pt").to(device)
27
 
28
+ # Generate the response using the model
29
+ outputs = model.generate(input_ids, max_length=200, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id)
30
 
31
  # Decode the model output
32
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
38
 
39
  # Set up the Gradio interface
40
  interface = gr.Interface(
41
+ fn=chat_with_distilgpt2,
42
+ inputs=gr.Textbox(label="Chat with DistilGPT-2"),
43
+ outputs=gr.Textbox(label="DistilGPT-2's Response"),
44
+ title="DistilGPT-2 Chatbot with Memory",
45
+ description="This is a simple chatbot powered by the DistilGPT-2 model with conversational memory, using LangChain.",
46
  )
47
 
48
  # Launch the Gradio app
49
  interface.launch()
50
 
51
+