File size: 4,847 Bytes
0a023ed 1d569c6 d5be079 1d569c6 0a023ed d5be079 0a023ed d5be079 0a023ed d5be079 0a023ed d5be079 0a023ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
# Define a class named `CodeGenerator` that will be responsible for generating code based on a given prompt.
class CodeGenerator:
# The constructor initializes the CodeGenerator object with a pre-trained model name.
# The default model name is "Salesforce/codet5-base".
def __init__(self, model_name="Salesforce/codet5-base"):
# Load the pre-trained tokenizer from the specified model name.
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
# Load the pre-trained sequence-to-sequence language model from the specified model name.
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# This method generates code based on the given prompt.
# The method takes two parameters: `prompt` (the input text) and `max_length` (the maximum length of the generated code).
def generate_code(self, prompt, max_length=100):
# Encode the prompt into input IDs that the model can understand.
input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
# Generate the output sequence using the pre-trained model.
# The `generate` method takes the input IDs, the maximum length of the output, and the number of output sequences to return (in this case, 1).
output = self.model.generate(input_ids, max_length=max_length, num_return_sequences=1)
# Decode the output sequence and return the generated code.
return self.tokenizer.decode(output[0], skip_special_tokens=True)
# Define a class named `ChatHandler` that will be responsible for managing the chat history.
class ChatHandler:
# The constructor initializes the ChatHandler object with an empty chat history.
def __init__(self):
self.history = []
# This method handles incoming messages and generates responses using the provided CodeGenerator.
# The method takes two parameters: `message` (the user's input message) and `code_generator` (an instance of the CodeGenerator class).
def handle_message(self, message, code_generator):
# Generate the response using the provided CodeGenerator.
response = code_generator.generate_code(message)
# Append the message-response pair to the chat history.
self.history.append((message, response))
# Return the empty message input and the updated chat history.
return "", self.history
# Define a function named `create_gradio_interface` that creates a Gradio interface for the chat application.
def create_gradio_interface():
# Create an instance of the CodeGenerator class.
code_generator = CodeGenerator()
# Create an instance of the ChatHandler class.
chat_handler = ChatHandler()
# Create a Gradio Blocks interface with a soft theme.
with gr.Blocks(theme=gr.themes.Soft()) as demo:
# Display a Markdown title for the chat interface.
gr.Markdown("# S-Dreamer Salesforce/codet5-base Chat Interface")
# Create a row with two columns.
with gr.Row():
# The first column will contain the chat interface.
with gr.Column(scale=3):
# Create a chatbot component to display the chat history.
chatbot = gr.Chatbot(height=400)
# Create a textbox for the user to input their message.
message_input = gr.Textbox(label="Enter your code-related query", placeholder="Type your message here...")
# Create a submit button to send the message.
submit_button = gr.Button("Submit")
# The second column will contain the features.
with gr.Column(scale=1):
# Display a Markdown title for the features section.
gr.Markdown("## Features")
# Define a list of features.
features = ["Code generation", "Code completion", "Code explanation", "Error correction"]
# Display each feature as a Markdown list item.
for feature in features:
gr.Markdown(f"- {feature}")
# Create a button to clear the chat history.
clear_button = gr.Button("Clear Chat")
# Connect the submit button to the `handle_message` method of the ChatHandler.
submit_button.click(chat_handler.handle_message, inputs=[message_input], outputs=[message_input, chatbot])
# Connect the clear button to a function that clears the chat history.
clear_button.click(lambda: None, outputs=[chatbot], inputs=[])
# Launch the Gradio interface.
demo.launch()
# This is the entry point of the application.
if __name__ == "__main__":
# Call the `create_gradio_interface` function to start the chat application.
create_gradio_interface()
|