sid_racha
commited on
Commit
·
24fdbf8
1
Parent(s):
5940614
modified dev
Browse files- .gitignore +5 -0
- Dockerfile +19 -0
- app.py +5 -0
- app/callbacks.py +24 -0
- app/chains.py +53 -0
- app/crud.py +23 -0
- app/data_indexing.py +150 -0
- app/database.py +12 -0
- app/main.py +87 -0
- app/models.py +28 -0
- app/prompts.py +51 -0
- app/schemas.py +19 -0
- requirements.txt +3 -0
.gitignore
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*/__pycache__/
|
| 3 |
+
**/__pycache__/
|
| 4 |
+
Test-LLM-Endpoint/
|
| 5 |
+
app/set_env_vars.sh
|
Dockerfile
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.12
|
| 2 |
+
# Create a new user named 'user' with user ID 1000 and create their home directory
|
| 3 |
+
RUN useradd -m -u 1000 user
|
| 4 |
+
# Switch to the newly created user
|
| 5 |
+
USER user
|
| 6 |
+
# Add the user's local bin directory to the PATH
|
| 7 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
| 8 |
+
# Set the working directory in the container to /app
|
| 9 |
+
WORKDIR /app
|
| 10 |
+
# Copy the requirements.txt file from the host to the container
|
| 11 |
+
# The --chown=user ensures the copied file is owned by our 'user'
|
| 12 |
+
COPY --chown=user ./requirements.txt requirements.txt
|
| 13 |
+
# Install the Python dependencies listed in requirements.txt
|
| 14 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
| 15 |
+
# Copy the rest of the application code from the host to the container
|
| 16 |
+
# Again, ensure the copied files are owned by 'user'
|
| 17 |
+
COPY --chown=user . /app
|
| 18 |
+
# Specify the command to run when the container starts
|
| 19 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
app.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI
|
| 2 |
+
app = FastAPI()
|
| 3 |
+
@app.get("/")
|
| 4 |
+
def greet_json():
|
| 5 |
+
return {"Hello": "World!"}
|
app/callbacks.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Any, List
|
| 2 |
+
from langchain_core.callbacks import BaseCallbackHandler
|
| 3 |
+
import schemas
|
| 4 |
+
import crud
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class LogResponseCallback(BaseCallbackHandler):
|
| 8 |
+
|
| 9 |
+
def __init__(self, user_request: schemas.UserRequest, db):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.user_request = user_request
|
| 12 |
+
self.db = db
|
| 13 |
+
|
| 14 |
+
def on_llm_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
|
| 15 |
+
"""Run when llm ends running."""
|
| 16 |
+
# TODO: The function on_llm_end is going to be called when the LLM stops sending
|
| 17 |
+
# the response. Use the crud.add_message function to capture that response.
|
| 18 |
+
raise NotImplemented
|
| 19 |
+
|
| 20 |
+
def on_llm_start(
|
| 21 |
+
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
| 22 |
+
) -> Any:
|
| 23 |
+
for prompt in prompts:
|
| 24 |
+
print(prompt)
|
app/chains.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from langchain_huggingface import HuggingFaceEndpoint
|
| 3 |
+
from langchain_core.runnables import RunnablePassthrough
|
| 4 |
+
import schemas
|
| 5 |
+
from prompts import (
|
| 6 |
+
raw_prompt,
|
| 7 |
+
raw_prompt_formatted,
|
| 8 |
+
format_context,
|
| 9 |
+
# tokenizer
|
| 10 |
+
)
|
| 11 |
+
from data_indexing import DataIndexer
|
| 12 |
+
|
| 13 |
+
data_indexer = DataIndexer()
|
| 14 |
+
|
| 15 |
+
llm = HuggingFaceEndpoint(
|
| 16 |
+
model="meta-llama/Llama-3.1-8B-Instruct",
|
| 17 |
+
huggingfacehub_api_token=os.environ['HF_TOKEN'],
|
| 18 |
+
max_new_tokens=512,
|
| 19 |
+
# stop_sequences=[tokenizer.eos_token],
|
| 20 |
+
streaming=True,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
simple_chain = (raw_prompt | llm).with_types(input_type=schemas.UserQuestion)
|
| 24 |
+
|
| 25 |
+
# TODO: create formatted_chain by piping raw_prompt_formatted and the LLM endpoint.
|
| 26 |
+
formatted_chain = (raw_prompt_formatted | llm).with_types(input_type=schemas.UserQuestion)
|
| 27 |
+
|
| 28 |
+
# # TODO: use history_prompt_formatted and HistoryInput to create the history_chain
|
| 29 |
+
# history_chain = None
|
| 30 |
+
|
| 31 |
+
# # TODO: Let's construct the standalone_chain by piping standalone_prompt_formatted with the LLM
|
| 32 |
+
# standalone_chain = None
|
| 33 |
+
|
| 34 |
+
# input_1 = RunnablePassthrough.assign(new_question=standalone_chain)
|
| 35 |
+
# input_2 = {
|
| 36 |
+
# 'context': lambda x: format_context(data_indexer.search(x['new_question'])),
|
| 37 |
+
# 'standalone_question': lambda x: x['new_question']
|
| 38 |
+
# }
|
| 39 |
+
# input_to_rag_chain = input_1 | input_2
|
| 40 |
+
|
| 41 |
+
# # TODO: use input_to_rag_chain, rag_prompt_formatted,
|
| 42 |
+
# # HistoryInput and the LLM to build the rag_chain.
|
| 43 |
+
# rag_chain = None
|
| 44 |
+
|
| 45 |
+
# # TODO: Implement the filtered_rag_chain. It should be the
|
| 46 |
+
# # same as the rag_chain but with hybrid_search = True.
|
| 47 |
+
# filtered_rag_chain = None
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
|
app/crud.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sqlalchemy.orm import Session
|
| 2 |
+
import models, schemas
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def get_or_create_user(db: Session, username: str):
|
| 6 |
+
user = db.query(models.User).filter(models.User.username == username).first()
|
| 7 |
+
if not user:
|
| 8 |
+
user = models.User(username=username)
|
| 9 |
+
db.add(user)
|
| 10 |
+
db.commit()
|
| 11 |
+
db.refresh(user)
|
| 12 |
+
return user
|
| 13 |
+
|
| 14 |
+
def add_message(db: Session, message: schemas.MessageBase, username: str):
|
| 15 |
+
# TODO: Implement the add_message function. It should:
|
| 16 |
+
# - get or create the user with the username
|
| 17 |
+
# - create a models.Message instance
|
| 18 |
+
# - pass the retrieved user to the message instance
|
| 19 |
+
# - save the message instance to the database
|
| 20 |
+
raise NotImplemented
|
| 21 |
+
|
| 22 |
+
def get_user_chat_history(db: Session, username: str):
|
| 23 |
+
raise NotImplemented
|
app/data_indexing.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import uuid
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from pinecone.grpc import PineconeGRPC as Pinecone
|
| 5 |
+
from pinecone import ServerlessSpec
|
| 6 |
+
from langchain_community.vectorstores import Chroma
|
| 7 |
+
from langchain_openai import OpenAIEmbeddings
|
| 8 |
+
|
| 9 |
+
current_dir = Path(__file__).resolve().parent
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DataIndexer:
|
| 13 |
+
|
| 14 |
+
source_file = os.path.join(current_dir, 'sources.txt')
|
| 15 |
+
|
| 16 |
+
def __init__(self, index_name='langchain-repo') -> None:
|
| 17 |
+
|
| 18 |
+
# TODO: choose your embedding model
|
| 19 |
+
# self.embedding_client = InferenceClient(
|
| 20 |
+
# "dunzhang/stella_en_1.5B_v5",
|
| 21 |
+
# token=os.environ['HF_TOKEN'],
|
| 22 |
+
# )
|
| 23 |
+
self.embedding_client = OpenAIEmbeddings()
|
| 24 |
+
self.index_name = index_name
|
| 25 |
+
self.pinecone_client = Pinecone(api_key=os.environ.get('PINECONE_API_KEY'))
|
| 26 |
+
|
| 27 |
+
if index_name not in self.pinecone_client.list_indexes().names():
|
| 28 |
+
# TODO: create your index if it doesn't exist. Use the create_index function.
|
| 29 |
+
# Make sure to choose the dimension that corresponds to your embedding model
|
| 30 |
+
pass
|
| 31 |
+
|
| 32 |
+
self.index = self.pinecone_client.Index(self.index_name)
|
| 33 |
+
# TODO: make sure to build the index.
|
| 34 |
+
self.source_index = None
|
| 35 |
+
|
| 36 |
+
def get_source_index(self):
|
| 37 |
+
if not os.path.isfile(self.source_file):
|
| 38 |
+
print('No source file')
|
| 39 |
+
return None
|
| 40 |
+
|
| 41 |
+
print('create source index')
|
| 42 |
+
|
| 43 |
+
with open(self.source_file, 'r') as file:
|
| 44 |
+
sources = file.readlines()
|
| 45 |
+
|
| 46 |
+
sources = [s.rstrip('\n') for s in sources]
|
| 47 |
+
vectorstore = Chroma.from_texts(
|
| 48 |
+
sources, embedding=self.embedding_client
|
| 49 |
+
)
|
| 50 |
+
return vectorstore
|
| 51 |
+
|
| 52 |
+
def index_data(self, docs, batch_size=32):
|
| 53 |
+
|
| 54 |
+
with open(self.source_file, 'a') as file:
|
| 55 |
+
for doc in docs:
|
| 56 |
+
file.writelines(doc.metadata['source'] + '\n')
|
| 57 |
+
|
| 58 |
+
for i in range(0, len(docs), batch_size):
|
| 59 |
+
batch = docs[i: i + batch_size]
|
| 60 |
+
|
| 61 |
+
# TODO: create a list of the vector representations of each text data in the batch
|
| 62 |
+
# TODO: choose your embedding model
|
| 63 |
+
# values = self.embedding_client.embed_documents([
|
| 64 |
+
# doc.page_content for doc in batch
|
| 65 |
+
# ])
|
| 66 |
+
|
| 67 |
+
# values = self.embedding_client.feature_extraction([
|
| 68 |
+
# doc.page_content for doc in batch
|
| 69 |
+
# ])
|
| 70 |
+
values = None
|
| 71 |
+
|
| 72 |
+
# TODO: create a list of unique identifiers for each element in the batch with the uuid package.
|
| 73 |
+
vector_ids = None
|
| 74 |
+
|
| 75 |
+
# TODO: create a list of dictionaries representing the metadata. Capture the text data
|
| 76 |
+
# with the "text" key, and make sure to capture the rest of the doc.metadata.
|
| 77 |
+
metadatas = None
|
| 78 |
+
|
| 79 |
+
# create a list of dictionaries with keys "id" (the unique identifiers), "values"
|
| 80 |
+
# (the vector representation), and "metadata" (the metadata).
|
| 81 |
+
vectors = [{
|
| 82 |
+
'id': vector_id,
|
| 83 |
+
'values': value,
|
| 84 |
+
'metadata': metadata
|
| 85 |
+
} for vector_id, value, metadata in zip(vector_ids, values, metadatas)]
|
| 86 |
+
|
| 87 |
+
try:
|
| 88 |
+
# TODO: Use the function upsert to upload the data to the database.
|
| 89 |
+
upsert_response = None
|
| 90 |
+
print(upsert_response)
|
| 91 |
+
except Exception as e:
|
| 92 |
+
print(e)
|
| 93 |
+
|
| 94 |
+
def search(self, text_query, top_k=5, hybrid_search=False):
|
| 95 |
+
|
| 96 |
+
filter = None
|
| 97 |
+
if hybrid_search and self.source_index:
|
| 98 |
+
# I implemented the filtering process to pull the 50 most relevant file names
|
| 99 |
+
# to the question. Make sure to adjust this number as you see fit.
|
| 100 |
+
source_docs = self.source_index.similarity_search(text_query, 50)
|
| 101 |
+
filter = {"source": {"$in":[doc.page_content for doc in source_docs]}}
|
| 102 |
+
|
| 103 |
+
# TODO: embed the text_query by using the embedding model
|
| 104 |
+
# TODO: choose your embedding model
|
| 105 |
+
# vector = self.embedding_client.feature_extraction(text_query)
|
| 106 |
+
# vector = self.embedding_client.embed_query(text_query)
|
| 107 |
+
vector = None
|
| 108 |
+
|
| 109 |
+
# TODO: use the vector representation of the text_query to
|
| 110 |
+
# search the database by using the query function.
|
| 111 |
+
result = None
|
| 112 |
+
|
| 113 |
+
docs = []
|
| 114 |
+
for res in result["matches"]:
|
| 115 |
+
# TODO: From the result's metadata, extract the "text" element.
|
| 116 |
+
pass
|
| 117 |
+
|
| 118 |
+
return docs
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
if __name__ == '__main__':
|
| 122 |
+
|
| 123 |
+
from langchain_community.document_loaders import GitLoader
|
| 124 |
+
from langchain_text_splitters import (
|
| 125 |
+
Language,
|
| 126 |
+
RecursiveCharacterTextSplitter,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
loader = GitLoader(
|
| 130 |
+
clone_url="https://github.com/langchain-ai/langchain",
|
| 131 |
+
repo_path="./code_data/langchain_repo/",
|
| 132 |
+
branch="master",
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
python_splitter = RecursiveCharacterTextSplitter.from_language(
|
| 136 |
+
language=Language.PYTHON, chunk_size=10000, chunk_overlap=100
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
docs = loader.load()
|
| 140 |
+
docs = [doc for doc in docs if doc.metadata['file_type'] in ['.py', '.md']]
|
| 141 |
+
docs = [doc for doc in docs if len(doc.page_content) < 50000]
|
| 142 |
+
docs = python_splitter.split_documents(docs)
|
| 143 |
+
for doc in docs:
|
| 144 |
+
doc.page_content = '# {}\n\n'.format(doc.metadata['source']) + doc.page_content
|
| 145 |
+
|
| 146 |
+
indexer = DataIndexer()
|
| 147 |
+
with open('/app/sources.txt', 'a') as file:
|
| 148 |
+
for doc in docs:
|
| 149 |
+
file.writelines(doc.metadata['source'] + '\n')
|
| 150 |
+
indexer.index_data(docs)
|
app/database.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sqlalchemy import create_engine
|
| 2 |
+
from sqlalchemy.ext.declarative import declarative_base
|
| 3 |
+
from sqlalchemy.orm import sessionmaker
|
| 4 |
+
|
| 5 |
+
SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
|
| 6 |
+
|
| 7 |
+
engine = create_engine(
|
| 8 |
+
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
|
| 9 |
+
)
|
| 10 |
+
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
| 11 |
+
|
| 12 |
+
Base = declarative_base()
|
app/main.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_core.runnables import Runnable
|
| 2 |
+
from langchain_core.callbacks import BaseCallbackHandler
|
| 3 |
+
from fastapi import FastAPI, Request, Depends
|
| 4 |
+
from sse_starlette.sse import EventSourceResponse
|
| 5 |
+
from langserve.serialization import WellKnownLCSerializer
|
| 6 |
+
from typing import List
|
| 7 |
+
from sqlalchemy.orm import Session
|
| 8 |
+
|
| 9 |
+
import schemas
|
| 10 |
+
from chains import simple_chain
|
| 11 |
+
import crud, models, schemas
|
| 12 |
+
from database import SessionLocal, engine
|
| 13 |
+
from callbacks import LogResponseCallback
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
models.Base.metadata.create_all(bind=engine)
|
| 17 |
+
|
| 18 |
+
app = FastAPI()
|
| 19 |
+
|
| 20 |
+
def get_db():
|
| 21 |
+
db = SessionLocal()
|
| 22 |
+
try:
|
| 23 |
+
yield db
|
| 24 |
+
finally:
|
| 25 |
+
db.close()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
async def generate_stream(input_data: schemas.BaseModel, runnable: Runnable, callbacks: List[BaseCallbackHandler]=[]):
|
| 29 |
+
for output in runnable.stream(input_data.dict(), config={"callbacks": callbacks}):
|
| 30 |
+
data = WellKnownLCSerializer().dumps(output).decode("utf-8")
|
| 31 |
+
yield {'data': data, "event": "data"}
|
| 32 |
+
yield {"event": "end"}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@app.post("/simple/stream")
|
| 36 |
+
async def simple_stream(request: Request):
|
| 37 |
+
data = await request.json()
|
| 38 |
+
user_question = schemas.UserQuestion(**data['input'])
|
| 39 |
+
return EventSourceResponse(generate_stream(user_question, simple_chain))
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@app.post("/formatted/stream")
|
| 43 |
+
async def formatted_stream(request: Request):
|
| 44 |
+
# TODO: use the formatted_chain to implement the "/formatted/stream" endpoint.
|
| 45 |
+
raise NotImplemented
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@app.post("/history/stream")
|
| 49 |
+
async def history_stream(request: Request, db: Session = Depends(get_db)):
|
| 50 |
+
# TODO: Let's implement the "/history/stream" endpoint. The endpoint should follow those steps:
|
| 51 |
+
# - The endpoint receives the request
|
| 52 |
+
# - The request is parsed into a user request
|
| 53 |
+
# - The user request is used to pull the chat history of the user
|
| 54 |
+
# - We add as part of the user history the current question by using add_message.
|
| 55 |
+
# - We create an instance of HistoryInput by using format_chat_history.
|
| 56 |
+
# - We use the history input within the history chain.
|
| 57 |
+
raise NotImplemented
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@app.post("/rag/stream")
|
| 61 |
+
async def rag_stream(request: Request, db: Session = Depends(get_db)):
|
| 62 |
+
# TODO: Let's implement the "/rag/stream" endpoint. The endpoint should follow those steps:
|
| 63 |
+
# - The endpoint receives the request
|
| 64 |
+
# - The request is parsed into a user request
|
| 65 |
+
# - The user request is used to pull the chat history of the user
|
| 66 |
+
# - We add as part of the user history the current question by using add_message.
|
| 67 |
+
# - We create an instance of HistoryInput by using format_chat_history.
|
| 68 |
+
# - We use the history input within the rag chain.
|
| 69 |
+
raise NotImplemented
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@app.post("/filtered_rag/stream")
|
| 73 |
+
async def filtered_rag_stream(request: Request, db: Session = Depends(get_db)):
|
| 74 |
+
# TODO: Let's implement the "/filtered_rag/stream" endpoint. The endpoint should follow those steps:
|
| 75 |
+
# - The endpoint receives the request
|
| 76 |
+
# - The request is parsed into a user request
|
| 77 |
+
# - The user request is used to pull the chat history of the user
|
| 78 |
+
# - We add as part of the user history the current question by using add_message.
|
| 79 |
+
# - We create an instance of HistoryInput by using format_chat_history.
|
| 80 |
+
# - We use the history input within the filtered rag chain.
|
| 81 |
+
raise NotImplemented
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
if __name__ == "__main__":
|
| 86 |
+
import uvicorn
|
| 87 |
+
uvicorn.run("main:app", host="localhost", reload=True, port=8000)
|
app/models.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sqlalchemy import Column, ForeignKey, Integer, String, DateTime
|
| 2 |
+
from sqlalchemy.orm import relationship
|
| 3 |
+
|
| 4 |
+
from database import Base
|
| 5 |
+
|
| 6 |
+
class User(Base):
|
| 7 |
+
__tablename__ = "users"
|
| 8 |
+
|
| 9 |
+
id = Column(Integer, primary_key=True, index=True)
|
| 10 |
+
username = Column(String, unique=True, index=True)
|
| 11 |
+
messages = relationship("Message", back_populates="user")
|
| 12 |
+
|
| 13 |
+
# TODO: Implement the Message SQLAlchemy model. Message should have a primary key,
|
| 14 |
+
# a message attribute to store the content of messages, a type, AI or Human,
|
| 15 |
+
# depending on if it is a user question or an AI response, a timestamp to
|
| 16 |
+
# order by time and a user attribute to get the user instance associated
|
| 17 |
+
# with the message. We also need a user_id that will use the User.id
|
| 18 |
+
# attribute as a foreign key.
|
| 19 |
+
class Message(Base):
|
| 20 |
+
__tablename__ = "messages"
|
| 21 |
+
|
| 22 |
+
id = Column(Integer, primary_key=True, index=True)
|
| 23 |
+
message = Column(String, index=True)
|
| 24 |
+
type = Column(String) # "AI" or "Human"
|
| 25 |
+
timestamp = Column(DateTime, index=True)
|
| 26 |
+
user_id = Column(Integer, ForeignKey("users.id"))
|
| 27 |
+
|
| 28 |
+
user = relationship("User", back_populates="messages")
|
app/prompts.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_core.prompts import PromptTemplate
|
| 2 |
+
from typing import List
|
| 3 |
+
import models
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def format_prompt(prompt) -> PromptTemplate:
|
| 7 |
+
# TODO: format the input prompt by using the model specific instruction template
|
| 8 |
+
# TODO: return a langchain PromptTemplate
|
| 9 |
+
return PromptTemplate.from_template(prompt)
|
| 10 |
+
|
| 11 |
+
def format_chat_history(messages: List[models.Message]):
|
| 12 |
+
# TODO: implement format_chat_history to format
|
| 13 |
+
# the list of Message into a text of chat history.
|
| 14 |
+
raise NotImplemented
|
| 15 |
+
|
| 16 |
+
def format_context(docs: List[str]):
|
| 17 |
+
# TODO: the output of the DataIndexer.search is a list of text,
|
| 18 |
+
# so we need to concatenate that list into a text that can fit into
|
| 19 |
+
# the rag_prompt_formatted. Implement format_context that takes a
|
| 20 |
+
# like of strings and returns the context as one string.
|
| 21 |
+
raise NotImplemented
|
| 22 |
+
|
| 23 |
+
raw_prompt = "{question}"
|
| 24 |
+
|
| 25 |
+
# TODO: Create the history_prompt prompt that will capture the question and the conversation history.
|
| 26 |
+
# The history_prompt needs a {chat_history} placeholder and a {question} placeholder.
|
| 27 |
+
history_prompt: str = None
|
| 28 |
+
|
| 29 |
+
# TODO: Create the standalone_prompt prompt that will capture the question and the chat history
|
| 30 |
+
# to generate a standalone question. It needs a {chat_history} placeholder and a {question} placeholder,
|
| 31 |
+
standalone_prompt: str = None
|
| 32 |
+
|
| 33 |
+
# TODO: Create the rag_prompt that will capture the context and the standalone question to generate
|
| 34 |
+
# a final answer to the question.
|
| 35 |
+
rag_prompt: str = None
|
| 36 |
+
|
| 37 |
+
# TODO: create raw_prompt_formatted by using format_prompt
|
| 38 |
+
raw_prompt_formatted = None
|
| 39 |
+
raw_prompt = PromptTemplate.from_template(raw_prompt)
|
| 40 |
+
|
| 41 |
+
# TODO: use format_prompt to create history_prompt_formatted
|
| 42 |
+
history_prompt_formatted: PromptTemplate = None
|
| 43 |
+
# TODO: use format_prompt to create standalone_prompt_formatted
|
| 44 |
+
standalone_prompt_formatted: PromptTemplate = None
|
| 45 |
+
# TODO: use format_prompt to create rag_prompt_formatted
|
| 46 |
+
rag_prompt_formatted: PromptTemplate = None
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
app/schemas.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic.v1 import BaseModel
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class UserQuestion(BaseModel):
|
| 5 |
+
question: str
|
| 6 |
+
|
| 7 |
+
# TODO: create a HistoryInput data model with a chat_history and question attributes.
|
| 8 |
+
class HistoryInput(BaseModel):
|
| 9 |
+
pass
|
| 10 |
+
|
| 11 |
+
# TODO: let's create a UserRequest data model with a question and username attribute.
|
| 12 |
+
# This will be used to parse the input request.
|
| 13 |
+
class UserRequest(BaseModel):
|
| 14 |
+
username: str
|
| 15 |
+
|
| 16 |
+
# TODO: implement MessageBase as a schema mapping from the database model to the
|
| 17 |
+
# FastAPI data model. Basically MessageBase should have the same attributes as models.Message
|
| 18 |
+
class MessageBase(BaseModel):
|
| 19 |
+
pass
|
requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
uvicorn[standard]
|
| 3 |
+
langchain-huggingface==0.2.0
|