dgjx commited on
Commit
08343a9
·
verified ·
1 Parent(s): fa87edc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -6
app.py CHANGED
@@ -5,7 +5,13 @@ import torch
5
  # 加载模型和分词器
6
  model_name = "defog/llama-3-sqlcoder-8b" # 使用更新的模型以提高性能
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto") # 降低内存占用
 
 
 
 
 
 
9
 
10
  def generate_sql(user_question, instructions, create_table_statements):
11
  prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
@@ -20,11 +26,24 @@ The following SQL query best answers the question `{user_question}`:
20
  ```sql
21
  """
22
 
23
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
24
- outputs = model.generate(**inputs, max_length=150)
25
- sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
26
-
27
- return sql_query
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
 
30
  question = f"What are our top 3 products by revenue in the New York region?"
 
5
  # 加载模型和分词器
6
  model_name = "defog/llama-3-sqlcoder-8b" # 使用更新的模型以提高性能
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForCausalLM.from_pretrained(
9
+ model_name,
10
+ trust_remote_code=True,
11
+ torch_dtype=torch.float16,
12
+ device_map="auto",
13
+ use_cache=True,
14
+ )
15
 
16
  def generate_sql(user_question, instructions, create_table_statements):
17
  prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
 
26
  ```sql
27
  """
28
 
29
+ inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
30
+ generated_ids = model.generate(
31
+ **inputs,
32
+ num_return_sequences=1,
33
+ eos_token_id=tokenizer.eos_token_id,
34
+ pad_token_id=tokenizer.eos_token_id,
35
+ max_new_tokens=400,
36
+ do_sample=False,
37
+ num_beams=1,
38
+ )
39
+ outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
40
+
41
+ torch.cuda.empty_cache()
42
+ torch.cuda.synchronize()
43
+ # empty cache so that you do generate more results w/o memory crashing
44
+ # particularly important on Colab – memory management is much more straightforward
45
+ # when running on an inference service
46
+ return sqlparse.format(outputs[0].split("[SQL]")[-1], reindent=True)
47
 
48
 
49
  question = f"What are our top 3 products by revenue in the New York region?"