Rsr2425 commited on
Commit
0f046d0
·
1 Parent(s): a2e4e8f

Added simple RAG component to web app

Browse files
Dockerfile CHANGED
@@ -13,6 +13,8 @@ FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim
13
 
14
  WORKDIR /app
15
 
 
 
16
  # Create a non-root user
17
  RUN useradd -m -u 1000 user
18
  RUN chown -R user:user /app
 
13
 
14
  WORKDIR /app
15
 
16
+ RUN mkdir -p /app/static/data
17
+
18
  # Create a non-root user
19
  RUN useradd -m -u 1000 user
20
  RUN chown -R user:user /app
backend/app/problem_generator.py CHANGED
@@ -1,24 +1,61 @@
1
  from typing import List
 
2
 
3
- # from backend.app.vectorstore import get_vector_db
4
- from langchain.agents import AgentExecutor, create_openai_functions_agent
5
- from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
6
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
7
- from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
8
  from langchain_openai import ChatOpenAI
9
-
 
 
 
10
 
11
  class ProblemGenerator:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def generate_problems(self, query: str) -> List[str]:
13
  """
14
- Generate problems based on the user's query.
 
 
 
 
 
 
15
  """
16
- # For MVP, returning random sample questions
17
- sample_questions = [
18
- "What is the main purpose of this framework?",
19
- "How do you install this tool?",
20
- "What are the key components?",
21
- "Explain the basic workflow",
22
- "What are the best practices?"
23
- ]
24
- return sample_questions
 
1
  from typing import List
2
+ import json
3
 
4
+ from langchain_core.prompts import ChatPromptTemplate
 
 
 
 
5
  from langchain_openai import ChatOpenAI
6
+ from langchain_core.runnables import RunnablePassthrough
7
+ from langchain_core.output_parsers import StrOutputParser
8
+ from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
9
+ from backend.app.vectorstore import get_vector_db
10
 
11
  class ProblemGenerator:
12
+ def __init__(self):
13
+ # Initialize prompts
14
+ self.system_role_prompt = """
15
+ You are a helpful assistant that generates questions based on a given context.
16
+ """
17
+
18
+ self.user_role_prompt = """
19
+ Based on the following context about {query}, generate 5 relevant and specific questions.
20
+ Make sure the questions can be answered using only the provided context.
21
+
22
+ Context: {context}
23
+
24
+ Generate 5 questions that test understanding of the material in the context.
25
+
26
+ Return only a json object with the following format:
27
+ {{
28
+ "questions": ["question1", "question2", "question3", "question4", "question5"]
29
+ }}
30
+ """
31
+
32
+ # Initialize chain components
33
+ self.chat_prompt = ChatPromptTemplate.from_messages([
34
+ ("system", self.system_role_prompt),
35
+ ("user", self.user_role_prompt)
36
+ ])
37
+
38
+ self.llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.7)
39
+ self.retriever = get_vector_db().as_retriever(search_kwargs={"k": 2})
40
+
41
+ # Build the RAG chain
42
+ self.rag_chain = (
43
+ {"context": self.retriever, "query": RunnablePassthrough()}
44
+ | self.chat_prompt
45
+ | self.llm
46
+ | StrOutputParser()
47
+ )
48
+
49
  def generate_problems(self, query: str) -> List[str]:
50
  """
51
+ Generate problems based on the user's query using RAG.
52
+
53
+ Args:
54
+ query (str): The topic to generate questions about
55
+
56
+ Returns:
57
+ List[str]: A list of generated questions
58
  """
59
+ raw_result = self.rag_chain.invoke(query)
60
+ result = json.loads(raw_result)
61
+ return result["questions"]
 
 
 
 
 
 
backend/tests/test_api.py CHANGED
@@ -14,7 +14,7 @@ def test_crawl_endpoint():
14
  def test_problems_endpoint():
15
  response = client.post(
16
  "/api/problems/",
17
- json={"user_query": "test query"}
18
  )
19
  assert response.status_code == 200
20
  assert "Problems" in response.json()
 
14
  def test_problems_endpoint():
15
  response = client.post(
16
  "/api/problems/",
17
+ json={"user_query": "RAG"}
18
  )
19
  assert response.status_code == 200
20
  assert "Problems" in response.json()
