Spaces:
Running
Running
import os | |
import getpass | |
from operator import itemgetter | |
from typing import List, Dict | |
import json | |
import requests | |
import traceback | |
#LangChain, LangGraph | |
from langchain_openai import ChatOpenAI | |
from langgraph.graph import START, StateGraph, END | |
from typing_extensions import List, TypedDict | |
# from langchain_core.documents import Document | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain.schema.output_parser import StrOutputParser | |
from langchain_core.tools import Tool, tool | |
from langgraph.prebuilt import ToolNode | |
from typing import TypedDict, Annotated | |
from langgraph.graph.message import add_messages | |
import operator | |
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage | |
from langchain.vectorstores import Qdrant | |
from langchain.embeddings import OpenAIEmbeddings | |
from langchain.schema import Document | |
from qdrant_client import QdrantClient | |
from qdrant_client.http.models import Distance, VectorParams | |
import chainlit as cl | |
import tempfile | |
import shutil | |
#helper imports | |
from code_analysis import * | |
from tools import search_pypi, write_to_docx | |
from prompts import main_prompt, documenter_prompt, code_description_prompt | |
from states import AgentState | |
# Global variables to store processed data | |
processed_file_path = None | |
document_file_path = None | |
vectorstore = None | |
main_chain = None | |
qdrant_client = None | |
async def on_chat_start(): | |
await cl.Message(content="Welcome to the Python Code Documentation Assistant! Please upload a Python file to get started.").send() | |
async def on_message(message: cl.Message): | |
global processed_file_path, document_file_path, vectorstore, main_chain, qdrant_client | |
if message.elements and any(el.type == "file" for el in message.elements): | |
file_elements = [el for el in message.elements if el.type == "file"] | |
file_element = file_elements[0] | |
is_python_file = ( | |
file_element.mime.startswith("text/x-python") or | |
file_element.name.endswith(".py") or | |
file_element.mime == "text/plain" # Some systems identify .py as text/plain | |
) | |
if is_python_file: | |
# Send processing message | |
msg = cl.Message(content="Processing your Python file...") | |
await msg.send() | |
print(f'file element \n {file_element} \n') | |
# Save uploaded file to a temporary location | |
temp_dir = tempfile.mkdtemp() | |
file_path = os.path.join(temp_dir, file_element.name) | |
with open(file_element.path, "rb") as source_file: | |
file_content_bytes = source_file.read() | |
with open(file_path, "wb") as destination_file: | |
destination_file.write(file_content_bytes) | |
processed_file_path = file_path | |
try: | |
# read file and extract imports | |
file_content = read_python_file(file_path) | |
# imports = extract_imports(file_content, file_path) | |
print(f'Done reading file') | |
# Define describe packages graph | |
search_packages_tools = [search_pypi] | |
##################################### DESCRIBE CODE AGENT #################################### | |
describe_code_llm = ChatOpenAI(model="gpt-4o-mini") | |
# describe_imports_llm = describe_imports_llm.bind_tools(tools = search_packages_tools, tool_choice="required") | |
describe_code_prompt = ChatPromptTemplate.from_messages([ | |
("system", code_description_prompt), | |
("human", "{code}") | |
]) | |
describe_code_chain = ( | |
{"code_language": itemgetter("code_language"), "code": itemgetter("code")} | |
| describe_code_prompt | describe_code_llm | StrOutputParser() | |
) | |
print(f'done defining imports chain') | |
# Define describe code chain node | |
def describe_code(state): | |
# print("Starting chain function") | |
last_message= state["messages"][-1] | |
# print(f'last message is \n {last_message}') | |
content = json.loads(last_message.content) | |
# print(f'content is {content}') | |
# print(type(content)) | |
chain_input = {"code_language": content['code_language'], | |
"code": content['code']} | |
# print(f'chain_input is {chain_input}') | |
# print(type(chain_input)) | |
response = describe_code_chain.invoke(chain_input) | |
# print(f"Chain response: {response}") | |
return {"messages": [AIMessage(content=response)]} | |
######################################## DOCUMENT WRITER AGENT ###################################3 | |
documenter_llm = ChatOpenAI(model="gpt-4o-mini") | |
documenter_llm_prompt = ChatPromptTemplate.from_messages([ | |
("system", documenter_prompt), | |
("human", "{content}") | |
]) | |
documenter_chain = ( | |
{"content": itemgetter("content")} | |
| documenter_llm_prompt | |
| documenter_llm | |
| StrOutputParser() | |
) | |
def write_document_content(state): | |
print(state) | |
json_content = state['messages'][-1].content | |
json_content = json_content[json_content.find("{"):json_content.rfind("}")+1].strip() | |
json_content = json.loads(json_content) | |
document_response = documenter_chain.invoke({"content": json_content}) | |
return {"messages": [AIMessage(content=document_response)]} | |
########################################## CONSTRUCT GRAPH ############################################################33 | |
class AgentState(TypedDict): | |
messages: Annotated[list, add_messages] | |
uncompiled_code_graph = StateGraph(AgentState) | |
uncompiled_code_graph.add_node("code_agent", describe_code) | |
uncompiled_code_graph.add_node("write_content_agent", write_document_content) | |
uncompiled_code_graph.add_node("write_document", write_to_docx) | |
uncompiled_code_graph.set_entry_point("code_agent") | |
uncompiled_code_graph.add_edge("code_agent", "write_content_agent") | |
uncompiled_code_graph.add_edge("write_content_agent", "write_document") | |
compiled_code_graph = uncompiled_code_graph.compile() | |
initial_state = { | |
"messages": [{ | |
"role": "human", | |
"content": json.dumps({ | |
"code_language": "python", | |
"code": file_content | |
}) | |
}] | |
} | |
# bind model to tool or ToolNode | |
# imports_tool_node = ToolNode(search_packages_tools) | |
# construct graph and compile | |
# uncompiled_imports_graph = StateGraph(AgentState) | |
# uncompiled_imports_graph.add_node("imports_agent", call_imports_chain) | |
# uncompiled_imports_graph.add_node("imports_action", imports_tool_node) | |
# uncompiled_imports_graph.set_entry_point("imports_agent") | |
# def should_continue(state): | |
# last_message = state["messages"][-1] | |
# if last_message.tool_calls: | |
# return "imports_action" | |
# return END | |
# uncompiled_imports_graph.add_conditional_edges( | |
# "imports_agent", | |
# should_continue | |
# ) | |
# uncompiled_imports_graph.add_edge("imports_action", "imports_agent") | |
# compiled_imports_graph = uncompiled_imports_graph.compile() | |
# print(f'compiled imports graph') | |
# # Invoke imports graph | |
# initial_state = { | |
# "messages": [{ | |
# "role": "human", | |
# "content": json.dumps({ | |
# "code_language": "python", | |
# "imports": imports | |
# }) | |
# }] | |
# } | |
# await msg.update(content="Analyzing imports and generating documentation...") | |
msg.content = "Analyzing your code and generating documentation..." | |
await msg.update() | |
# msg = cl.Message(content="Analyzing your code and generating documentation...") | |
# await msg.send() | |
documenter_result = compiled_code_graph.invoke(initial_state) | |
############################################## SAVE DESCRIPTION CHUNKS IN VECTOR STORE ########################################3 | |
qdrant_client = QdrantClient(":memory:") | |
embedding_model = OpenAIEmbeddings(model="text-embedding-3-small") | |
embedding_dim = 1536 | |
qdrant_client.create_collection( | |
collection_name="description_rag_data", | |
vectors_config=VectorParams(size=embedding_dim, distance=Distance.COSINE), | |
) | |
vectorstore = Qdrant(qdrant_client, collection_name="description_rag_data", embeddings=embedding_model) | |
# Add chunks | |
chunks = documenter_result['messages'][1].content | |
chunks = chunks[chunks.find("{"):chunks.rfind("}")+1].strip() | |
chunks = json.loads(chunks) | |
print(f'################################### raw chunks \n {chunks} \n ######################## \n') | |
chunks_list = [] | |
for key in chunks: | |
if isinstance(chunks[key], dict): | |
chunks_list.append(chunks[key]) | |
elif isinstance(chunks[key], list): | |
for value in chunks[key]: | |
chunks_list.append(value) | |
print(f'################################### chunks_list \n {chunks_list} \n ######################## \n') | |
docs = [ | |
Document( | |
page_content=f"{chunk.get('type', '')} - {chunk.get('name', '')} - {chunk.get('description', '')}", # Content for the model | |
metadata={**chunk} # Store metadata, but don't put embeddings here | |
) | |
for chunk in chunks_list | |
] | |
vectorstore.add_documents(docs) | |
qdrant_retriever = vectorstore.as_retriever(search_kwargs={"k": 3}) | |
print('done adding docs to DB') | |
#define documenter chain | |
# documenter_llm = ChatOpenAI(model="gpt-4o-mini") | |
# documenter_llm_prompt = ChatPromptTemplate.from_messages([ | |
# ("system", documenter_prompt), | |
# ]) | |
# documenter_chain = ( | |
# {"context": itemgetter("context")} | |
# | documenter_llm_prompt | |
# | documenter_llm | |
# | StrOutputParser() | |
# ) | |
# print('done defining documenter chain') | |
#extract description chunks from database | |
# collection_name = "description_rag_data" | |
# all_points = qdrant_client.scroll(collection_name=collection_name, limit=1000)[0] # Adjust limit if needed | |
# one_chunk = all_points[0].payload | |
# input_text = f"type: {one_chunk['metadata']['type']} \nname: {one_chunk['metadata']['name']} \ncontent: {one_chunk['metadata']['content']}" | |
# print('done extracting chunks form DB') | |
# document_response = documenter_chain.invoke({"context": input_text}) | |
print('done invoking documenter chain and will write in docx') | |
# write packages description in word file | |
# document_file_path = write_to_docx(document_response) | |
# print (f'################################ \n documenter_result \n {documenter_result} \n ############################ \n') | |
# document_file_path = documenter_result['messages'][-1].content[0] | |
# print() | |
document_file_path = 'generated_documentation.docx' | |
print('done writing docx file') | |
# Set up Main Chain for chat | |
main_llm = ChatOpenAI(model="gpt-4o-mini", temperature=0) | |
main_llm_prompt = ChatPromptTemplate.from_messages([ | |
("system", main_prompt), | |
("human", "{query}") | |
]) | |
main_chain = ( | |
{"context": itemgetter("query") | qdrant_retriever, "code_language": itemgetter("code_language"), "query": itemgetter("query"), } | |
| main_llm_prompt | |
| main_llm | |
| StrOutputParser() | |
) | |
print('done defining main chain') | |
# Present download button for the document | |
elements = [ | |
cl.File( | |
name="documentation.docx", | |
path=document_file_path, | |
display="inline" | |
) | |
] | |
print('done defining elements') | |
msg.content = "β Your Python file has been processed! You can download the documentation file below. How can I help you with your code?" | |
msg.elements = elements | |
await msg.update() | |
# await msg.update( | |
# content="β Your Python file has been processed! You can download the documentation file below. How can I help you with your code?.", | |
# elements=elements | |
# ) | |
except Exception as e: | |
# await msg.update(content=f"β Error processing file: {str(e)}") | |
error_traceback = traceback.format_exc() | |
print(error_traceback) | |
msg.content = f"β Error processing file: {str(e)}" | |
await msg.update() | |
# msg = cl.Message(content=f"second message β Error processing file: {str(e)}") | |
# await msg.send() | |
else: | |
await cl.Message(content="Please upload a Python (.py) file.").send() | |
# Handle chat messages if file has been processed | |
elif processed_file_path and main_chain: | |
user_input = message.content | |
# Send thinking message | |
msg = cl.Message(content="Thinking...") | |
await msg.send() | |
try: | |
# Use main_chain to answer the query | |
# invoke main chain | |
inputs = { | |
'code_language': 'Python', | |
'query': user_input | |
} | |
response = main_chain.invoke(inputs) | |
# Update with the response | |
# await msg.update(content=response) | |
msg.content = response | |
await msg.update() | |
# msg = cl.Message(content=response) | |
# await msg.send() | |
except Exception as e: | |
# await msg.update(content=f"β Error processing your question: {str(e)}") | |
msg.content = f"β Error processing your question: {str(e)}" | |
await msg.update() | |
# msg = cl.Message(content=f"β Error processing your question: {str(e)}") | |
# await msg.send() | |
else: | |
await cl.Message(content="Please upload a Python file first before asking questions.").send() | |
def on_stop(): | |
global processed_file_path | |
# Clean up temporary files | |
if processed_file_path and os.path.exists(os.path.dirname(processed_file_path)): | |
shutil.rmtree(os.path.dirname(processed_file_path)) | |