File size: 3,522 Bytes
538c7bc
 
 
 
 
 
 
 
 
 
 
 
144d0bc
538c7bc
 
 
 
 
 
144d0bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
538c7bc
 
 
54f84c8
538c7bc
 
 
 
 
 
 
 
54f84c8
538c7bc
960b1cd
 
 
 
 
 
 
 
 
 
 
538c7bc
 
 
 
960b1cd
 
538c7bc
 
2f0d805
538c7bc
 
54f84c8
538c7bc
 
960b1cd
 
 
 
 
 
 
 
 
 
 
 
 
144d0bc
 
dc11607
960b1cd
 
 
 
 
 
144d0bc
538c7bc
2f0d805
 
 
 
 
 
538c7bc
 
144d0bc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os

import docx2txt
from dotenv import load_dotenv

from langchain.chat_models import ChatOpenAI
from langchain.schema import (
    SystemMessage,
    HumanMessage,
    AIMessage
)
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.callbacks.base import BaseCallbackHandler

import streamlit as st

load_dotenv()


class StreamHandler(BaseCallbackHandler):
    def __init__(self, container, initial_text=""):
        self.container = container
        self.text = initial_text

    def on_llm_new_token(self, token: str, **kwargs) -> None:
        self.text += token
        self.container.markdown(self.text)


def init_gpt(gpt_model, stream_handler):
    global llm
    llm = ChatOpenAI(
        temperature=0.3,
        model=gpt_model,
        streaming=True,
        callbacks=[stream_handler]
    )


embeddings = OpenAIEmbeddings()


def generate_content(query, knowledge_base):
    # relevant_docs = db_chroma.similarity_search(query)
    system_prompt = f"""You are a professional writer of motivational letters.\
You will be given a content from a knowledge base below, delimited by triple \
backticks. Your job is to use knowledge from this data and write a \
motivational letter for graduate school application. Only write content \
using data from the knowledgebase, do not claim facts from outside of it. \
Make the letter very personal with regards to the knowledge base.

Knowledge Base: ```{knowledge_base}```
"""
    # system_message = SystemMessage(content=system_prompt)
    # human_message = HumanMessage(content=query[-1]['content'])
    # message = [system_message, human_message]
    messages = [SystemMessage(content=system_prompt)]
    for i in range(len(query)):
        if i % 2 == 0:
            temp_query = HumanMessage(content=query[i]['content'])
        else:
            temp_query = AIMessage(content=query[i]['content'])
        messages.append(temp_query)
    response = llm(messages)
    return response.content


def main():
    st.title("GradGPT 🤖")
    st.header("ChatGPT Powered Motivational Letter writer")

    uploaded_file = st.file_uploader("Upload a word file", type="docx")
    knowledge_base = ""
    if uploaded_file is not None:
        # extract text from word file
        knowledge_base = docx2txt.process(uploaded_file)
        # load_into_chroma(call_transcript)

    if "messages" not in st.session_state:
        st.session_state.messages = []

    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])

    if prompt := st.chat_input("Enter your queries here."):
        st.session_state.messages.append({"role": "user", "content": prompt})
        with st.chat_message("user"):
            st.markdown(prompt)

        with st.chat_message("assistant"):
            # message_placeholder = st.empty()
            stream_handler = StreamHandler(st.empty())
            init_gpt("gpt-3.5-turbo-16k", stream_handler)
            content = generate_content(
                st.session_state.messages, knowledge_base
            )
        st.session_state.messages.append(
            {"role": "assistant", "content": content}
        )
        # message_placeholder.markdown(content)

    with st.sidebar:
        # remove last 2 messages
        if st.button("remove previous message"):
            if len(st.session_state.messages) >= 2:
                st.session_state.messages = st.session_state.messages[:-2]


if __name__ == '__main__':
    main()