cyyeh commited on
Commit
bfd6209
·
1 Parent(s): f0a9867
Files changed (11) hide show
  1. .dockerignore +1 -0
  2. .env.example +2 -0
  3. .gitignore +2 -1
  4. Dockerfile +0 -1
  5. Makefile +1 -1
  6. poetry.lock +15 -1
  7. pyproject.toml +3 -0
  8. requirements.txt +1 -1
  9. src/__init__.py +0 -0
  10. src/apis.py +399 -5
  11. src/app.py +8 -1
.dockerignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .env
.env.example ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ WRENAI_CLOUD_API_KEY=
2
+ WRENAI_CLOUD_PROJECT_ID=
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  .DS_Store
2
- __pycache__
 
 
1
  .DS_Store
2
+ __pycache__
3
+ .env
Dockerfile CHANGED
@@ -18,7 +18,6 @@ ENV PATH $HOME/.local/bin:$PATH
18
  WORKDIR $HOME
19
  RUN mkdir app
20
  WORKDIR $HOME/app
21
- COPY . $HOME/app
22
 
23
  RUN mkdir -p $HOME/app/.streamlit
24
  COPY .streamlit/config.toml $HOME/app/.streamlit/
 
18
  WORKDIR $HOME
19
  RUN mkdir app
20
  WORKDIR $HOME/app
 
21
 
22
  RUN mkdir -p $HOME/app/.streamlit
23
  COPY .streamlit/config.toml $HOME/app/.streamlit/
Makefile CHANGED
@@ -2,4 +2,4 @@ run:
2
  poetry run streamlit run src/app.py
3
 
4
  deps:
5
- poetry export --without-hashes --format=requirements.txt > requirements.txt
 
2
  poetry run streamlit run src/app.py
3
 
4
  deps:
