import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import torch | |
# Define a class named `CodeGenerator` that will be responsible for generating code based on a given prompt. | |
class CodeGenerator: | |
# The constructor initializes the CodeGenerator object with a pre-trained model name. | |
# The default model name is "Salesforce/codet5-base". | |
def __init__(self, model_name="Salesforce/codet5-base"): | |
# Load the pre-trained tokenizer from the specified model name. | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Load the pre-trained sequence-to-sequence language model from the specified model name. | |
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
# This method generates code based on the given prompt. | |
# The method takes two parameters: `prompt` (the input text) and `max_length` (the maximum length of the generated code). | |
def generate_code(self, prompt, max_length=100): | |
# Encode the prompt into input IDs that the model can understand. | |
input_ids = self.tokenizer.encode(prompt, return_tensors="pt") | |
# Generate the output sequence using the pre-trained model. | |
# The `generate` method takes the input IDs, the maximum length of the output, and the number of output sequences to return (in this case, 1). | |
output = self.model.generate(input_ids, max_length=max_length, num_return_sequences=1) | |
# Decode the output sequence and return the generated code. | |
return self.tokenizer.decode(output[0], skip_special_tokens=True) | |
# Define a class named `ChatHandler` that will be responsible for managing the chat history. | |
class ChatHandler: | |
# The constructor initializes the ChatHandler object with an empty chat history. | |
def __init__(self): | |
self.history = [] | |
# This method handles incoming messages and generates responses using the provided CodeGenerator. | |
# The method takes two parameters: `message` (the user's input message) and `code_generator` (an instance of the CodeGenerator class). | |
def handle_message(self, message, code_generator): | |
# Generate the response using the provided CodeGenerator. | |
response = code_generator.generate_code(message) | |
# Append the message-response pair to the chat history. | |
self.history.append((message, response)) | |
# Return the empty message input and the updated chat history. | |
return "", self.history | |
# Define a function named `create_gradio_interface` that creates a Gradio interface for the chat application. | |
def create_gradio_interface(): | |
# Create an instance of the CodeGenerator class. | |
code_generator = CodeGenerator() | |
# Create an instance of the ChatHandler class. | |
chat_handler = ChatHandler() | |
# Create a Gradio Blocks interface with a soft theme. | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
# Display a Markdown title for the chat interface. | |
gr.Markdown("# S-Dreamer Salesforce/codet5-base Chat Interface") | |
# Create a row with two columns. | |
with gr.Row(): | |
# The first column will contain the chat interface. | |
with gr.Column(scale=3): | |
# Create a chatbot component to display the chat history. | |
chatbot = gr.Chatbot(height=400) | |
# Create a textbox for the user to input their message. | |
message_input = gr.Textbox(label="Enter your code-related query", placeholder="Type your message here...") | |
# Create a submit button to send the message. | |
submit_button = gr.Button("Submit") | |
# The second column will contain the features. | |
with gr.Column(scale=1): | |
# Display a Markdown title for the features section. | |
gr.Markdown("## Features") | |
# Define a list of features. | |
features = ["Code generation", "Code completion", "Code explanation", "Error correction"] | |
# Display each feature as a Markdown list item. | |
for feature in features: | |
gr.Markdown(f"- {feature}") | |
# Create a button to clear the chat history. | |
clear_button = gr.Button("Clear Chat") | |
# Connect the submit button to the `handle_message` method of the ChatHandler. | |
submit_button.click(chat_handler.handle_message, inputs=[message_input], outputs=[message_input, chatbot]) | |
# Connect the clear button to a function that clears the chat history. | |
clear_button.click(lambda: None, outputs=[chatbot], inputs=[]) | |
# Launch the Gradio interface. | |
demo.launch() | |
# This is the entry point of the application. | |
if __name__ == "__main__": | |
# Call the `create_gradio_interface` function to start the chat application. | |
create_gradio_interface() | |