KeshavRa commited on
Commit
45c6f19
·
verified ·
1 Parent(s): f9462af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +247 -4
app.py CHANGED
@@ -256,7 +256,8 @@ if selected_app == "3) Upload Datasets":
256
  st.markdown("Go to this [google colab link](https://colab.research.google.com/drive/1eCpk9HUoCKZb--tiNyQSHFW2ojoaA35m) to get started")
257
 
258
  if selected_app == "4) Create Chatbot":
259
-
 
260
 
261
 
262
  requirements = '''
@@ -270,9 +271,251 @@ if selected_app == "4) Create Chatbot":
270
  st.write("requirements.txt")
271
  st.code(requirements, language='python')
272
 
273
- app = '''
274
- APP.PY
275
- '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
  st.write("app.py")
278
  st.code(app, language='python')
 
256
  st.markdown("Go to this [google colab link](https://colab.research.google.com/drive/1eCpk9HUoCKZb--tiNyQSHFW2ojoaA35m) to get started")
257
 
258
  if selected_app == "4) Create Chatbot":
259
+ num_domains = st.number_input("Number sentences per Q/A pair", value=2, step=1, min_value=1, max_value=3)
260
+
261
 
262
 
263
  requirements = '''
 
271
  st.write("requirements.txt")
272
  st.code(requirements, language='python')
273
 
274
+ app = """
275
+ import os
276
+ import streamlit as st
277
+ from datasets import load_dataset
278
+ import chromadb
279
+ import string
280
+
281
+ from openai import OpenAI
282
+
283
+ import numpy as np
284
+ import pandas as pd
285
+
286
+ from scipy.spatial.distance import cosine
287
+
288
+ from typing import Dict, List
289
+
290
+ def merge_dataframes(dataframes):
291
+ # Concatenate the list of dataframes
292
+ combined_dataframe = pd.concat(dataframes, ignore_index=True)
293
+
294
+ # Ensure that the resulting dataframe only contains the columns "context", "questions", "answers"
295
+ combined_dataframe = combined_dataframe[['context', 'questions', 'answers']]
296
+
297
+ return combined_dataframe
298
+
299
+ def call_chatgpt(prompt: str, directions: str) -> str:
300
+ '''
301
+ Uses the OpenAI API to generate an AI response to a prompt.
302
+ Args:
303
+ prompt: A string representing the prompt to send to the OpenAI API.
304
+ Returns:
305
+ A string representing the AI's generated response.
306
+ '''
307
+
308
+ # Use the OpenAI API to generate a response based on the input prompt.
309
+ client = OpenAI(api_key = os.environ["OPENAI_API_KEY"])
310
+
311
+ completion = client.chat.completions.create(
312
+ model="gpt-3.5-turbo-0125",
313
+ messages=[
314
+ {"role": "system", "content": directions},
315
+ {"role": "user", "content": prompt}
316
+ ]
317
+ )
318
+
319
+ # Extract the text from the first (and only) choice in the response output.
320
+ ans = completion.choices[0].message.content
321
+
322
+ # Return the generated AI response.
323
+ return ans
324
+
325
+ def openai_text_embedding(prompt: str) -> str:
326
+ return openai.Embedding.create(input=prompt, model="text-embedding-ada-002")[
327
+ "data"
328
+ ][0]["embedding"]
329
+
330
+ def calculate_sts_openai_score(sentence1: str, sentence2: str) -> float:
331
+ # Compute sentence embeddings
332
+ embedding1 = openai_text_embedding(sentence1) # Flatten the embedding array
333
+ embedding2 = openai_text_embedding(sentence2) # Flatten the embedding array
334
+
335
+ # Convert to array
336
+ embedding1 = np.asarray(embedding1)
337
+ embedding2 = np.asarray(embedding2)
338
+
339
+ # Calculate cosine similarity between the embeddings
340
+ similarity_score = 1 - cosine(embedding1, embedding2)
341
+
342
+ return similarity_score
343
+
344
+ def add_dist_score_column(
345
+ dataframe: pd.DataFrame, sentence: str,
346
+ ) -> pd.DataFrame:
347
+ dataframe["stsopenai"] = dataframe["questions"].apply(
348
+ lambda x: calculate_sts_openai_score(str(x), sentence)
349
+ )
350
+
351
+ sorted_dataframe = dataframe.sort_values(by="stsopenai", ascending=False)
352
+
353
+
354
+ return sorted_dataframe.iloc[:5, :]
355
+
356
+ def convert_to_list_of_dict(df: pd.DataFrame) -> List[Dict[str, str]]:
357
+ '''
358
+ Reads in a pandas DataFrame and produces a list of dictionaries with two keys each, 'question' and 'answer.'
359
+ Args:
360
+ df: A pandas DataFrame with columns named 'questions' and 'answers'.
361
+ Returns:
362
+ A list of dictionaries, with each dictionary containing a 'question' and 'answer' key-value pair.
363
+ '''
364
+
365
+ # Initialize an empty list to store the dictionaries
366
+ result = []
367
+
368
+ # Loop through each row of the DataFrame
369
+ for index, row in df.iterrows():
370
+ # Create a dictionary with the current question and answer
371
+ qa_dict_quest = {"role": "user", "content": row["questions"]}
372
+ qa_dict_ans = {"role": "assistant", "content": row["answers"]}
373
+
374
+ # Add the dictionary to the result list
375
+ result.append(qa_dict_quest)
376
+ result.append(qa_dict_ans)
377
+
378
+ # Return the list of dictionaries
379
+ return result
380
+
381
+ st.sidebar.markdown(f'''This is a chatbot to help you learn more about {organization_name}''')
382
+
383
+ domain = st.sidebar.selectbox(f"Select a topic", {domains})
384
+
385
+ special_threshold = 0.3
386
+
387
+ n_results = 3
388
+
389
+ clear_button = st.sidebar.button("Clear Conversation", key="clear")
390
+
391
+ if clear_button:
392
+ st.session_state.messages = []
393
+ st.session_state.curr_domain = ""
394
+
395
+
396
+
397
+ ###
398
+ ###
399
+ ### Load the dataset from a provided source.
400
+ ###
401
+ ###
402
+
403
+ initial_input = f"Tell me about {organization_name}"
404
+
405
+ # Initialize a new client for ChromeDB.
406
+ client = chromadb.Client()
407
+
408
+ # Generate a random number between 1 billion and 10 billion.
409
+ random_number: int = np.random.randint(low=1e9, high=1e10)
410
+
411
+ # Generate a random string consisting of 10 uppercase letters and digits.
412
+ random_string: str = "".join(
413
+ np.random.choice(list(string.ascii_uppercase + string.digits), size=10)
414
+ )
415
+
416
+ # Combine the random number and random string into one identifier.
417
+ combined_string: str = f"{random_number}{random_string}"
418
+
419
+ # Create a new collection in ChromeDB with the combined string as its name.
420
+ collection = client.create_collection(combined_string)
421
+
422
+ st.title(f"{organization_name} Chatbot")
423
+
424
+ # Initialize chat history
425
+ if "messages" not in st.session_state:
426
+ st.session_state.messages = []
427
+
428
+ if "curr_domain" not in st.session_state:
429
+ st.session_state.curr_domain = ""
430
+
431
+ ###
432
+ ###
433
+ ### init_messages dict (one key per domain)
434
+ ###
435
+ ###
436
+
437
+ ###
438
+ ###
439
+ ### chatbot_instructions dict (one key per domain)
440
+ ###
441
+ ###
442
+
443
+ # Embed and store the first N supports for this demo
444
+ with st.spinner("Loading, please be patient with us ... 🙏"):
445
+ L = len(dataset["train"]["questions"])
446
+
447
+ collection.add(
448
+ ids=[str(i) for i in range(0, L)], # IDs are just strings
449
+ documents=dataset["train"]["questions"], # Enter questions here
450
+ metadatas=[{"type": "support"} for _ in range(0, L)],
451
+ )
452
+
453
+ if st.session_state.curr_domain != domain:
454
+ st.session_state.messages = []
455
+
456
+ init_message = init_messages[domain]
457
+ st.session_state.messages.append({"role": "assistant", "content": init_message})
458
+
459
+ st.session_state.curr_domain = domain
460
+
461
+ # Display chat messages from history on app rerun
462
+ for message in st.session_state.messages:
463
+ with st.chat_message(message["role"]):
464
+ st.markdown(message["content"])
465
+
466
+ # React to user input
467
+ if prompt := st.chat_input(f"Tell me about {organization_name"):
468
+ # Display user message in chat message container
469
+ st.chat_message("user").markdown(prompt)
470
+ # Add user message to chat history
471
+ st.session_state.messages.append({"role": "user", "content": prompt})
472
+
473
+ question = prompt
474
+
475
+ results = collection.query(query_texts=question, n_results=n_results)
476
+
477
+ idx = results["ids"][0]
478
+ idx = [int(i) for i in idx]
479
+ ref = pd.DataFrame(
480
+ {
481
+ "idx": idx,
482
+ "questions": [dataset["train"]["questions"][i] for i in idx],
483
+ "answers": [dataset["train"]["answers"][i] for i in idx],
484
+ "distances": results["distances"][0],
485
+ }
486
+ )
487
+ # special_threshold = st.sidebar.slider('How old are you?', 0, 0.6, 0.1) # 0.3
488
+ # special_threshold = 0.3
489
+ filtered_ref = ref[ref["distances"] < special_threshold]
490
+ if filtered_ref.shape[0] > 0:
491
+ # st.success("There are highly relevant information in our database.")
492
+ ref_from_db_search = filtered_ref["answers"].str.cat(sep=" ")
493
+ final_ref = filtered_ref
494
+ else:
495
+ # st.warning(
496
+ # "The database may not have relevant information to help your question so please be aware of hallucinations."
497
+ # )
498
+ ref_from_db_search = ref["answers"].str.cat(sep=" ")
499
+ final_ref = ref
500
+
501
+ engineered_prompt = f'''
502
+ Based on the context: {ref_from_db_search},
503
+ answer the user question: {question}.
504
+ '''
505
+
506
+ directions = chatbot_instructions[domain]
507
+
508
+ answer = call_chatgpt(engineered_prompt, directions)
509
+
510
+ response = answer
511
+ # Display assistant response in chat message container
512
+ with st.chat_message("assistant"):
513
+ st.markdown(response)
514
+ with st.expander("See reference:"):
515
+ st.table(final_ref)
516
+ # Add assistant response to chat history
517
+ st.session_state.messages.append({"role": "assistant", "content": response})
518
+ """
519
 
520
  st.write("app.py")
521
  st.code(app, language='python')