File size: 1,386 Bytes
3227532
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b496cf
 
3227532
9b496cf
 
 
 
 
3227532
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
import gradio as gr

tokenizer = T5Tokenizer.from_pretrained('t5-small')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = T5ForConditionalGeneration.from_pretrained('cssupport/t5-small-awesome-text-to-sql')
model = model.to(device)
model.eval()

def generate_sql(input_prompt):
    inputs = tokenizer(input_prompt, padding=True, truncation=True, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model.generate(**inputs, max_length=512)
    generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_sql

def gradio_interface(tables, query):
    input_prompt = f"tables:\n{tables}\nquery for:{query}"
    return generate_sql(input_prompt)

iface = gr.Interface(
    fn=gradio_interface,
    inputs=[
        gr.Textbox(lines=5, label="Context Tables", placeholder="Enter table definitions here..."),
        gr.Textbox(lines=2, label="Query Description", placeholder="Enter your SQL query here...")
    ],
    outputs=gr.Textbox(label="Generated SQL Query", placeholder=""),
    title="Text to SQL Generator",
    examples=[
        ["CREATE TABLE student_course_attendance (student_id VARCHAR); CREATE TABLE students (student_id VARCHAR);", "List the id of students who never attends courses?"]
    ]
)

iface.launch()