5
+ poetry export --without-hashes --without=dev --format=requirements.txt > requirements.txt
poetry.lock CHANGED
@@ -764,6 +764,20 @@ files = [
764
  [package.dependencies]
765
  six = ">=1.5"
766
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
767
  [[package]]
768
  name = "pytz"
769
  version = "2025.2"
@@ -1125,4 +1139,4 @@ watchmedo = ["PyYAML (>=3.10)"]
1125
  [metadata]
1126
  lock-version = "2.0"
1127
  python-versions = ">=3.12,<3.13"
1128
- content-hash = "0ea34f1300b1f1a7efcf97c0a55dd9d2b2d0b989ac9f8f94d808c57e1bac1af0"
 
764
  [package.dependencies]
765
  six = ">=1.5"
766
 
767
+ [[package]]
768
+ name = "python-dotenv"
769
+ version = "1.1.1"
770
+ description = "Read key-value pairs from a .env file and set them as environment variables"
771
+ optional = false
772
+ python-versions = ">=3.9"
773
+ files = [
774
+ {file = "python_dotenv-1.1.1-py3-none-any.whl", hash = "sha256:31f23644fe2602f88ff55e1f5c79ba497e01224ee7737937930c448e4d0e24dc"},
775
+ {file = "python_dotenv-1.1.1.tar.gz", hash = "sha256:a8a6399716257f45be6a007360200409fce5cda2661e3dec71d23dc15f6189ab"},
776
+ ]
777
+
778
+ [package.extras]
779
+ cli = ["click (>=5.0)"]
780
+
781
  [[package]]
782
  name = "pytz"
783
  version = "2025.2"
 
1139
  [metadata]
1140
  lock-version = "2.0"
1141
  python-versions = ">=3.12,<3.13"
1142
+ content-hash = "d3018cb9ea02785fe38a30f0f0e1196ad02058106619e7ba3a2a11c4a78753a2"
pyproject.toml CHANGED
@@ -11,7 +11,10 @@ package-mode = false
11
  python = ">=3.12,<3.13"
12
  streamlit = "^1.46.1"
13
  requests = "^2.32.4"
 
14
 
 
 
15
 
16
  [build-system]
17
  requires = ["poetry-core"]
 
11
  python = ">=3.12,<3.13"
12
  streamlit = "^1.46.1"
13
  requests = "^2.32.4"
14
+ watchdog = "^6.0.0"
15
 
16
+ [tool.poetry.group.dev.dependencies]
17
+ python-dotenv = "^1.1.1"
18
 
19
  [build-system]
20
  requires = ["poetry-core"]
requirements.txt CHANGED
@@ -35,4 +35,4 @@ tornado==6.5.1 ; python_version >= "3.12" and python_version < "3.13"
35
  typing-extensions==4.14.0 ; python_version >= "3.12" and python_version < "3.13"
36
  tzdata==2025.2 ; python_version >= "3.12" and python_version < "3.13"
37
  urllib3==2.5.0 ; python_version >= "3.12" and python_version < "3.13"
38
- watchdog==6.0.0 ; python_version >= "3.12" and python_version < "3.13" and platform_system != "Darwin"
 
35
  typing-extensions==4.14.0 ; python_version >= "3.12" and python_version < "3.13"
36
  tzdata==2025.2 ; python_version >= "3.12" and python_version < "3.13"
37
  urllib3==2.5.0 ; python_version >= "3.12" and python_version < "3.13"
38
+ watchdog==6.0.0 ; python_version >= "3.12" and python_version < "3.13"
src/__init__.py ADDED
File without changes
src/apis.py CHANGED
@@ -1,11 +1,47 @@
 
 
1
  import requests
2
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  def generate_sql(
5
  api_key: str,
6
  project_id: str,
7
- query: str,
8
  thread_id: str = "",
 
 
9
  ) -> tuple[dict, str]:
10
  """Generate SQL from natural language query."""
11
  base_url = "https://cloud.getwren.ai/api/v1"
@@ -16,7 +52,39 @@ def generate_sql(
16
  }
17
  payload = {
18
  "projectId": project_id,
19
- "question": query,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  }
21
  if thread_id:
22
  payload["threadId"] = thread_id
@@ -26,7 +94,60 @@ def generate_sql(
26
  response.raise_for_status()
27
  return response.json(), ""
28
  except requests.exceptions.RequestException as e:
29
- return {}, e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
 
32
  def generate_chart(
@@ -34,8 +155,8 @@ def generate_chart(
34
  project_id: str,
35
  question: str,
36
  sql: str,
 
37
  sample_size: int = 1000,
38
- thread_id: str = ""
39
  ) -> tuple[dict, str]:
40
  """Generate a chart from query results."""
41
  base_url = "https://cloud.getwren.ai/api/v1"
@@ -58,4 +179,277 @@ def generate_chart(
58
  response.raise_for_status()
59
  return response.json(), ""
60
  except requests.exceptions.RequestException as e:
61
- return {}, e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
  import requests
4
 
5
 
6
+ def ask(
7
+ api_key: str,
8
+ project_id: str,
9
+ question: str,
10
+ thread_id: str = "",
11
+ language: str = "English",
12
+ sample_size: int = 1000,
13
+ ) -> tuple[dict, str]:
14
+ """Ask a question and get an answer with SQL and explanation."""
15
+ base_url = "https://cloud.getwren.ai/api/v1"
16
+ endpoint = f"{base_url}/ask"
17
+ headers = {
18
+ "Authorization": f"Bearer {api_key}",
19
+ "Content-Type": "application/json"
20
+ }
21
+ payload = {
22
+ "projectId": project_id,
23
+ "question": question,
24
+ "sampleSize": sample_size,
25
+ "language": language,
26
+ }
27
+ if thread_id:
28
+ payload["threadId"] = thread_id
29
+
30
+ try:
31
+ response = requests.post(endpoint, json=payload, headers=headers)
32
+ response.raise_for_status()
33
+ return response.json(), ""
34
+ except requests.exceptions.RequestException as e:
35
+ return {}, str(e)
36
+
37
+
38
  def generate_sql(
39
  api_key: str,
40
  project_id: str,
41
+ question: str,
42
  thread_id: str = "",
43
+ language: str = "English",
44
+ return_sql_dialect: bool = False,
45
  ) -> tuple[dict, str]:
46
  """Generate SQL from natural language query."""
47
  base_url = "https://cloud.getwren.ai/api/v1"
 
52
  }
53
  payload = {
54
  "projectId": project_id,
55
+ "question": question,
56
+ "language": language,
57
+ "returnSqlDialect": return_sql_dialect,
58
+ }
59
+ if thread_id:
60
+ payload["threadId"] = thread_id
61
+
62
+ try:
63
+ response = requests.post(endpoint, json=payload, headers=headers)
64
+ response.raise_for_status()
65
+ return response.json(), ""
66
+ except requests.exceptions.RequestException as e:
67
+ return {}, str(e)
68
+
69
+
70
+ def run_sql(
71
+ api_key: str,
72
+ project_id: str,
73
+ sql: str,
74
+ thread_id: str = "",
75
+ limit: int = 1000,
76
+ ) -> tuple[dict, str]:
77
+ """Execute SQL query and return results."""
78
+ base_url = "https://cloud.getwren.ai/api/v1"
79
+ endpoint = f"{base_url}/run_sql"
80
+ headers = {
81
+ "Authorization": f"Bearer {api_key}",
82
+ "Content-Type": "application/json"
83
+ }
84
+ payload = {
85
+ "projectId": project_id,
86
+ "sql": sql,
87
+ "limit": limit,
88
  }
89
  if thread_id:
90
  payload["threadId"] = thread_id
 
94
  response.raise_for_status()
95
  return response.json(), ""
96
  except requests.exceptions.RequestException as e:
97
+ return {}, str(e)
98
+
99
+
100
+ def generate_summary(
101
+ api_key: str,
102
+ project_id: str,
103
+ question: str,
104
+ sql: str,
105
+ thread_id: str = "",
106
+ language: str = "English",
107
+ sample_size: int = 1000,
108
+ ) -> tuple[dict, str]:
109
+ """Generate a summary of query results."""
110
+ base_url = "https://cloud.getwren.ai/api/v1"
111
+ endpoint = f"{base_url}/generate_summary"
112
+ headers = {
113
+ "Authorization": f"Bearer {api_key}",
114
+ "Content-Type": "application/json"
115
+ }
116
+ payload = {
117
+ "projectId": project_id,
118
+ "question": question,
119
+ "sql": sql,
120
+ "language": language,
121
+ "sampleSize": sample_size,
122
+ }
123
+ if thread_id:
124
+ payload["threadId"] = thread_id
125
+
126
+ try:
127
+ response = requests.post(endpoint, json=payload, headers=headers)
128
+ response.raise_for_status()
129
+ return response.json(), ""
130
+ except requests.exceptions.RequestException as e:
131
+ return {}, str(e)
132
+
133
+
134
+ def stream_explanation(
135
+ api_key: str
136
+ ) -> tuple[dict, str]:
137
+ """Stream explanation for a query."""
138
+ base_url = "https://cloud.getwren.ai/api/v1"
139
+ endpoint = f"{base_url}/stream_explanation"
140
+ headers = {
141
+ "Authorization": f"Bearer {api_key}",
142
+ "Accept": "text/event-stream",
143
+ }
144
+
145
+ try:
146
+ response = requests.get(endpoint, headers=headers, stream=True)
147
+ response.raise_for_status()
148
+ return response.json(), ""
149
+ except requests.exceptions.RequestException as e:
150
+ return {}, str(e)
151
 
152
 
153
  def generate_chart(
 
155
  project_id: str,
156
  question: str,
157
  sql: str,
158
+ thread_id: str = "",
159
  sample_size: int = 1000,
 
160
  ) -> tuple[dict, str]:
161
  """Generate a chart from query results."""
162
  base_url = "https://cloud.getwren.ai/api/v1"
 
179
  response.raise_for_status()
180
  return response.json(), ""
181
  except requests.exceptions.RequestException as e:
182
+ return {}, str(e)
183
+
184
+
185
+ def stream_ask(
186
+ api_key: str,
187
+ project_id: str,
188
+ question: str,
189
+ thread_id: str = "",
190
+ language: str = "English",
191
+ sample_size: int = 1000,
192
+ ) -> tuple[dict, str]:
193
+ """Stream ask endpoint for real-time responses."""
194
+ base_url = "https://cloud.getwren.ai/api/v1"
195
+ endpoint = f"{base_url}/stream/ask"
196
+ headers = {
197
+ "Authorization": f"Bearer {api_key}",
198
+ "Content-Type": "application/json",
199
+ "Accept": "text/event-stream",
200
+ }
201
+ payload = {
202
+ "projectId": project_id,
203
+ "question": question,
204
+ "language": language,
205
+ "sampleSize": sample_size,
206
+ }
207
+ if thread_id:
208
+ payload["threadId"] = thread_id
209
+
210
+ try:
211
+ response = requests.post(endpoint, json=payload, headers=headers, stream=True)
212
+ response.raise_for_status()
213
+ return response.json(), ""
214
+ except requests.exceptions.RequestException as e:
215
+ return {}, str(e)
216
+
217
+
218
+ def stream_generate_sql(
219
+ api_key: str,
220
+ project_id: str,
221
+ question: str,
222
+ thread_id: str = "",
223
+ language: str = "English",
224
+ return_sql_dialect: bool = False,
225
+ ) -> tuple[dict, str]:
226
+ """Stream SQL generation endpoint for real-time responses."""
227
+ base_url = "https://cloud.getwren.ai/api/v1"
228
+ endpoint = f"{base_url}/stream/generate_sql"
229
+ headers = {
230
+ "Authorization": f"Bearer {api_key}",
231
+ "Content-Type": "application/json",
232
+ "Accept": "text/event-stream",
233
+ }
234
+ payload = {
235
+ "projectId": project_id,
236
+ "question": question,
237
+ "language": language,
238
+ "returnSqlDialect": return_sql_dialect,
239
+ }
240
+ if thread_id:
241
+ payload["threadId"] = thread_id
242
+
243
+ try:
244
+ response = requests.post(endpoint, json=payload, headers=headers, stream=True)
245
+ response.raise_for_status()
246
+ return response.json(), ""
247
+ except requests.exceptions.RequestException as e:
248
+ return {}, str(e)
249
+
250
+
251
+ def get_models(
252
+ api_key: str,
253
+ project_id: str,
254
+ ) -> tuple[dict, str]:
255
+ """Get latest deployed models for a project."""
256
+ base_url = "https://cloud.getwren.ai/api/v1"
257
+ endpoint = f"{base_url}/projects/{project_id}/models"
258
+ headers = {
259
+ "Authorization": f"Bearer {api_key}",
260
+ }
261
+
262
+ try:
263
+ response = requests.get(endpoint, headers=headers)
264
+ response.raise_for_status()
265
+ return response.json(), ""
266
+ except requests.exceptions.RequestException as e:
267
+ return {}, str(e)
268
+
269
+
270
+ def get_instructions(
271
+ api_key: str,
272
+ project_id: str,
273
+ ) -> tuple[dict, str]:
274
+ """Get all instructions for a project."""
275
+ base_url = "https://cloud.getwren.ai/api/v1"
276
+ endpoint = f"{base_url}/projects/{project_id}/knowledge/instructions"
277
+ headers = {
278
+ "Authorization": f"Bearer {api_key}",
279
+ }
280
+
281
+ try:
282
+ response = requests.get(endpoint, headers=headers)
283
+ response.raise_for_status()
284
+ return response.json(), ""
285
+ except requests.exceptions.RequestException as e:
286
+ return {}, str(e)
287
+
288
+
289
+ def create_instruction(
290
+ api_key: str,
291
+ project_id: str,
292
+ instruction: str,
293
+ is_global: bool = True,
294
+ questions: Optional[list[str]] = None,
295
+ ) -> tuple[dict, str]:
296
+ """Create a new instruction for a project."""
297
+ base_url = "https://cloud.getwren.ai/api/v1"
298
+ endpoint = f"{base_url}/projects/{project_id}/knowledge/instructions"
299
+ headers = {
300
+ "Authorization": f"Bearer {api_key}",
301
+ "Content-Type": "application/json"
302
+ }
303
+ payload = {
304
+ "instruction": instruction,
305
+ "isGlobal": is_global,
306
+ "questions": questions or [],
307
+ }
308
+
309
+ try:
310
+ response = requests.post(endpoint, json=payload, headers=headers)
311
+ response.raise_for_status()
312
+ return response.json(), ""
313
+ except requests.exceptions.RequestException as e:
314
+ return {}, str(e)
315
+
316
+
317
+ def update_instruction(
318
+ api_key: str,
319
+ project_id: str,
320
+ instruction_id: str,
321
+ instruction: str,
322
+ is_global: bool = True,
323
+ questions: Optional[list[str]] = None,
324
+ ) -> tuple[dict, str]:
325
+ """Update an existing instruction."""
326
+ base_url = "https://cloud.getwren.ai/api/v1"
327
+ endpoint = f"{base_url}/projects/{project_id}/knowledge/instructions/{instruction_id}"
328
+ headers = {
329
+ "Authorization": f"Bearer {api_key}",
330
+ "Content-Type": "application/json"
331
+ }
332
+ payload = {
333
+ "instruction": instruction,
334
+ "isGlobal": is_global,
335
+ "questions": questions or [],
336
+ }
337
+
338
+ try:
339
+ response = requests.put(endpoint, json=payload, headers=headers)
340
+ response.raise_for_status()
341
+ return response.json(), ""
342
+ except requests.exceptions.RequestException as e:
343
+ return {}, str(e)
344
+
345
+
346
+ def delete_instruction(
347
+ api_key: str,
348
+ project_id: str,
349
+ instruction_id: str,
350
+ ) -> tuple[dict, str]:
351
+ """Delete an instruction."""
352
+ base_url = "https://cloud.getwren.ai/api/v1"
353
+ endpoint = f"{base_url}/projects/{project_id}/knowledge/instructions/{instruction_id}"
354
+ headers = {
355
+ "Authorization": f"Bearer {api_key}",
356
+ }
357
+
358
+ try:
359
+ response = requests.delete(endpoint, headers=headers)
360
+ response.raise_for_status()
361
+ return response.json(), ""
362
+ except requests.exceptions.RequestException as e:
363
+ return {}, str(e)
364
+
365
+
366
+ def get_sql_pairs(
367
+ api_key: str,
368
+ project_id: str,
369
+ ) -> tuple[dict, str]:
370
+ """Get all SQL pairs for a project."""
371
+ base_url = "https://cloud.getwren.ai/api/v1"
372
+ endpoint = f"{base_url}/projects/{project_id}/knowledge/sql_pairs"
373
+ headers = {
374
+ "Authorization": f"Bearer {api_key}",
375
+ }
376
+
377
+ try:
378
+ response = requests.get(endpoint, headers=headers)
379
+ response.raise_for_status()
380
+ return response.json(), ""
381
+ except requests.exceptions.RequestException as e:
382
+ return {}, str(e)
383
+
384
+
385
+ def create_sql_pair(
386
+ api_key: str,
387
+ project_id: str,
388
+ question: str,
389
+ sql: str,
390
+ ) -> tuple[dict, str]:
391
+ """Create a new SQL pair for a project."""
392
+ base_url = "https://cloud.getwren.ai/api/v1"
393
+ endpoint = f"{base_url}/projects/{project_id}/knowledge/sql_pairs"
394
+ headers = {
395
+ "Authorization": f"Bearer {api_key}",
396
+ "Content-Type": "application/json"
397
+ }
398
+ payload = {
399
+ "question": question,
400
+ "sql": sql,
401
+ }
402
+
403
+ try:
404
+ response = requests.post(endpoint, json=payload, headers=headers)
405
+ response.raise_for_status()
406
+ return response.json(), ""
407
+ except requests.exceptions.RequestException as e:
408
+ return {}, str(e)
409
+
410
+
411
+ def update_sql_pair(
412
+ api_key: str,
413
+ project_id: str,
414
+ sql_pair_id: str,
415
+ question: str,
416
+ sql: str,
417
+ ) -> tuple[dict, str]:
418
+ """Update an existing SQL pair."""
419
+ base_url = "https://cloud.getwren.ai/api/v1"
420
+ endpoint = f"{base_url}/projects/{project_id}/knowledge/sql_pairs/{sql_pair_id}"
421
+ headers = {
422
+ "Authorization": f"Bearer {api_key}",
423
+ "Content-Type": "application/json"
424
+ }
425
+ payload = {
426
+ "question": question,
427
+ "sql": sql,
428
+ }
429
+
430
+ try:
431
+ response = requests.put(endpoint, json=payload, headers=headers)
432
+ response.raise_for_status()
433
+ return response.json(), ""
434
+ except requests.exceptions.RequestException as e:
435
+ return {}, str(e)
436
+
437
+
438
+ def delete_sql_pair(
439
+ api_key: str,
440
+ project_id: str,
441
+ sql_pair_id: str,
442
+ ) -> tuple[dict, str]:
443
+ """Delete an SQL pair."""
444
+ base_url = "https://cloud.getwren.ai/api/v1"
445
+ endpoint = f"{base_url}/projects/{project_id}/knowledge/sql_pairs/{sql_pair_id}"
446
+ headers = {
447
+ "Authorization": f"Bearer {api_key}",
448
+ }
449
+
450
+ try:
451
+ response = requests.delete(endpoint, headers=headers)
452
+ response.raise_for_status()
453
+ return response.json(), ""
454
+ except requests.exceptions.RequestException as e:
455
+ return {}, str(e)
src/app.py CHANGED
@@ -112,7 +112,14 @@ def main():
112
 
113
  # Generate chart
114
  with st.spinner("Generating chart..."):
115
- chart_response, error = generate_chart(api_key, project_id, prompt, sql_query, sample_size, st.session_state.thread_id)
 
 
 
 
 
 
 
116
 
117
  if chart_response:
118
  vega_spec = chart_response.get("vegaSpec", {})
 
112
 
113
  # Generate chart
114
  with st.spinner("Generating chart..."):
115
+ chart_response, error = generate_chart(
116
+ api_key,
117
+ project_id,
118
+ prompt,
119
+ sql_query,
120
+ thread_id=st.session_state.thread_id,
121
+ sample_size=sample_size,
122
+ )
123
 
124
  if chart_response:
125
  vega_spec = chart_response.get("vegaSpec", {})