mervp commited on
Commit
e9114a0
·
verified ·
1 Parent(s): db7a8da

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +58 -22
README.md CHANGED
@@ -53,30 +53,66 @@ Always validate and test generated queries before execution in a production data
53
  ## How to Get Started with the Model
54
 
55
  ```python
56
- from transformers import AutoTokenizer, AutoModelForCausalLM
57
- from peft import PeftModel
 
 
 
 
 
58
 
59
- base_model = AutoModelForCausalLM.from_pretrained(
60
- "unsloth/llama-3.2-3b-unsloth-bnb-4bit",
61
- device_map="auto",
62
- trust_remote_code=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  )
64
 
65
- tokenizer = AutoTokenizer.from_pretrained("unsloth/llama-3.2-3b-unsloth-bnb-4bit")
66
- model = PeftModel.from_pretrained(base_model, "mervp/SQLGenie")
67
 
68
- prompt = "List the customers from Canada."
69
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
70
- outputs = model.generate(**inputs, max_new_tokens=100)
71
- print(tokenizer.decode(outputs[0], skip_special_tokens=True))
72
 
73
- #
74
- OR
75
- #
76
- from unsloth import FastLanguageModel
77
- model, tokenizer = FastLanguageModel.from_pretrained(
78
- model_name="mervp/SQLGenie",
79
- max_seq_length=2048,
80
- dtype=None,
81
- # load_in_4bit=True,
82
- )
 
53
  ## How to Get Started with the Model
54
 
55
  ```python
56
+ from unsloth import FastLanguageModel
57
+
58
+ model, tokenizer = FastLanguageModel.from_pretrained(
59
+ model_name="mervp/SQLGenie",
60
+ max_seq_length=2048,
61
+ dtype=None,
62
+ )
63
 
64
+ prompt = """ You are an text to SQL query translator.
65
+ Users will ask you questions in English
66
+ and you will generate a SQL query based on their question
67
+ SQL has to be simple, The schema context has been provided to you.
68
+
69
+
70
+ ### User Question:
71
+ {}
72
+
73
+ ### Sql Context:
74
+ {}
75
+
76
+ ### Sql Query:
77
+ {}
78
+ """
79
+
80
+ question = "List the names of customers who have an account balance greater than 6000."
81
+ schema = """
82
+ CREATE TABLE socially_responsible_lending (
83
+ customer_id INT,
84
+ name VARCHAR(50),
85
+ account_balance DECIMAL(10, 2)
86
+ );
87
+
88
+ INSERT INTO socially_responsible_lending VALUES
89
+ (1, 'james Chad', 5000),
90
+ (2, 'Jane Rajesh', 7000),
91
+ (3, 'Alia Kapoor', 6000),
92
+ (4, 'Fatima Patil', 8000);
93
+ """
94
+
95
+ inputs = tokenizer(
96
+ [prompt.format(question, schema, "")],
97
+ return_tensors="pt",
98
+ padding=True,
99
+ truncation=True
100
+ ).to("cuda")
101
+
102
+ output = model.generate(
103
+ **inputs,
104
+ max_new_tokens=256,
105
+ temperature=0.2,
106
+ top_p=0.9,
107
+ top_k=50,
108
+ do_sample=True
109
  )
110
 
111
+ decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)
 
112
 
113
+ if "### Sql Query:" in decoded_output:
114
+ sql_query = decoded_output.split("### Sql Query:")[-1].strip()
115
+ else:
116
+ sql_query = decoded_output.strip()
117
 
118
+ print(sql_query)