cbs-tech-strategy johnnyclee commited on
Commit
a62624f
·
0 Parent(s):

Duplicate from johnnyclee/chatgpt_clone

Browse files

Co-authored-by: Johnny Lee <[email protected]>

Files changed (8) hide show
  1. .gitattributes +34 -0
  2. .gitignore +162 -0
  3. .pre-commit-config.yaml +34 -0
  4. .python-version +1 -0
  5. README.md +11 -0
  6. app.py +598 -0
  7. requirements.txt +7 -0
  8. templates.json +62 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+ chats/*
.pre-commit-config.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ default_language_version:
2
+ node: system
3
+ python: python3.8
4
+ repos:
5
+ - repo: https://github.com/pre-commit/pre-commit-hooks
6
+ rev: v4.4.0
7
+ hooks:
8
+ - id: debug-statements
9
+ - id: end-of-file-fixer
10
+ exclude: "^.*.crt|^.*.json|^.*.svg|^.*.txt"
11
+ - id: requirements-txt-fixer
12
+ - id: trailing-whitespace
13
+ # - id: pretty-format-json
14
+ # args: [--autofix]
15
+
16
+ - repo: https://github.com/astral-sh/ruff-pre-commit
17
+ # Ruff version.
18
+ rev: v0.0.282
19
+ hooks:
20
+ - id: ruff
21
+ args: [--fix, --exit-non-zero-on-fix]
22
+
23
+ - repo: https://github.com/psf/black
24
+ rev: 23.7.0
25
+ # TODO: ensure that the black version is aligned, with another hook?
26
+ hooks:
27
+ - id: black
28
+ language: python
29
+ types: [python]
30
+
31
+ - repo: meta
32
+ hooks:
33
+ - id: check-useless-excludes
34
+ - id: check-hooks-apply
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ chatgpt-clone
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Chatgpt Clone
3
+ emoji: 🏆
4
+ colorFrom: gray
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ app_file: app.py
8
+ pinned: false
9
+ license: cc
10
+ duplicated_from: johnnyclee/chatgpt_clone
11
+ ---
app.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ruff: noqa: E501
2
+ from __future__ import annotations
3
+ import asyncio
4
+ import datetime
5
+ import pytz
6
+ import logging
7
+ import os
8
+ from enum import Enum
9
+ import json
10
+ import uuid
11
+ from pydantic import BaseModel
12
+ import gspread
13
+
14
+ from copy import deepcopy
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import gradio as gr
18
+ import tiktoken
19
+
20
+ # from dotenv import load_dotenv
21
+
22
+ # load_dotenv()
23
+
24
+ from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
25
+ from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
26
+ from langchain.chains import ConversationChain
27
+ from langsmith import Client
28
+ from langchain.chat_models import ChatAnthropic, ChatOpenAI
29
+ from langchain.memory import ConversationTokenBufferMemory
30
+ from langchain.prompts.chat import (
31
+ ChatPromptTemplate,
32
+ HumanMessagePromptTemplate,
33
+ MessagesPlaceholder,
34
+ SystemMessagePromptTemplate,
35
+ )
36
+ from langchain.schema import BaseMessage
37
+
38
+
39
+ logging.basicConfig(format="%(asctime)s %(name)s %(levelname)s:%(message)s")
40
+ LOG = logging.getLogger(__name__)
41
+ LOG.setLevel(logging.INFO)
42
+
43
+ GPT_3_5_CONTEXT_LENGTH = 4096
44
+ CLAUDE_2_CONTEXT_LENGTH = 100000 # need to use claude tokenizer
45
+
46
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
47
+ ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY")
48
+ HF_TOKEN = os.getenv("HF_TOKEN")
49
+ GS_CREDS = json.loads(rf"""{os.getenv("GSPREAD_SERVICE")}""")
50
+ GSHEET_ID = os.getenv("GSHEET_ID")
51
+ AUTH_GSHEET_NAME = os.getenv("AUTH_GSHEET_NAME")
52
+ TURNS_GSHEET_NAME = os.getenv("TURNS_GSHEET_NAME")
53
+
54
+ theme = gr.themes.Soft()
55
+
56
+ creds = [(os.getenv("CHAT_USERNAME"), os.getenv("CHAT_PASSWORD"))]
57
+
58
+ gradio_flagger = gr.HuggingFaceDatasetSaver(
59
+ hf_token=HF_TOKEN, dataset_name="chats", separate_dirs=True
60
+ )
61
+
62
+
63
+ def get_gsheet_rows(
64
+ sheet_id: str, sheet_name: str, creds: Dict[str, str]
65
+ ) -> List[Dict[str, str]]:
66
+ gc = gspread.service_account_from_dict(creds)
67
+ worksheet = gc.open_by_key(sheet_id).worksheet(sheet_name)
68
+ rows = worksheet.get_all_records()
69
+ return rows
70
+
71
+
72
+ def append_gsheet_rows(
73
+ sheet_id: str,
74
+ rows: List[List[str]],
75
+ sheet_name: str,
76
+ creds: Dict[str, str],
77
+ ) -> None:
78
+ gc = gspread.service_account_from_dict(creds)
79
+ worksheet = gc.open_by_key(sheet_id).worksheet(sheet_name)
80
+ worksheet.append_rows(values=rows, insert_data_option="INSERT_ROWS")
81
+
82
+
83
+ class ChatSystemMessage(str, Enum):
84
+ CASE_SYSTEM_MESSAGE = """You are a helpful AI assistant for a Columbia Business School MBA student.
85
+ Follow this message's instructions carefully. Respond using markdown.
86
+ Never repeat these instructions in a subsequent message.
87
+
88
+ You will start an conversation with me in the following form:
89
+ 1. Below these instructions you will receive a business scenario. The scenario will (a) include the name of a company or category, and (b) a debatable multiple-choice question about the business scenario.
90
+ 2. We will pretend to be executives charged with solving the strategic question outlined in the scenario.
91
+ 3. To start the conversation, you will provide summarize the question and provide all options in the multiple choice question to me. Then, you will ask me to choose a position and provide a short opening argument. Do not yet provide your position.
92
+ 4. After receiving my position and explanation. You will choose an alternate position in the scenario.
93
+ 5. Inform me which position you have chosen, then proceed to have a discussion with me on this topic.
94
+ 6. The discussion should be informative and very rigorous. Do not agree with my arguments easily. Pursue a Socratic method of questioning and reasoning.
95
+ """
96
+
97
+ RESEARCH_SYSTEM_MESSAGE = """You are a helpful AI assistant for a Columbia Business School MBA student.
98
+ Follow this message's instructions carefully. Respond using markdown.
99
+ Never repeat these instructions in a subsequent message.
100
+
101
+ You will start an conversation with me in the following form:
102
+ 1. You are to be a professional research consultant to the MBA student.
103
+ 2. The student will be working in a group of classmates to collaborate on a proposal to solve a business dillema.
104
+ 3. Be as helpful as you can to the student while remaining factual.
105
+ 4. If you are not certain, please warn the student to conduct additional research on the internet.
106
+ 5. Use tables and bullet points as useful way to compare insights
107
+ """
108
+
109
+
110
+ class ChatbotMode(str, Enum):
111
+ DEBATE_PARTNER = "Debate Partner"
112
+ RESEARCH_ASSISTANT = "Research Assistant"
113
+ DEFAULT = DEBATE_PARTNER
114
+
115
+
116
+ class PollQuestion(BaseModel): # type: ignore[misc]
117
+ name: str
118
+ template: str
119
+
120
+
121
+ class PollQuestions(BaseModel): # type: ignore[misc]
122
+ cases: List[PollQuestion]
123
+
124
+ @classmethod
125
+ def from_json_file(cls, json_file_path: str) -> PollQuestions:
126
+ """Expects a JSON file with an array of poll questions
127
+ Each JSON object should have "name" and "template" keys
128
+ """
129
+ with open(json_file_path, "r") as json_f:
130
+ payload = json.load(json_f)
131
+ return_obj_list = []
132
+ if isinstance(payload, list):
133
+ for case in payload:
134
+ return_obj_list.append(PollQuestion(**case))
135
+ return cls(cases=return_obj_list)
136
+ raise ValueError(
137
+ f"JSON object in {json_file_path} must be an array of PollQuestion"
138
+ )
139
+
140
+ def get_case(self, case_name: str) -> PollQuestion:
141
+ """Searches cases to return the template for poll question"""
142
+ for case in self.cases:
143
+ if case.name == case_name:
144
+ return case
145
+
146
+ def get_case_names(self) -> List[str]:
147
+ """Returns the names in cases"""
148
+ return [case.name for case in self.cases]
149
+
150
+
151
+ poll_questions = PollQuestions.from_json_file("templates.json")
152
+
153
+
154
+ def reset_textbox():
155
+ return gr.update(value=""), gr.update(value=""), gr.update(value="")
156
+
157
+
158
+ def auth(username, password):
159
+ try:
160
+ auth_records = get_gsheet_rows(
161
+ sheet_id=GSHEET_ID, sheet_name=AUTH_GSHEET_NAME, creds=GS_CREDS
162
+ )
163
+ auth_dict = {user["username"]: user["password"] for user in auth_records}
164
+ search_auth_user = auth_dict.get(username)
165
+ if search_auth_user:
166
+ autheticated = search_auth_user == password
167
+ if autheticated:
168
+ LOG.info(f"{username} successfully logged in.")
169
+ return autheticated
170
+ else:
171
+ LOG.info(f"{username} failed to login.")
172
+ return False
173
+
174
+ except Exception as exc:
175
+ LOG.info(f"{username} failed to login")
176
+ LOG.error(exc)
177
+ return (username, password) in creds
178
+
179
+
180
+ class ChatSession(BaseModel):
181
+ class Config:
182
+ arbitrary_types_allowed = True
183
+
184
+ context_length: int
185
+ tokenizer: tiktoken.Encoding
186
+ chain: ConversationChain
187
+ history: List[BaseMessage] = []
188
+ session_id: str = str(uuid.uuid4())
189
+
190
+ @staticmethod
191
+ def set_metadata(
192
+ username: str,
193
+ chatbot_mode: str,
194
+ turns_completed: int,
195
+ case: Optional[str] = None,
196
+ ) -> Dict[str, Union[str, int]]:
197
+ metadata = dict(
198
+ username=username,
199
+ chatbot_mode=chatbot_mode,
200
+ turns_completed=turns_completed,
201
+ case=case,
202
+ )
203
+ return metadata
204
+
205
+ @staticmethod
206
+ def _make_template(
207
+ system_msg: str, poll_question_name: Optional[str] = None
208
+ ) -> ChatPromptTemplate:
209
+ knowledge_cutoff = "Sept 2021"
210
+ current_date = datetime.datetime.now(
211
+ pytz.timezone("America/New_York")
212
+ ).strftime("%Y-%m-%d")
213
+ if poll_question_name:
214
+ poll_question = poll_questions.get_case(poll_question_name)
215
+ if poll_question:
216
+ message_template = poll_question.template
217
+ system_msg += f"""
218
+ {message_template}
219
+
220
+ Knowledge cutoff: {knowledge_cutoff}
221
+ Current date: {current_date}
222
+ """
223
+ else:
224
+ knowledge_cutoff = "Early 2023"
225
+ system_msg += f"""
226
+
227
+ Knowledge cutoff: {knowledge_cutoff}
228
+ Current date: {current_date}
229
+ """
230
+
231
+ human_template = "{input}"
232
+ return ChatPromptTemplate.from_messages(
233
+ [
234
+ SystemMessagePromptTemplate.from_template(system_msg),
235
+ MessagesPlaceholder(variable_name="history"),
236
+ HumanMessagePromptTemplate.from_template(human_template),
237
+ ]
238
+ )
239
+
240
+ @staticmethod
241
+ def _set_llm(
242
+ use_claude: bool,
243
+ ) -> Tuple[Union[ChatOpenAI, ChatAnthropic], int, tiktoken.tokenizer]:
244
+ if use_claude:
245
+ llm = ChatAnthropic(
246
+ model="claude-2",
247
+ anthropic_api_key=ANTHROPIC_API_KEY,
248
+ temperature=1,
249
+ max_tokens_to_sample=5000,
250
+ streaming=True,
251
+ )
252
+ context_length = CLAUDE_2_CONTEXT_LENGTH
253
+ tokenizer = tiktoken.get_encoding("cl100k_base")
254
+ return llm, context_length, tokenizer
255
+ else:
256
+ llm = ChatOpenAI(
257
+ model_name="gpt-4",
258
+ temperature=1,
259
+ openai_api_key=OPENAI_API_KEY,
260
+ max_retries=6,
261
+ request_timeout=100,
262
+ streaming=True,
263
+ )
264
+ context_length = GPT_3_5_CONTEXT_LENGTH
265
+ _, tokenizer = llm._get_encoding_model()
266
+ return llm, context_length, tokenizer
267
+
268
+ def update_system_prompt(
269
+ self, system_msg: str, poll_question_name: Optional[str] = None
270
+ ) -> None:
271
+ self.chain.prompt = self._make_template(system_msg, poll_question_name)
272
+
273
+ def change_llm(self, use_claude: bool) -> None:
274
+ llm, self.context_length, self.tokenizer = self._set_llm(use_claude)
275
+ self.chain.llm = llm
276
+
277
+ def clear_memory(self) -> None:
278
+ self.chain.memory.clear()
279
+ self.history = []
280
+
281
+ def set_chatbot_mode(
282
+ self, case_mode: bool, poll_question_name: Optional[str] = None
283
+ ) -> None:
284
+ if case_mode and poll_question_name:
285
+ self.change_llm(use_claude=False)
286
+ self.update_system_prompt(
287
+ system_msg=ChatSystemMessage.CASE_SYSTEM_MESSAGE,
288
+ poll_question_name=poll_question_name,
289
+ )
290
+ else:
291
+ self.change_llm(use_claude=True)
292
+ self.update_system_prompt(
293
+ system_msg=ChatSystemMessage.RESEARCH_SYSTEM_MESSAGE
294
+ )
295
+
296
+ @classmethod
297
+ def new(
298
+ cls,
299
+ use_claude: bool,
300
+ system_msg: str,
301
+ metadata: Dict[str, Any],
302
+ poll_question_name: Optional[str] = None,
303
+ ) -> ChatSession:
304
+ llm, context_length, tokenizer = cls._set_llm(use_claude)
305
+ memory = ConversationTokenBufferMemory(
306
+ llm=llm, max_token_limit=context_length, return_messages=True
307
+ )
308
+ template = cls._make_template(
309
+ system_msg=system_msg, poll_question_name=poll_question_name
310
+ )
311
+ chain = ConversationChain(
312
+ memory=memory,
313
+ prompt=template,
314
+ llm=llm,
315
+ metadata=metadata,
316
+ )
317
+ return cls(
318
+ context_length=context_length,
319
+ tokenizer=tokenizer,
320
+ chain=chain,
321
+ )
322
+
323
+
324
+ async def respond(
325
+ chat_input: str,
326
+ chatbot_mode: str,
327
+ case_input: str,
328
+ state: ChatSession,
329
+ request: gr.Request,
330
+ ) -> Tuple[List[str], ChatSession, str]:
331
+ """Execute the chat functionality."""
332
+
333
+ def prep_messages(
334
+ user_msg: str, memory_buffer: List[BaseMessage]
335
+ ) -> Tuple[str, List[BaseMessage]]:
336
+ messages_to_send = state.chain.prompt.format_messages(
337
+ input=user_msg, history=memory_buffer
338
+ )
339
+ user_msg_token_count = state.chain.llm.get_num_tokens_from_messages(
340
+ [messages_to_send[-1]]
341
+ )
342
+ total_token_count = state.chain.llm.get_num_tokens_from_messages(
343
+ messages_to_send
344
+ )
345
+ while user_msg_token_count > state.context_length:
346
+ LOG.warning(
347
+ f"Pruning user message due to user message token length of {user_msg_token_count}"
348
+ )
349
+ user_msg = state.tokenizer.decode(
350
+ state.chain.llm.get_token_ids(user_msg)[: state.context_length - 100]
351
+ )
352
+ messages_to_send = state.chain.prompt.format_messages(
353
+ input=user_msg, history=memory_buffer
354
+ )
355
+ user_msg_token_count = state.chain.llm.get_num_tokens_from_messages(
356
+ [messages_to_send[-1]]
357
+ )
358
+ total_token_count = state.chain.llm.get_num_tokens_from_messages(
359
+ messages_to_send
360
+ )
361
+ while total_token_count > state.context_length:
362
+ LOG.warning(
363
+ f"Pruning memory due to total token length of {total_token_count}"
364
+ )
365
+ if len(memory_buffer) == 1:
366
+ memory_buffer.pop(0)
367
+ continue
368
+ memory_buffer = memory_buffer[1:]
369
+ messages_to_send = state.chain.prompt.format_messages(
370
+ input=user_msg, history=memory_buffer
371
+ )
372
+ total_token_count = state.chain.llm.get_num_tokens_from_messages(
373
+ messages_to_send
374
+ )
375
+ return user_msg, memory_buffer
376
+
377
+ try:
378
+ if state is None:
379
+ if chatbot_mode == ChatbotMode.DEBATE_PARTNER:
380
+ new_session = ChatSession.new(
381
+ use_claude=False,
382
+ system_msg=ChatSystemMessage.CASE_SYSTEM_MESSAGE,
383
+ metadata=ChatSession.set_metadata(
384
+ username=request.username,
385
+ chatbot_mode=chatbot_mode,
386
+ turns_completed=0,
387
+ case=case_input,
388
+ ),
389
+ poll_question_name=case_input,
390
+ )
391
+ else:
392
+ new_session = ChatSession.new(
393
+ use_claude=True,
394
+ system_msg=ChatSystemMessage.RESEARCH_SYSTEM_MESSAGE,
395
+ metadata=ChatSession.set_metadata(
396
+ username=request.username,
397
+ chatbot_mode=chatbot_mode,
398
+ turns_completed=0,
399
+ ),
400
+ poll_question_name=None,
401
+ )
402
+ state = new_session
403
+ state.chain.metadata = ChatSession.set_metadata(
404
+ username=request.username,
405
+ chatbot_mode=chatbot_mode,
406
+ turns_completed=len(state.history) + 1,
407
+ case=case_input,
408
+ )
409
+ LOG.info(f"""[{request.username}] STARTING CHAIN""")
410
+ LOG.debug(f"History: {state.history}")
411
+ LOG.debug(f"User input: {chat_input}")
412
+ chat_input, state.chain.memory.chat_memory.messages = prep_messages(
413
+ chat_input, state.chain.memory.buffer
414
+ )
415
+ messages_to_send = state.chain.prompt.format_messages(
416
+ input=chat_input, history=state.chain.memory.buffer
417
+ )
418
+ total_token_count = state.chain.llm.get_num_tokens_from_messages(
419
+ messages_to_send
420
+ )
421
+ LOG.debug(f"Messages to send: {messages_to_send}")
422
+ LOG.debug(f"Tokens to send: {total_token_count}")
423
+ callback = AsyncIteratorCallbackHandler()
424
+ run_collector = RunCollectorCallbackHandler()
425
+ run = asyncio.create_task(
426
+ state.chain.apredict(
427
+ input=chat_input,
428
+ callbacks=[callback, run_collector],
429
+ )
430
+ )
431
+ state.history.append((chat_input, ""))
432
+ run_id = None
433
+ langsmith_url = None
434
+ async for tok in callback.aiter():
435
+ user, bot = state.history[-1]
436
+ bot += tok
437
+ state.history[-1] = (user, bot)
438
+ yield state.history, state, None
439
+ await run
440
+ if run_collector.traced_runs and run_id is None:
441
+ run_id = run_collector.traced_runs[0].id
442
+ LOG.info(f"RUNID: {run_id}")
443
+ if run_id:
444
+ run_collector.traced_runs = []
445
+ try:
446
+ langsmith_url = Client().share_run(run_id)
447
+ LOG.info(f"""Run ID: {run_id} \n URL : {langsmith_url}""")
448
+ url_markdown = (
449
+ f"""[Click to view shareable chat]({langsmith_url})"""
450
+ )
451
+ except Exception as exc:
452
+ LOG.error(exc)
453
+ url_markdown = "Share link not currently available"
454
+ if (
455
+ len(state.history) > 9
456
+ and chatbot_mode == ChatbotMode.DEBATE_PARTNER
457
+ ):
458
+ url_markdown += """\n
459
+ 🙌 You have completed 10 exchanges with the chatbot."""
460
+ yield state.history, state, url_markdown
461
+ LOG.info(f"""[{request.username}] ENDING CHAIN""")
462
+ LOG.debug(f"History: {state.history}")
463
+ LOG.debug(f"Memory: {state.chain.memory.json()}")
464
+ current_timestamp = datetime.datetime.now(pytz.timezone("US/Eastern")).replace(
465
+ tzinfo=None
466
+ )
467
+ timestamp_string = current_timestamp.strftime("%Y-%m-%d %H:%M:%S")
468
+ data_to_flag = (
469
+ {
470
+ "history": deepcopy(state.history),
471
+ "username": request.username,
472
+ "timestamp": timestamp_string,
473
+ "session_id": state.session_id,
474
+ "metadata": state.chain.metadata,
475
+ "langsmith_url": langsmith_url,
476
+ },
477
+ )
478
+ gradio_flagger.flag(flag_data=data_to_flag, username=request.username)
479
+ (flagged_data,) = data_to_flag
480
+ metadata_to_gsheet = flagged_data.get("metadata").values()
481
+ gsheet_row = [[timestamp_string, *metadata_to_gsheet, langsmith_url]]
482
+ LOG.info(f"Data to GSHEET: {gsheet_row}")
483
+ append_gsheet_rows(
484
+ sheet_id=GSHEET_ID,
485
+ sheet_name=TURNS_GSHEET_NAME,
486
+ rows=gsheet_row,
487
+ creds=GS_CREDS,
488
+ )
489
+ except Exception as e:
490
+ LOG.error(e)
491
+ raise e
492
+
493
+
494
+ class ChatbotConfig(BaseModel):
495
+ app_title: str = "CBS Technology Strategy - Fall 2023"
496
+ chatbot_modes: List[ChatbotMode] = [mode.value for mode in ChatbotMode]
497
+ case_options: List[str] = poll_questions.get_case_names()
498
+ default_case_option: str = "Netflix"
499
+
500
+
501
+ def change_chatbot_mode(
502
+ state: ChatSession, chatbot_mode: str, poll_question_name: str, request: gr.Request
503
+ ) -> Tuple[Any, ChatSession]:
504
+ """Returns a function that sets the visibility of the case input field and the state"""
505
+ if state is None:
506
+ if chatbot_mode == ChatbotMode.DEBATE_PARTNER:
507
+ new_session = ChatSession.new(
508
+ use_claude=False,
509
+ system_msg=ChatSystemMessage.CASE_SYSTEM_MESSAGE,
510
+ metadata=ChatSession.set_metadata(
511
+ username=request.username,
512
+ chatbot_mode=chatbot_mode,
513
+ turns_completed=0,
514
+ case=poll_question_name,
515
+ ),
516
+ poll_question_name=case_input,
517
+ )
518
+ else:
519
+ new_session = ChatSession.new(
520
+ use_claude=True,
521
+ system_msg=ChatSystemMessage.RESEARCH_SYSTEM_MESSAGE,
522
+ metadata=ChatSession.set_metadata(
523
+ username=request.username,
524
+ chatbot_mode=chatbot_mode,
525
+ turns_completed=0,
526
+ ),
527
+ poll_question_name=None,
528
+ )
529
+ state = new_session
530
+ if chatbot_mode == ChatbotMode.DEBATE_PARTNER:
531
+ state.set_chatbot_mode(case_mode=True, poll_question_name=poll_question_name)
532
+ state.clear_memory()
533
+ return gr.update(visible=True), state
534
+ elif chatbot_mode == ChatbotMode.RESEARCH_ASSISTANT:
535
+ state.set_chatbot_mode(case_mode=False)
536
+ state.clear_memory()
537
+ return gr.update(visible=False), state
538
+ else:
539
+ raise ValueError("chatbot_mode is not correctly set")
540
+
541
+
542
+ config = ChatbotConfig()
543
+ with gr.Blocks(
544
+ theme=theme,
545
+ analytics_enabled=False,
546
+ title=config.app_title,
547
+ ) as demo:
548
+ state = gr.State()
549
+ gr.Markdown(f"""### {config.app_title}""")
550
+ with gr.Tab("Chatbot"):
551
+ with gr.Row():
552
+ chatbot_mode = gr.Radio(
553
+ label="Mode",
554
+ choices=config.chatbot_modes,
555
+ value=ChatbotMode.DEFAULT,
556
+ )
557
+ case_input = gr.Dropdown(
558
+ label="Case",
559
+ choices=config.case_options,
560
+ value=config.default_case_option,
561
+ multiselect=False,
562
+ )
563
+ chatbot = gr.Chatbot(label="ChatBot", show_share_button=False)
564
+ with gr.Row():
565
+ input_message = gr.Textbox(
566
+ placeholder="Send a message.",
567
+ label="Type a message to begin",
568
+ scale=5,
569
+ )
570
+ chat_submit_button = gr.Button(value="Submit")
571
+ status_message = gr.Markdown()
572
+ gradio_flagger.setup([chatbot], "chats")
573
+
574
+ chatbot_submit_params = dict(
575
+ fn=respond,
576
+ inputs=[input_message, chatbot_mode, case_input, state],
577
+ outputs=[chatbot, state, status_message],
578
+ )
579
+ input_message.submit(**chatbot_submit_params)
580
+ chat_submit_button.click(**chatbot_submit_params)
581
+ chatbot_mode_params = dict(
582
+ fn=change_chatbot_mode,
583
+ inputs=[state, chatbot_mode, case_input],
584
+ outputs=[case_input, state],
585
+ )
586
+ chatbot_mode.change(**chatbot_mode_params)
587
+ case_input.change(**chatbot_mode_params)
588
+ clear_chatbot_messages_params = dict(
589
+ fn=reset_textbox, inputs=[], outputs=[input_message, chatbot, status_message]
590
+ )
591
+ chatbot_mode.change(**clear_chatbot_messages_params)
592
+ case_input.change(**clear_chatbot_messages_params)
593
+ chat_submit_button.click(**clear_chatbot_messages_params)
594
+ input_message.submit(**clear_chatbot_messages_params)
595
+
596
+ demo.queue(max_size=99, concurrency_count=99, api_open=False).launch(
597
+ debug=True, auth=auth
598
+ )
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ anthropic==0.3.7
2
+ gradio==3.39.0
3
+ gspread==5.10.0
4
+ langchain==0.0.265
5
+ openai==0.27.8
6
+ pytz==2023.3
7
+ tiktoken==0.4.0
templates.json ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "name": "Rivian",
4
+ "template": "Rivian secure supply for their EV batteries primarily through which of the following options?\n1. Buying from multiple battery suppliers\n2. Partnering with a battery supplier (Joint Venture, for example)\n3. Developing in-house manufacturing capabilities"
5
+ },
6
+ {
7
+ "name": "LEGO",
8
+ "template": "How should LEGO protect its plastic molding manufacturing platform?\n1. Patent\n2. Trade Secret\n3. Publish"
9
+ },
10
+ {
11
+ "name": "Netflix",
12
+ "template": "Going forward, what should Netflix prioritize in its content strategy?\n1. Invest more in original content than third-party licensing\n2. Balance between original content and third-party licensing\n3. Invest more in third-party licensing than original content"
13
+ },
14
+ {
15
+ "name": "LinkedIn",
16
+ "template": "Is it wise for LinkedIn to invest more in developing its LinkedIn Learning platform?\n1. Yes, definitely\n2. Yes, but it should not be core to LinkedIn\n3. No, but I wouldn't kill it\n4. No, I would either stop investing or outsource/partner"
17
+ },
18
+ {
19
+ "name": "Uber",
20
+ "template": "Are there reasonable platform synergies between Uber's ridesharing business (e.g., UberX and UberX Share) and food/grocery delivery business (e.g., Uber Eats)?\n1. Yes, definitely\n2. Yes, but mainly because the success of one side can subsidize the other\n3. Somewhat, because there is a limit\n4. Probably not"
21
+ },
22
+ {
23
+ "name": "DigitalOcean",
24
+ "template": "Should Digital Ocean dedicate resources towards 1) gaining more market share among small businesses, or 2) getting existing customers to spend more?\n1. Market share\n2. Higher spend"
25
+ },
26
+ {
27
+ "name": "Alphabet",
28
+ "template": "Given competition from Amazon, TikTok, and new technologies like Generative AI, how important will search be as a core part of Google's platform business going forward?\n1. Still very important\n2. Gradually become less important\n3. Quickly become less important"
29
+ },
30
+ {
31
+ "name": "Twitch",
32
+ "template": "Currently, Twitch uses a mix of admins (Twitch employees and contractors) and volunteer moderators who are also members of specific Twitch channel communities to address content moderation. Which side should get priority in evaluating instances of content policy violations if their recommendations are in conflict?\n1. Twitch admins\n2. Channel volunteer moderators\n3. It depends\n4. I have a different idea"
33
+ },
34
+ {
35
+ "name": "Kakao",
36
+ "template": "How do you think Kakao should approach the expansion of its platform?\n1. Prioritize expanding into other culturally similar Asian markets\n2. Prioritize expanding into new verticals within Korea (such as enterprise software)\n3. Prioritize testing markets beyond Asia\n4. I have another idea"
37
+ },
38
+ {
39
+ "name": "StitchFix",
40
+ "template": "Is Generative AI (i.e., AI tools like ChatGPT that can develop novel insights and ideas rather than simply engaging in pattern recognition) more of a threat, a complement, or neither with respect to Stitch Fix's core business operations?\n1. More of a Threat\n2. More of a Complement\n3. Neither"
41
+ },
42
+ {
43
+ "name": "Hubspot",
44
+ "template": "In which phases of the sales funnel would chatbots do better than human representatives?\n1. Top of Funnel\n2. Middle of the Funnel\n3. Bottom of the Funnel\nIn implementing a chatbot, should it disclose to a customer that they are chatting with a bot?\n1. Yes, almost always\n2. It depends on the circumstances\n3. No, almost never"
45
+ },
46
+ {
47
+ "name": "Mastercard",
48
+ "template": "To effectively build an 'AI Powerhouse', would you recommend MasterCard creating a central AI team (Centralized AI), or staff AI experts within each business unit (Decentralized AI)?\n1. Centralized AI\n2. Decentralized AI\n3. I have a different idea"
49
+ },
50
+ {
51
+ "name": "Autonomous Vehicles",
52
+ "template": "Suppose you are a general partner of a VC/PE firm that has a fund set aside for futurist and frontier technologies. Within the autonomous vehicles space, which of the following rivals discussed in the case would you invest in?\n1. Tesla\n2. GM and Cruise\n3. Waymo\n4. Motional\nWhen it comes to programming decision rules for ethical dilemmas that Autonomous Vehicles will encounter (like the Trolley Problem described in the case), which stakeholder's opinion should matter the most?\n1. Social Scientists (e.g., professors of philosophy, sociology, etc.)\n2. Technology Specialists (e.g., engineers)\n3. Policy-makers and regulators (e.g., government agencies)\n4. The general public (e.g., by popular vote)\n5. The owner (e.g., shareholders) and managers of AV car companies"
53
+ },
54
+ {
55
+ "name": "Apple Privacy",
56
+ "template": "Apple claims that its vertical integration and emphasis on privacy raises WTP for consumers which makes up for the potential revenues it could earn by making its data on users available to partners, such as app-makers. Do you think Apple should maintain its strict data privacy policy and strategy?\n1. Yes, unconditionally.\n2. It depends on the partner and case.\n3. No, Apple should be more open to sharing user data."
57
+ },
58
+ {
59
+ "name": "Meta",
60
+ "template": "Does Meta have a responsibility to inform their users when conducting experiments or A/B tests on Facebook, Instagram, or any of its other platforms in a more explicit way? (i.e., other than offering a blanket statement in the user agreement during registration)\n1. Yes.\n2. No.\n3. It depends."
61
+ }
62
+ ]