import gradio as gr from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch # Define a class named `CodeGenerator` that will be responsible for generating code based on a given prompt. class CodeGenerator: # The constructor initializes the CodeGenerator object with a pre-trained model name. # The default model name is "Salesforce/codet5-base". def __init__(self, model_name="Salesforce/codet5-base"): # Load the pre-trained tokenizer from the specified model name. self.tokenizer = AutoTokenizer.from_pretrained(model_name) # Load the pre-trained sequence-to-sequence language model from the specified model name. self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) # This method generates code based on the given prompt. # The method takes two parameters: `prompt` (the input text) and `max_length` (the maximum length of the generated code). def generate_code(self, prompt, max_length=100): # Encode the prompt into input IDs that the model can understand. input_ids = self.tokenizer.encode(prompt, return_tensors="pt") # Generate the output sequence using the pre-trained model. # The `generate` method takes the input IDs, the maximum length of the output, and the number of output sequences to return (in this case, 1). output = self.model.generate(input_ids, max_length=max_length, num_return_sequences=1) # Decode the output sequence and return the generated code. return self.tokenizer.decode(output[0], skip_special_tokens=True) # Define a class named `ChatHandler` that will be responsible for managing the chat history. class ChatHandler: # The constructor initializes the ChatHandler object with an empty chat history. def __init__(self): self.history = [] # This method handles incoming messages and generates responses using the provided CodeGenerator. # The method takes two parameters: `message` (the user's input message) and `code_generator` (an instance of the CodeGenerator class). def handle_message(self, message, code_generator): # Generate the response using the provided CodeGenerator. response = code_generator.generate_code(message) # Append the message-response pair to the chat history. self.history.append((message, response)) # Return the empty message input and the updated chat history. return "", self.history # Define a function named `create_gradio_interface` that creates a Gradio interface for the chat application. def create_gradio_interface(): # Create an instance of the CodeGenerator class. code_generator = CodeGenerator() # Create an instance of the ChatHandler class. chat_handler = ChatHandler() # Create a Gradio Blocks interface with a soft theme. with gr.Blocks(theme=gr.themes.Soft()) as demo: # Display a Markdown title for the chat interface. gr.Markdown("# S-Dreamer Salesforce/codet5-base Chat Interface") # Create a row with two columns. with gr.Row(): # The first column will contain the chat interface. with gr.Column(scale=3): # Create a chatbot component to display the chat history. chatbot = gr.Chatbot(height=400) # Create a textbox for the user to input their message. message_input = gr.Textbox(label="Enter your code-related query", placeholder="Type your message here...") # Create a submit button to send the message. submit_button = gr.Button("Submit") # The second column will contain the features. with gr.Column(scale=1): # Display a Markdown title for the features section. gr.Markdown("## Features") # Define a list of features. features = ["Code generation", "Code completion", "Code explanation", "Error correction"] # Display each feature as a Markdown list item. for feature in features: gr.Markdown(f"- {feature}") # Create a button to clear the chat history. clear_button = gr.Button("Clear Chat") # Connect the submit button to the `handle_message` method of the ChatHandler. submit_button.click(chat_handler.handle_message, inputs=[message_input], outputs=[message_input, chatbot]) # Connect the clear button to a function that clears the chat history. clear_button.click(lambda: None, outputs=[chatbot], inputs=[]) # Launch the Gradio interface. demo.launch() # This is the entry point of the application. if __name__ == "__main__": # Call the `create_gradio_interface` function to start the chat application. create_gradio_interface()