|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
import torch |
|
|
|
class CodeGenerator: |
|
def __init__(self, model_name="Salesforce/codet5-base", device=None): |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
if device: |
|
self.model = self.model.to(device) |
|
|
|
def generate_code(self, prompt, max_length=100): |
|
try: |
|
input_ids = self.tokenizer.encode(prompt, return_tensors="pt") |
|
output = self.model.generate(input_ids, max_length=max_length, num_return_sequences=1) |
|
return self.tokenizer.decode(output[0], skip_special_tokens=True) |
|
except Exception as e: |
|
return f"Error generating code: {str(e)}" |
|
|
|
class ChatHandler: |
|
def __init__(self, code_generator): |
|
self.history = [] |
|
self.code_generator = code_generator |
|
|
|
def handle_message(self, message): |
|
if not message.strip(): |
|
return "", self.history |
|
response = self.code_generator.generate_code(message) |
|
self.history.append((message, response)) |
|
return "", self.history |
|
|
|
def clear_history(self): |
|
self.history = [] |
|
return [] |
|
|
|
def create_gradio_interface(): |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
code_generator = CodeGenerator(device=device) |
|
chat_handler = ChatHandler(code_generator) |
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("# S-Dreamer Salesforce/codet5-base Chat Interface") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
chatbot = gr.Chatbot(height=400) |
|
message_input = gr.Textbox(label="Enter your code-related query", placeholder="Type your message here...") |
|
submit_button = gr.Button("Submit") |
|
|
|
with gr.Column(scale=1): |
|
gr.Markdown("## Features") |
|
features = ["Code generation", "Code completion", "Code explanation", "Error correction"] |
|
for feature in features: |
|
gr.Markdown(f"- {feature}") |
|
clear_button = gr.Button("Clear Chat") |
|
|
|
submit_button.click(chat_handler.handle_message, inputs=message_input, outputs=[message_input, chatbot]) |
|
clear_button.click(lambda: (None, chat_handler.clear_history()), inputs=[], outputs=[message_input, chatbot]) |
|
|
|
demo.launch() |
|
|
|
if __name__ == "__main__": |
|
create_gradio_interface() |
|
|