test_problem_gen_rag.ipynb ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "[nltk_data] Downloading package punkt_tab to\n",
13
+ "[nltk_data] /Users/ryanrodriguez/nltk_data...\n",
14
+ "[nltk_data] Package punkt_tab is already up-to-date!\n",
15
+ "[nltk_data] Downloading package averaged_perceptron_tagger_eng to\n",
16
+ "[nltk_data] /Users/ryanrodriguez/nltk_data...\n",
17
+ "[nltk_data] Package averaged_perceptron_tagger_eng is already up-to-\n",
18
+ "[nltk_data] date!\n"
19
+ ]
20
+ }
21
+ ],
22
+ "source": [
23
+ "from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser\n",
24
+ "from langchain_core.prompts import ChatPromptTemplate\n",
25
+ "from langchain_openai import ChatOpenAI\n",
26
+ "from langchain.chains import create_retrieval_chain\n",
27
+ "from langchain.chains.combine_documents import create_stuff_documents_chain\n",
28
+ "from backend.app.vectorstore import get_vector_db"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": 17,
34
+ "metadata": {},
35
+ "outputs": [],
36
+ "source": [
37
+ "system_role_prompt = \"\"\"\n",
38
+ " You are a helpful assistant that generates questions based on a given context.\n",
39
+ "\"\"\"\n",
40
+ "\n",
41
+ "user_role_prompt = \"\"\"\n",
42
+ " Based on the following context about {query}, generate 5 relevant and specific questions.\n",
43
+ " Make sure the questions can be answered using only the provided context.\n",
44
+ "\n",
45
+ " Context: {context}\n",
46
+ "\n",
47
+ " Generate 5 questions that test understanding of the material in the context.\n",
48
+ " \n",
49
+ " Return only a json object with the following format:\n",
50
+ " {{\n",
51
+ " \"questions\": [\"question1\", \"question2\", \"question3\", \"question4\", \"question5\"]\n",
52
+ " }}\n",
53
+ "\"\"\"\n"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": 18,
59
+ "metadata": {},
60
+ "outputs": [],
61
+ "source": [
62
+ "chat_prompt = ChatPromptTemplate.from_messages([\n",
63
+ " (\"system\", system_role_prompt),\n",
64
+ " (\"user\", user_role_prompt)\n",
65
+ "])\n",
66
+ "\n",
67
+ "openai_chat_model = ChatOpenAI(model=\"gpt-3.5-turbo\", temperature=0.7)\n",
68
+ "\n",
69
+ "retriever = get_vector_db().as_retriever(search_kwargs={\"k\": 2})\n"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "execution_count": 19,
75
+ "metadata": {},
76
+ "outputs": [],
77
+ "source": [
78
+ "from langchain_core.runnables import RunnablePassthrough\n",
79
+ "from langchain_core.output_parsers import StrOutputParser\n",
80
+ "\n",
81
+ "simple_rag = (\n",
82
+ " {\"context\": retriever, \"query\": RunnablePassthrough(), \"num_questions\": RunnablePassthrough()}\n",
83
+ " | chat_prompt\n",
84
+ " | openai_chat_model\n",
85
+ " | StrOutputParser()\n",
86
+ ")"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "execution_count": 22,
92
+ "metadata": {},
93
+ "outputs": [],
94
+ "source": [
95
+ "raw_result = simple_rag.invoke(\"RAG\")"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": 23,
101
+ "metadata": {},
102
+ "outputs": [
103
+ {
104
+ "data": {
105
+ "text/plain": [
106
+ "['What are the two main components of a typical RAG application?',\n",
107
+ " 'What is the purpose of the indexing component in a RAG application?',\n",
108
+ " \"What are the steps involved in the 'Load' phase of indexing?\",\n",
109
+ " 'Why is splitting text into smaller chunks important in the context of RAG applications?',\n",
110
+ " 'How does the retrieval and generation component of a RAG application process user queries?']"
111
+ ]
112
+ },
113
+ "execution_count": 23,
114
+ "metadata": {},
115
+ "output_type": "execute_result"
116
+ }
117
+ ],
118
+ "source": [
119
+ "import json\n",
120
+ "result = json.loads(raw_result)\n",
121
+ "result[\"questions\"]"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "code",
126
+ "execution_count": null,
127
+ "metadata": {},
128
+ "outputs": [],
129
+ "source": []
130
+ }
131
+ ],
132
+ "metadata": {
133
+ "kernelspec": {
134
+ "display_name": ".venv",
135
+ "language": "python",
136
+ "name": "python3"
137
+ },
138
+ "language_info": {
139
+ "codemirror_mode": {
140
+ "name": "ipython",
141
+ "version": 3
142
+ },
143
+ "file_extension": ".py",
144
+ "mimetype": "text/x-python",
145
+ "name": "python",
146
+ "nbconvert_exporter": "python",
147
+ "pygments_lexer": "ipython3",
148
+ "version": "3.12.0"
149
+ }
150
+ },
151
+ "nbformat": 4,
152
+ "nbformat_minor": 2
153
+ }