File size: 4,183 Bytes
1d569c6
d5be079
 
1d569c6
d5be079
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d569c6
d5be079
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

# Load the Salesforce/codet5-base model and tokenizer
# We are using the 'Salesforce/codet5-base' model, which is a pre-trained model for code-related tasks.
# The AutoTokenizer and AutoModelForSeq2SeqLM classes from the Transformers library are used to load the model and tokenizer.
model_name = "Salesforce/codet5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# Function to generate code
# This function takes a prompt (code-related query) as input and generates code based on that prompt.
# It uses the loaded model and tokenizer to encode the input, generate the output, and then decode the generated text.
def generate_code(prompt, max_length=100):
    # Encode the input prompt using the tokenizer
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    
    # Generate the output using the model
    # The `model.generate()` function is used to generate the output sequence based on the input.
    # The `max_length` parameter sets the maximum length of the generated sequence.
    # The `num_return_sequences` parameter specifies the number of output sequences to be generated (in this case, 1).
    output = model.generate(input_ids, max_length=max_length, num_return_sequences=1)
    
    # Decode the generated output to get the actual code
    # The `tokenizer.decode()` function is used to convert the output token IDs back to readable text.
    # The `skip_special_tokens=True` argument ensures that any special tokens (e.g., start/end of sequence tokens) are removed from the output.
    generated_code = tokenizer.decode(output[0], skip_special_tokens=True)
    
    # Return the generated code
    return generated_code

# Function to handle chat interaction
# This function is responsible for managing the chat interaction between the user and the system.
# It takes the user's message and the chat history as input, and returns the system's response and the updated chat history.
def chat_interaction(message, history):
    # Initialize the chat history if it's not provided
    history = history or []
    
    # Generate the response using the `generate_code` function
    response = generate_code(message)
    
    # Update the chat history by appending the user's message and the system's response
    history.append((message, response))
    
    # Return the empty message (to clear the input field) and the updated chat history
    return "", history

# Create the Gradio interface
# The Gradio library is used to create an interactive web interface for the chat application.
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    # Add a Markdown title for the interface
    gr.Markdown("# S-Dreamer Salesforce/codet5-base Chat Interface")
    
    # Create a row with two columns
    with gr.Row():
        # Left column for the chat area
        with gr.Column(scale=3):
            # Add a chatbot component to display the chat history
            chatbot = gr.Chatbot(height=400)
            # Add a text input field for the user to enter messages
            message = gr.Textbox(label="Enter your code-related query", placeholder="Type your message here...")
            # Add a submit button
            submit_button = gr.Button("Submit")
        
        # Right column for the feature list
        with gr.Column(scale=1):
            # Add Markdown sections for the features
            gr.Markdown("## Features")
            gr.Markdown("- Code generation")
            gr.Markdown("- Code completion")
            gr.Markdown("- Code explanation")
            gr.Markdown("- Error correction")
            
            # Add a clear button to reset the chat
            clear_button = gr.Button("Clear Chat")
    
    # Connect the submit button to the `chat_interaction` function
    submit_button.click(chat_interaction, inputs=[message, chatbot], outputs=[message, chatbot])
    
    # Connect the clear button to a lambda function that clears the chat
    clear_button.click(lambda: None, outputs=[chatbot], inputs=[])

# Launch the Gradio interface
demo.launch()