Spaces:
Sleeping
Sleeping
import os | |
import glob | |
from pathlib import Path | |
from typing import List | |
from dotenv import load_dotenv | |
from openai import OpenAI | |
import gradio as gr | |
from pinecone import Pinecone, ServerlessSpec | |
# Load .env from the script's directory | |
env_path = Path(__file__).resolve().parent / '.env' | |
print("Loading .env from:", env_path) | |
load_dotenv(dotenv_path=env_path) | |
# Load environment variables | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY") | |
PINECONE_ENV = os.getenv("PINECONE_ENV") | |
INDEX_NAME = os.getenv("PINECONE_INDEX", "hr-handbook") | |
print("Loaded PINECONE_API_KEY:", PINECONE_API_KEY[:6] + "..." if PINECONE_API_KEY else "NOT FOUND") | |
print("Loaded OPENAI_API_KEY:", OPENAI_API_KEY[:6] + "..." if OPENAI_API_KEY else "NOT FOUND") | |
# Initialize OpenAI client | |
client = OpenAI(api_key=OPENAI_API_KEY) | |
# Define recommended questions before using them | |
recommended = { | |
"all": [ | |
"Where can I find benefits information?", | |
"How do I request time off?", | |
"What is Made Tech’s mission", | |
"How does Made Tech support learning and mentoring?", | |
], | |
"benefits": [ | |
"What does private medical insurance cover?", | |
"How do I join the pension scheme?", | |
"What is the maximum amount I can apply for under the Cycle to Work scheme?", | |
"How do I request a flexible working pattern?", | |
"How do I apply for Help to Buy Tech through TechScheme?", | |
"Can I increase or decrease my pension contributions?", | |
"When are Winter and Summer company parties held?", | |
], | |
"company": [ | |
"What is the company's mission?", | |
"What values guide the work and culture at Made Tech?", | |
"What is Made Tech's purpose?", | |
"What is the role of a peer buddy in the onboarding process?", | |
"What policies should I read as a new employee?", | |
], | |
"guides": [ | |
"How do I submit an expense?", | |
"Where is the hiring policy?", | |
"What is chalet time?", | |
"What accounts and tools are introduced during onboarding?", | |
"Will I receive a laptop before my first day?", | |
"What is Chalet Time? What are the priorities for using Chalet Time?", | |
"What should I do if I am planning to relocate?", | |
"How can I contribute to the handbook?", | |
], | |
"roles": [ | |
"What does a data scientist do?", | |
"What is the duration of the Software Engineering Academy at Made Tech?", | |
"How do career levels work?", | |
"What types of needs does an Associate Product Manager explore in their role?", | |
"How are success criteria and measurable outcomes defined?", | |
"What are the responsibilities of a Delivery Support Analyst in PMO?", | |
"How do Delivery Directors contribute to Made Tech's commercial growth?", | |
"What are some key outcomes expected from a Delivery Director?", | |
], | |
"communities-of-practice": [ | |
"How can I join a community of practice?", | |
"What is the purpose of the Book Club at Made Tech?", | |
"How often does the Book Club meet?", | |
"What is the EDGE approach to digital transformation?", | |
"How can I join the Book Club meetings?", | |
"When do CoPs meet?", | |
], | |
} | |
# Initialize Pinecone | |
def init_pinecone(index_name: str): | |
pc = Pinecone(api_key=PINECONE_API_KEY) | |
if index_name not in pc.list_indexes().names(): | |
pc.create_index( | |
name=index_name, | |
dimension=1536, | |
metric="cosine", | |
spec=ServerlessSpec(cloud="aws", region="us-east-1") | |
) | |
return pc.Index(index_name) | |
# Load text files | |
def load_documents(root_dir: str) -> List[dict]: | |
docs = [] | |
for path in Path(root_dir).rglob("*.txt"): | |
category = path.parts[1] if len(path.parts) > 1 else "general" | |
with open(path, "r", encoding="utf-8", errors="ignore") as f: | |
content = f.read() | |
docs.append({"id": str(path), "text": content, "category": category}) | |
return docs | |
# Embed and upsert | |
def index_documents(index, docs: List[dict]): | |
for batch_start in range(0, len(docs), 100): | |
batch = docs[batch_start:batch_start + 100] | |
ids = [doc["id"] for doc in batch] | |
texts = [doc["text"] for doc in batch] | |
embeddings = client.embeddings.create(input=texts, model="text-embedding-ada-002") | |
vectors = [ | |
(id_, emb.embedding, {"category": doc["category"]}) | |
for id_, emb, doc in zip(ids, embeddings.data, batch) | |
] | |
index.upsert(vectors) | |
# Query Pinecone | |
def retrieve(query: str, index, category: str = None, k: int = 5) -> List[str]: | |
embed = client.embeddings.create(input=[query], model="text-embedding-ada-002").data[0].embedding | |
kwargs = {"top_k": k, "include_metadata": True} | |
if category: | |
kwargs["filter"] = {"category": {"$eq": category}} | |
res = index.query(vector=embed, **kwargs) | |
return [m["metadata"].get("text", "") for m in res["matches"] if "metadata" in m and m["metadata"].get("text")] | |
# Generate answer | |
def generate_answer(query: str, docs: List[str]) -> str: | |
system_prompt = ( | |
"You are a helpful HR assistant. Use the provided context to answer the question.\n" | |
"If the answer is not contained in the context, reply that you don't know." | |
) | |
context = "\n\n".join(docs) | |
messages = [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"} | |
] | |
response = client.chat.completions.create(model="gpt-3.5-turbo", messages=messages) | |
return response.choices[0].message.content.strip() | |
# Gradio logic | |
def answer_question(query: str, category: str): | |
docs = retrieve(query, pinecone_index, category) | |
return generate_answer(query, docs) | |
# Main logic | |
if __name__ == "__main__": | |
pinecone_index = init_pinecone(INDEX_NAME) | |
if not int(os.getenv("SKIP_INDEXING", "0")): | |
documents = load_documents(".") | |
index_documents(pinecone_index, documents) | |
categories = sorted(set(Path(p).parts[0] for p in glob.glob('*/*.txt')) | set(recommended.keys()) - {"all"}) | |
with gr.Blocks() as demo: | |
gr.HTML(""" | |
<style> | |
#banner-img { | |
display: flex; | |
justify-content: center; | |
margin-bottom: 20px; | |
} | |
#banner-img img { | |
max-width: 800px; | |
width: 100%; | |
height: auto; | |
border-radius: 10px; | |
} | |
.gradio-container .gr-tabnav button { | |
background: linear-gradient(to right, #36d1dc, #5b86e5) !important; | |
color: white !important; | |
border: none !important; | |
border-radius: 8px !important; | |
padding: 10px 16px; | |
margin: 0 4px; | |
font-weight: bold; | |
transition: 0.3s; | |
} | |
.gradio-container .gr-tabnav button:hover { | |
background: linear-gradient(to right, #5b86e5, #36d1dc) !important; | |
transform: scale(1.05); | |
} | |
.gradio-container .gr-tabnav button[aria-selected="true"] { | |
background: #1e3c72 !important; | |
font-weight: bold; | |
} | |
</style> | |
""") | |
with gr.Row(): | |
banner_path = Path(__file__).resolve().parent / "bannerhr.png" | |
banner_value = str(banner_path) if banner_path.exists() else None | |
gr.Image(value=banner_value, show_label=False, show_download_button=False, elem_id="banner-img") | |
with gr.Tabs(): | |
for cat in categories: | |
with gr.Tab(cat.capitalize()): | |
with gr.Row(): | |
with gr.Column(): | |
example_choices = recommended.get(cat, recommended["all"]) | |
example_value = example_choices[0] if example_choices else "" | |
examples = gr.Dropdown( | |
choices=example_choices, | |
label="Recommended questions", | |
value=example_value | |
) | |
with gr.Column(): | |
query = gr.Textbox( | |
label="Ask a question", | |
value=example_value | |
) | |
submit = gr.Button("Submit") | |
answer = gr.Textbox(label="Answer") | |
examples.change(lambda q: q, inputs=examples, outputs=query) | |
submit.click(lambda q, cat=cat: answer_question(q, cat), inputs=[query], outputs=answer) | |
demo.launch() | |