File size: 4,847 Bytes
0a023ed
1d569c6
d5be079
 
1d569c6
0a023ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5be079
0a023ed
 
 
 
 
 
 
 
 
 
 
d5be079
0a023ed
 
 
 
d5be079
0a023ed
 
d5be079
0a023ed
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

# Define a class named `CodeGenerator` that will be responsible for generating code based on a given prompt.
class CodeGenerator:
    # The constructor initializes the CodeGenerator object with a pre-trained model name.
    # The default model name is "Salesforce/codet5-base".
    def __init__(self, model_name="Salesforce/codet5-base"):
        # Load the pre-trained tokenizer from the specified model name.
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        # Load the pre-trained sequence-to-sequence language model from the specified model name.
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

    # This method generates code based on the given prompt.
    # The method takes two parameters: `prompt` (the input text) and `max_length` (the maximum length of the generated code).
    def generate_code(self, prompt, max_length=100):
        # Encode the prompt into input IDs that the model can understand.
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
        # Generate the output sequence using the pre-trained model.
        # The `generate` method takes the input IDs, the maximum length of the output, and the number of output sequences to return (in this case, 1).
        output = self.model.generate(input_ids, max_length=max_length, num_return_sequences=1)
        # Decode the output sequence and return the generated code.
        return self.tokenizer.decode(output[0], skip_special_tokens=True)

# Define a class named `ChatHandler` that will be responsible for managing the chat history.
class ChatHandler:
    # The constructor initializes the ChatHandler object with an empty chat history.
    def __init__(self):
        self.history = []

    # This method handles incoming messages and generates responses using the provided CodeGenerator.
    # The method takes two parameters: `message` (the user's input message) and `code_generator` (an instance of the CodeGenerator class).
    def handle_message(self, message, code_generator):
        # Generate the response using the provided CodeGenerator.
        response = code_generator.generate_code(message)
        # Append the message-response pair to the chat history.
        self.history.append((message, response))
        # Return the empty message input and the updated chat history.
        return "", self.history

# Define a function named `create_gradio_interface` that creates a Gradio interface for the chat application.
def create_gradio_interface():
    # Create an instance of the CodeGenerator class.
    code_generator = CodeGenerator()
    # Create an instance of the ChatHandler class.
    chat_handler = ChatHandler()

    # Create a Gradio Blocks interface with a soft theme.
    with gr.Blocks(theme=gr.themes.Soft()) as demo:
        # Display a Markdown title for the chat interface.
        gr.Markdown("# S-Dreamer Salesforce/codet5-base Chat Interface")

        # Create a row with two columns.
        with gr.Row():
            # The first column will contain the chat interface.
            with gr.Column(scale=3):
                # Create a chatbot component to display the chat history.
                chatbot = gr.Chatbot(height=400)
                # Create a textbox for the user to input their message.
                message_input = gr.Textbox(label="Enter your code-related query", placeholder="Type your message here...")
                # Create a submit button to send the message.
                submit_button = gr.Button("Submit")

            # The second column will contain the features.
            with gr.Column(scale=1):
                # Display a Markdown title for the features section.
                gr.Markdown("## Features")
                # Define a list of features.
                features = ["Code generation", "Code completion", "Code explanation", "Error correction"]
                # Display each feature as a Markdown list item.
                for feature in features:
                    gr.Markdown(f"- {feature}")
                # Create a button to clear the chat history.
                clear_button = gr.Button("Clear Chat")

        # Connect the submit button to the `handle_message` method of the ChatHandler.
        submit_button.click(chat_handler.handle_message, inputs=[message_input], outputs=[message_input, chatbot])
        # Connect the clear button to a function that clears the chat history.
        clear_button.click(lambda: None, outputs=[chatbot], inputs=[])

    # Launch the Gradio interface.
    demo.launch()

# This is the entry point of the application.
if __name__ == "__main__":
    # Call the `create_gradio_interface` function to start the chat application.
    create_gradio_interface()