Spaces:
Sleeping
Sleeping
soruces and metadata filtering
Browse files- app.py +97 -7
- utils/__pycache__/utils.cpython-311.pyc +0 -0
- utils/utils.py +46 -1
app.py
CHANGED
@@ -1,20 +1,21 @@
|
|
1 |
import streamlit as st
|
2 |
from utils.retriever import retrieve_paragraphs
|
3 |
from utils.generator import build_messages, _call_llm
|
|
|
4 |
import ast
|
5 |
import time
|
6 |
import asyncio
|
7 |
-
import
|
8 |
import logging
|
9 |
logging.basicConfig(level=logging.INFO)
|
10 |
|
11 |
|
12 |
-
|
13 |
-
def chat_response(query):
|
14 |
"""Generate chat response based on method and inputs"""
|
15 |
|
16 |
try:
|
17 |
-
retrieved_paragraphs = retrieve_paragraphs(query)
|
18 |
context_retrieved = ast.literal_eval(retrieved_paragraphs)
|
19 |
|
20 |
# Build list of only content, no metadata
|
@@ -24,13 +25,66 @@ def chat_response(query):
|
|
24 |
|
25 |
messages = build_messages(query, context_retrieved_lst)
|
26 |
answer = asyncio.run(_call_llm(messages))
|
27 |
-
return answer
|
28 |
|
29 |
|
30 |
except Exception as e:
|
31 |
error_message = f"Error processing request: {str(e)}"
|
32 |
return error_message
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
col_title, col_about = st.columns([8, 2])
|
35 |
with col_title:
|
36 |
st.markdown(
|
@@ -42,12 +96,48 @@ with col_title:
|
|
42 |
query = st.text_input(
|
43 |
label="Enter your question:",
|
44 |
key="query",
|
|
|
45 |
)
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
# Only run search & display if user has entered something
|
48 |
if not query.strip():
|
49 |
st.info("Please enter a question to see results.")
|
50 |
st.stop()
|
51 |
else:
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
from utils.retriever import retrieve_paragraphs
|
3 |
from utils.generator import build_messages, _call_llm
|
4 |
+
from utils.utils import meetings_list, countries_list, projects_list
|
5 |
import ast
|
6 |
import time
|
7 |
import asyncio
|
8 |
+
import re
|
9 |
import logging
|
10 |
logging.basicConfig(level=logging.INFO)
|
11 |
|
12 |
|
13 |
+
########### Function for getting response #######################
|
14 |
+
def chat_response(query, filter_metadata=None):
|
15 |
"""Generate chat response based on method and inputs"""
|
16 |
|
17 |
try:
|
18 |
+
retrieved_paragraphs = retrieve_paragraphs(query, filter_metadata=filter_metadata)
|
19 |
context_retrieved = ast.literal_eval(retrieved_paragraphs)
|
20 |
|
21 |
# Build list of only content, no metadata
|
|
|
25 |
|
26 |
messages = build_messages(query, context_retrieved_lst)
|
27 |
answer = asyncio.run(_call_llm(messages))
|
28 |
+
return answer, context_retrieved
|
29 |
|
30 |
|
31 |
except Exception as e:
|
32 |
error_message = f"Error processing request: {str(e)}"
|
33 |
return error_message
|
34 |
|
35 |
+
|
36 |
+
############## UI related functions #####################
|
37 |
+
|
38 |
+
def reset_page():
|
39 |
+
"""
|
40 |
+
Reset pagination back to the first page; used as on_change callback.
|
41 |
+
"""
|
42 |
+
st.session_state["page"] = 1
|
43 |
+
|
44 |
+
def contruct_metadata_filter():
|
45 |
+
filter_metadata = {}
|
46 |
+
if st.session_state['meetings_filter'] != 'All':
|
47 |
+
filter_metadata['meeting_id'] = st.session_state['meetings_filter']
|
48 |
+
## need to change the filter for coutnry and project tolist
|
49 |
+
if st.session_state['country_filter'] != 'All':
|
50 |
+
filter_metadata['Countries'] = st.session_state['country_filter']
|
51 |
+
if st.session_state['project_filter'] != 'All':
|
52 |
+
filter_metadata['Projects'] = st.session_state['project_filter']
|
53 |
+
|
54 |
+
return filter_metadata
|
55 |
+
|
56 |
+
|
57 |
+
def render_sources(chunks, query):
|
58 |
+
# 11.7. Render each result chunk
|
59 |
+
st.write("Sources")
|
60 |
+
st.write("======================================")
|
61 |
+
start_idx = 0
|
62 |
+
for idx, doc in enumerate(chunks, start=start_idx + 1):
|
63 |
+
meta = doc.get('answer_metadata', {})
|
64 |
+
title = meta.get('Decision Number', 'Unknown Project')
|
65 |
+
agencies = meta.get('Agencies', 'Unknown Agencies')
|
66 |
+
country = meta.get('country', 'Unknown Country')
|
67 |
+
snippet = doc.get('answer', '')
|
68 |
+
preview = snippet.split(maxsplit=90)[:90]
|
69 |
+
remainder = snippet[len(" ".join(preview)):]
|
70 |
+
|
71 |
+
# Title + metadata
|
72 |
+
st.markdown(f"#### {idx}. {title}", unsafe_allow_html=True)
|
73 |
+
st.markdown(f"**Agencies:** {agencies} | **Country:** {country}")
|
74 |
+
|
75 |
+
# Snippet + optional expander
|
76 |
+
st.markdown(" ".join(preview), unsafe_allow_html=True)
|
77 |
+
if remainder:
|
78 |
+
with st.expander("Show more"):
|
79 |
+
st.markdown(remainder, unsafe_allow_html=True)
|
80 |
+
st.divider()
|
81 |
+
|
82 |
+
st.set_page_config(page_title="Montreal AI Decisions (MVP)")
|
83 |
+
for key in ('meetings_filter', 'country_filter', 'project_filter'):
|
84 |
+
if key not in st.session_state:
|
85 |
+
st.session_state[key] = 'All'
|
86 |
+
if 'page' not in st.session_state:
|
87 |
+
st.session_state['page'] = 1
|
88 |
col_title, col_about = st.columns([8, 2])
|
89 |
with col_title:
|
90 |
st.markdown(
|
|
|
96 |
query = st.text_input(
|
97 |
label="Enter your question:",
|
98 |
key="query",
|
99 |
+
on_change = reset_page
|
100 |
)
|
101 |
|
102 |
+
# 10.2. Filter widgets
|
103 |
+
col1, col2, col3, col4 = st.columns(4)
|
104 |
+
with col1:
|
105 |
+
meetings = sorted(meetings_list)
|
106 |
+
st.selectbox(
|
107 |
+
"Meeting",
|
108 |
+
options=['All'] + meetings,
|
109 |
+
key='meetings_filter',
|
110 |
+
on_change=reset_page
|
111 |
+
)
|
112 |
+
with col2:
|
113 |
+
countries = sorted(countries_list)
|
114 |
+
st.selectbox(
|
115 |
+
"Country",
|
116 |
+
options=['All'] + countries,
|
117 |
+
key='country_filter',
|
118 |
+
on_change=reset_page
|
119 |
+
)
|
120 |
+
with col3:
|
121 |
+
projects = sorted(projects_list)
|
122 |
+
st.selectbox(
|
123 |
+
"Projects",
|
124 |
+
options=['All'] + projects,
|
125 |
+
key='project_filter',
|
126 |
+
on_change=reset_page
|
127 |
+
)
|
128 |
+
|
129 |
# Only run search & display if user has entered something
|
130 |
if not query.strip():
|
131 |
st.info("Please enter a question to see results.")
|
132 |
st.stop()
|
133 |
else:
|
134 |
+
filter_metadata = contruct_metadata_filter()
|
135 |
+
if filter_metadata:
|
136 |
+
answer, context_retrieved = chat_response(query, filter_metadata)
|
137 |
+
st.write(answer)
|
138 |
+
render_sources(context_retrieved, query)
|
139 |
+
|
140 |
+
else:
|
141 |
+
answer, context_retrieved = chat_response(query)
|
142 |
+
st.write(answer)
|
143 |
+
render_sources(context_retrieved, query)
|
utils/__pycache__/utils.cpython-311.pyc
CHANGED
Binary files a/utils/__pycache__/utils.cpython-311.pyc and b/utils/__pycache__/utils.cpython-311.pyc differ
|
|
utils/utils.py
CHANGED
@@ -38,4 +38,49 @@ def get_auth(provider: str) -> dict:
|
|
38 |
if not api_key:
|
39 |
raise RuntimeError(f"Missing API key for provider '{provider}'. Please set the appropriate environment variable.")
|
40 |
|
41 |
-
return auth_config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
if not api_key:
|
39 |
raise RuntimeError(f"Missing API key for provider '{provider}'. Please set the appropriate environment variable.")
|
40 |
|
41 |
+
return auth_config
|
42 |
+
|
43 |
+
|
44 |
+
meetings_list = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',
|
45 |
+
'13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23',
|
46 |
+
'24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34',
|
47 |
+
'35', '36', '37', '38', '39', '40', '41', '42', '43', '44', '45',
|
48 |
+
'46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56',
|
49 |
+
'57', '58', '59', '60', '61', '62', '63', '64', '65', '66', '67',
|
50 |
+
'68', '69', '70', '71', '72', '73', '74', '75', '76', '77', '78',
|
51 |
+
'79', '80', '81', '82', '83', '84', '85', '86', '87', '88', '89',
|
52 |
+
'90', '91', '92', '93', '94']
|
53 |
+
|
54 |
+
countries_list= ['Afghanistan', 'Africa', 'Albania', 'Algeria', 'American Samoa', 'Angola', 'Anguilla',
|
55 |
+
'Antigua and Barbuda', 'Argentina', 'Armenia', 'Asia Pacific', 'Australia', 'Azerbaijan', 'Bahamas',
|
56 |
+
'Bahrain', 'Bangladesh', 'Barbados', 'Belgium', 'Belize', 'Benin', 'Bhutan', 'Bolivia (Plurinational State of)',
|
57 |
+
'Bosnia and Herzegovina', 'Botswana', 'Brazil', 'British Indian Ocean Territory', 'Brunei Darussalam',
|
58 |
+
'Burkina Faso', 'Burundi', 'Cabo Verde', 'Cambodia', 'Cameroon', 'Canada', 'Central African Republic',
|
59 |
+
'Chad', 'Chile', 'China', 'Colombia', 'Comoros', 'Congo', 'Cook Islands', 'Costa Rica', 'Croatia', 'Cuba',
|
60 |
+
'Cyprus', 'Czechia', "C├┤te d'Ivoire", "Democratic People's Republic of Korea", 'Democratic Republic of the Congo',
|
61 |
+
'Djibouti', 'Dominica', 'Dominican Republic', 'Ecuador', 'Egypt', 'El Salvador', 'Equatorial Guinea', 'Eritrea',
|
62 |
+
'Eswatini', 'Ethiopia', 'Europe', 'Fiji', 'Finland', 'France', 'French Southern and Antarctic Lands', 'Gabon',
|
63 |
+
'Gambia', 'Georgia', 'Germany', 'Ghana', 'Global', 'Grenada', 'Guatemala', 'Guinea', 'Guinea-Bissau', 'Guyana',
|
64 |
+
'Haiti', 'Honduras', 'India', 'Indonesia', 'Iran (Islamic Republic of)', 'Iraq', 'Israel', 'Jamaica', 'Japan',
|
65 |
+
'Jordan', 'Kazakhstan', 'Kenya', 'Kiribati', 'Kuwait', 'Kyrgyzstan', "Lao People's Democratic Republic",
|
66 |
+
'Latin America and Caribbean', 'Lebanon', 'Lesotho', 'Liberia', 'Libya', 'Madagascar', 'Malawi', 'Malaysia',
|
67 |
+
'Maldives', 'Mali', 'Malta', 'Marshall Islands', 'Mauritania', 'Mauritius', 'Mexico', 'Micronesia (Federated States of)',
|
68 |
+
'Mongolia', 'Montenegro', 'Morocco', 'Mozambique', 'Myanmar', 'Namibia', 'Nauru', 'Nepal', 'Netherlands', 'Nicaragua',
|
69 |
+
'Niger', 'Nigeria', 'Niue', 'North Macedonia', 'Oman', 'Pakistan', 'Palau', 'Panama', 'Papua New Guinea', 'Paraguay',
|
70 |
+
'Peru', 'Philippines', 'Portugal', 'Qatar', 'Republic of Korea', 'Republic of Moldova', 'Romania', 'Russian Federation',
|
71 |
+
'Rwanda', 'Saint Kitts and Nevis', 'Saint Lucia', 'Saint Vincent and the Grenadines', 'Samoa', 'Sao Tome and Principe',
|
72 |
+
'Saudi Arabia', 'Senegal', 'Serbia', 'Seychelles', 'Sierra Leone', 'Slovenia', 'Solomon Islands', 'Somalia', 'South Africa',
|
73 |
+
'South Sudan', 'Spain', 'Sri Lanka', 'Sudan', 'Suriname', 'Sweden', 'Switzerland', 'Syrian Arab Republic', 'Thailand',
|
74 |
+
'Timor-Leste', 'Togo', 'Tonga', 'Trinidad and Tobago', 'Tunisia', 'Turkmenistan', 'Tuvalu', 'T├╝rkiye', 'Uganda', 'Ukraine',
|
75 |
+
'United Arab Emirates', 'United Kingdom', 'United Republic of Tanzania', 'United States of America', 'Uruguay', 'Uzbekistan',
|
76 |
+
'Vanuatu', 'Venezuela (Bolivarian Republic of)', 'Viet Nam', 'Yemen', 'Zambia', 'Zimbabwe']
|
77 |
+
|
78 |
+
projects_list = [ "Other ODS phaseout",
|
79 |
+
"Other ODS production phaseout",
|
80 |
+
"CFC production phaseout ",
|
81 |
+
"CFC phaseout",
|
82 |
+
"HCFC phaseout stage 1",
|
83 |
+
"HCFC phaseout stage 2",
|
84 |
+
"HCFC phaseout stage 3",
|
85 |
+
"HCFC production phaseout stage 2",
|
86 |
+
]
|