ppsingh commited on
Commit
f1afeff
·
1 Parent(s): 056423f

soruces and metadata filtering

Browse files
app.py CHANGED
@@ -1,20 +1,21 @@
1
  import streamlit as st
2
  from utils.retriever import retrieve_paragraphs
3
  from utils.generator import build_messages, _call_llm
 
4
  import ast
5
  import time
6
  import asyncio
7
- import logging
8
  import logging
9
  logging.basicConfig(level=logging.INFO)
10
 
11
 
12
-
13
- def chat_response(query):
14
  """Generate chat response based on method and inputs"""
15
 
16
  try:
17
- retrieved_paragraphs = retrieve_paragraphs(query)
18
  context_retrieved = ast.literal_eval(retrieved_paragraphs)
19
 
20
  # Build list of only content, no metadata
@@ -24,13 +25,66 @@ def chat_response(query):
24
 
25
  messages = build_messages(query, context_retrieved_lst)
26
  answer = asyncio.run(_call_llm(messages))
27
- return answer
28
 
29
 
30
  except Exception as e:
31
  error_message = f"Error processing request: {str(e)}"
32
  return error_message
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  col_title, col_about = st.columns([8, 2])
35
  with col_title:
36
  st.markdown(
@@ -42,12 +96,48 @@ with col_title:
42
  query = st.text_input(
43
  label="Enter your question:",
44
  key="query",
 
45
  )
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  # Only run search & display if user has entered something
48
  if not query.strip():
49
  st.info("Please enter a question to see results.")
50
  st.stop()
51
  else:
52
- st.write(chat_response(query))
53
-
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  from utils.retriever import retrieve_paragraphs
3
  from utils.generator import build_messages, _call_llm
4
+ from utils.utils import meetings_list, countries_list, projects_list
5
  import ast
6
  import time
7
  import asyncio
8
+ import re
9
  import logging
10
  logging.basicConfig(level=logging.INFO)
11
 
12
 
13
+ ########### Function for getting response #######################
14
+ def chat_response(query, filter_metadata=None):
15
  """Generate chat response based on method and inputs"""
16
 
17
  try:
18
+ retrieved_paragraphs = retrieve_paragraphs(query, filter_metadata=filter_metadata)
19
  context_retrieved = ast.literal_eval(retrieved_paragraphs)
20
 
21
  # Build list of only content, no metadata
 
25
 
26
  messages = build_messages(query, context_retrieved_lst)
27
  answer = asyncio.run(_call_llm(messages))
28
+ return answer, context_retrieved
29
 
30
 
31
  except Exception as e:
32
  error_message = f"Error processing request: {str(e)}"
33
  return error_message
34
 
35
+
36
+ ############## UI related functions #####################
37
+
38
+ def reset_page():
39
+ """
40
+ Reset pagination back to the first page; used as on_change callback.
41
+ """
42
+ st.session_state["page"] = 1
43
+
44
+ def contruct_metadata_filter():
45
+ filter_metadata = {}
46
+ if st.session_state['meetings_filter'] != 'All':
47
+ filter_metadata['meeting_id'] = st.session_state['meetings_filter']
48
+ ## need to change the filter for coutnry and project tolist
49
+ if st.session_state['country_filter'] != 'All':
50
+ filter_metadata['Countries'] = st.session_state['country_filter']
51
+ if st.session_state['project_filter'] != 'All':
52
+ filter_metadata['Projects'] = st.session_state['project_filter']
53
+
54
+ return filter_metadata
55
+
56
+
57
+ def render_sources(chunks, query):
58
+ # 11.7. Render each result chunk
59
+ st.write("Sources")
60
+ st.write("======================================")
61
+ start_idx = 0
62
+ for idx, doc in enumerate(chunks, start=start_idx + 1):
63
+ meta = doc.get('answer_metadata', {})
64
+ title = meta.get('Decision Number', 'Unknown Project')
65
+ agencies = meta.get('Agencies', 'Unknown Agencies')
66
+ country = meta.get('country', 'Unknown Country')
67
+ snippet = doc.get('answer', '')
68
+ preview = snippet.split(maxsplit=90)[:90]
69
+ remainder = snippet[len(" ".join(preview)):]
70
+
71
+ # Title + metadata
72
+ st.markdown(f"#### {idx}. {title}", unsafe_allow_html=True)
73
+ st.markdown(f"**Agencies:** {agencies} | **Country:** {country}")
74
+
75
+ # Snippet + optional expander
76
+ st.markdown(" ".join(preview), unsafe_allow_html=True)
77
+ if remainder:
78
+ with st.expander("Show more"):
79
+ st.markdown(remainder, unsafe_allow_html=True)
80
+ st.divider()
81
+
82
+ st.set_page_config(page_title="Montreal AI Decisions (MVP)")
83
+ for key in ('meetings_filter', 'country_filter', 'project_filter'):
84
+ if key not in st.session_state:
85
+ st.session_state[key] = 'All'
86
+ if 'page' not in st.session_state:
87
+ st.session_state['page'] = 1
88
  col_title, col_about = st.columns([8, 2])
89
  with col_title:
90
  st.markdown(
 
96
  query = st.text_input(
97
  label="Enter your question:",
98
  key="query",
99
+ on_change = reset_page
100
  )
101
 
102
+ # 10.2. Filter widgets
103
+ col1, col2, col3, col4 = st.columns(4)
104
+ with col1:
105
+ meetings = sorted(meetings_list)
106
+ st.selectbox(
107
+ "Meeting",
108
+ options=['All'] + meetings,
109
+ key='meetings_filter',
110
+ on_change=reset_page
111
+ )
112
+ with col2:
113
+ countries = sorted(countries_list)
114
+ st.selectbox(
115
+ "Country",
116
+ options=['All'] + countries,
117
+ key='country_filter',
118
+ on_change=reset_page
119
+ )
120
+ with col3:
121
+ projects = sorted(projects_list)
122
+ st.selectbox(
123
+ "Projects",
124
+ options=['All'] + projects,
125
+ key='project_filter',
126
+ on_change=reset_page
127
+ )
128
+
129
  # Only run search & display if user has entered something
130
  if not query.strip():
131
  st.info("Please enter a question to see results.")
132
  st.stop()
133
  else:
134
+ filter_metadata = contruct_metadata_filter()
135
+ if filter_metadata:
136
+ answer, context_retrieved = chat_response(query, filter_metadata)
137
+ st.write(answer)
138
+ render_sources(context_retrieved, query)
139
+
140
+ else:
141
+ answer, context_retrieved = chat_response(query)
142
+ st.write(answer)
143
+ render_sources(context_retrieved, query)
utils/__pycache__/utils.cpython-311.pyc CHANGED
Binary files a/utils/__pycache__/utils.cpython-311.pyc and b/utils/__pycache__/utils.cpython-311.pyc differ
 
utils/utils.py CHANGED
@@ -38,4 +38,49 @@ def get_auth(provider: str) -> dict:
38
  if not api_key:
39
  raise RuntimeError(f"Missing API key for provider '{provider}'. Please set the appropriate environment variable.")
40
 
41
- return auth_config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  if not api_key:
39
  raise RuntimeError(f"Missing API key for provider '{provider}'. Please set the appropriate environment variable.")
40
 
41
+ return auth_config
42
+
43
+
44
+ meetings_list = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',
45
+ '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23',
46
+ '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34',
47
+ '35', '36', '37', '38', '39', '40', '41', '42', '43', '44', '45',
48
+ '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56',
49
+ '57', '58', '59', '60', '61', '62', '63', '64', '65', '66', '67',
50
+ '68', '69', '70', '71', '72', '73', '74', '75', '76', '77', '78',
51
+ '79', '80', '81', '82', '83', '84', '85', '86', '87', '88', '89',
52
+ '90', '91', '92', '93', '94']
53
+
54
+ countries_list= ['Afghanistan', 'Africa', 'Albania', 'Algeria', 'American Samoa', 'Angola', 'Anguilla',
55
+ 'Antigua and Barbuda', 'Argentina', 'Armenia', 'Asia Pacific', 'Australia', 'Azerbaijan', 'Bahamas',
56
+ 'Bahrain', 'Bangladesh', 'Barbados', 'Belgium', 'Belize', 'Benin', 'Bhutan', 'Bolivia (Plurinational State of)',
57
+ 'Bosnia and Herzegovina', 'Botswana', 'Brazil', 'British Indian Ocean Territory', 'Brunei Darussalam',
58
+ 'Burkina Faso', 'Burundi', 'Cabo Verde', 'Cambodia', 'Cameroon', 'Canada', 'Central African Republic',
59
+ 'Chad', 'Chile', 'China', 'Colombia', 'Comoros', 'Congo', 'Cook Islands', 'Costa Rica', 'Croatia', 'Cuba',
60
+ 'Cyprus', 'Czechia', "C├┤te d'Ivoire", "Democratic People's Republic of Korea", 'Democratic Republic of the Congo',
61
+ 'Djibouti', 'Dominica', 'Dominican Republic', 'Ecuador', 'Egypt', 'El Salvador', 'Equatorial Guinea', 'Eritrea',
62
+ 'Eswatini', 'Ethiopia', 'Europe', 'Fiji', 'Finland', 'France', 'French Southern and Antarctic Lands', 'Gabon',
63
+ 'Gambia', 'Georgia', 'Germany', 'Ghana', 'Global', 'Grenada', 'Guatemala', 'Guinea', 'Guinea-Bissau', 'Guyana',
64
+ 'Haiti', 'Honduras', 'India', 'Indonesia', 'Iran (Islamic Republic of)', 'Iraq', 'Israel', 'Jamaica', 'Japan',
65
+ 'Jordan', 'Kazakhstan', 'Kenya', 'Kiribati', 'Kuwait', 'Kyrgyzstan', "Lao People's Democratic Republic",
66
+ 'Latin America and Caribbean', 'Lebanon', 'Lesotho', 'Liberia', 'Libya', 'Madagascar', 'Malawi', 'Malaysia',
67
+ 'Maldives', 'Mali', 'Malta', 'Marshall Islands', 'Mauritania', 'Mauritius', 'Mexico', 'Micronesia (Federated States of)',
68
+ 'Mongolia', 'Montenegro', 'Morocco', 'Mozambique', 'Myanmar', 'Namibia', 'Nauru', 'Nepal', 'Netherlands', 'Nicaragua',
69
+ 'Niger', 'Nigeria', 'Niue', 'North Macedonia', 'Oman', 'Pakistan', 'Palau', 'Panama', 'Papua New Guinea', 'Paraguay',
70
+ 'Peru', 'Philippines', 'Portugal', 'Qatar', 'Republic of Korea', 'Republic of Moldova', 'Romania', 'Russian Federation',
71
+ 'Rwanda', 'Saint Kitts and Nevis', 'Saint Lucia', 'Saint Vincent and the Grenadines', 'Samoa', 'Sao Tome and Principe',
72
+ 'Saudi Arabia', 'Senegal', 'Serbia', 'Seychelles', 'Sierra Leone', 'Slovenia', 'Solomon Islands', 'Somalia', 'South Africa',
73
+ 'South Sudan', 'Spain', 'Sri Lanka', 'Sudan', 'Suriname', 'Sweden', 'Switzerland', 'Syrian Arab Republic', 'Thailand',
74
+ 'Timor-Leste', 'Togo', 'Tonga', 'Trinidad and Tobago', 'Tunisia', 'Turkmenistan', 'Tuvalu', 'T├╝rkiye', 'Uganda', 'Ukraine',
75
+ 'United Arab Emirates', 'United Kingdom', 'United Republic of Tanzania', 'United States of America', 'Uruguay', 'Uzbekistan',
76
+ 'Vanuatu', 'Venezuela (Bolivarian Republic of)', 'Viet Nam', 'Yemen', 'Zambia', 'Zimbabwe']
77
+
78
+ projects_list = [ "Other ODS phaseout",
79
+ "Other ODS production phaseout",
80
+ "CFC production phaseout ",
81
+ "CFC phaseout",
82
+ "HCFC phaseout stage 1",
83
+ "HCFC phaseout stage 2",
84
+ "HCFC phaseout stage 3",
85
+ "HCFC production phaseout stage 2",
86
+ ]