Spaces:
Sleeping
Sleeping
from src.models.analysis_models import MLTaskType, ModelResponseStatus, RequirementsAnalysis, TechnicalResearch, ComponentType, ParameterSpec, ConfigParam, FunctionSpec, ComponentSpec, ImplementationPlan | |
from typing import Iterator, List, Optional | |
from phi.workflow import Workflow, RunResponse, RunEvent | |
from phi.agent import Agent | |
from phi.model.openai import OpenAIChat | |
from phi.storage.workflow.sqlite import SqlWorkflowStorage | |
from phi.storage.agent.sqlite import SqlAgentStorage | |
# from phi.memory.db.sqlite import SqliteMemoryDb | |
from phi.tools.duckduckgo import DuckDuckGo | |
from phi.utils.log import logger | |
from dotenv import load_dotenv | |
import json | |
import os | |
load_dotenv() | |
api_key = os.getenv("OPENAI_API_KEY") | |
class MLAnalysisWorkflow(Workflow): | |
"""Workflow for analyzing ML business requirements and creating technical specifications""" | |
# Initialize agents | |
requirements_analyst: Agent = Agent( | |
name="ML Requirements Analyst", | |
model=OpenAIChat(id="gpt-4o", api_key=api_key), | |
description="Expert ML Solutions Architect specialized in analyzing business requirements", | |
instructions=[ | |
"Analyze business problems and translate them into technical ML specifications.", | |
"1. Understand the core business problem and objectives", | |
"2. Identify the type of ML task required", | |
"3. Determine data requirements and constraints", | |
"4. List unclear points that need clarification", | |
"5. Specify areas that need technical research", | |
"Be precise in identifying what information is missing or needs validation." | |
], | |
response_model=RequirementsAnalysis, | |
structured_outputs=True, | |
reasoning=True, | |
storage=SqlAgentStorage( | |
table_name="requirements_sessions", | |
db_file="storage/agent_storage.db" | |
), | |
debug_mode=True, | |
# memory=AgentMemory(memory_db=requirements_db) | |
) | |
technical_researcher: Agent = Agent( | |
name="ML Technical Researcher", | |
model=OpenAIChat(id="gpt-4o", api_key=api_key), | |
description="ML Expert specialized in researching technical implementations", | |
tools=[DuckDuckGo(search=True, news=False)], | |
instructions=[ | |
"Research and validate technical aspects of ML solutions.", | |
"1. Search for similar ML implementations and best practices", | |
"2. Find recommended models and architectures", | |
"3. Research typical hyperparameters and evaluation metrics", | |
"4. Look for implementation constraints and requirements", | |
"5. Validate technical feasibility", | |
"Provide sources for all technical information.", | |
"Focus on recent and reliable technical sources." | |
], | |
response_model=TechnicalResearch, | |
structured_outputs=True, | |
prevent_hallucination=True, | |
reasoning=True, | |
storage=SqlAgentStorage( | |
table_name="researcher_sessions", | |
db_file="storage/agent_storage.db" | |
), | |
debug_mode=True, | |
# memory=AgentMemory(memory_db=researcher_db) | |
) | |
writer: Agent = Agent( | |
model=OpenAIChat(id="gpt-4o", api_key=api_key), | |
instructions=[ | |
"You will be provided with lots of structured outputs. Your work is to display this" | |
"in a nicely formatted manner without changing any of the content. Present all the links" | |
"as they are, with explicitly mentioned hyperlinks. Do not change any content." | |
], | |
markdown=True, | |
) | |
def validate_model_response(self, response: ModelResponseStatus) -> List[str]: | |
"""Check for missing or incomplete fields in ModelResponseStatus""" | |
logger.info("Checking for missing or incomplete fields in ModelResponseStatus...") | |
missing_fields = [] | |
response_dict = response.model_dump() | |
for field, value in response_dict.items(): | |
if value == "..." or value == ["..."]: | |
missing_fields.append(field) | |
elif isinstance(value, list) and not value: | |
missing_fields.append(field) | |
return missing_fields | |
def analyze_requirements(self, user_query: str) -> Optional[RequirementsAnalysis]: | |
"""Stream requirements analysis""" | |
logger.info("Analyzing requirements...") | |
prompt = f"Analyze this business problem and provide initial technical specifications: {user_query}" | |
analyse_stream = self.requirements_analyst.run(prompt) | |
return analyse_stream.content | |
def conduct_research(self, research_prompt: str) -> Optional[TechnicalResearch]: | |
"""Stream technical research""" | |
logger.info("Conducting technical research...") | |
conduct_stream = self.technical_researcher.run(research_prompt) | |
return conduct_stream.content | |
def finalize_analysis(self, final_prompt: str) -> Optional[RequirementsAnalysis]: | |
"""Stream final analysis""" | |
logger.info("Finalizing analysis...") | |
finalise_stream = self.requirements_analyst.run(final_prompt) | |
return finalise_stream.content | |
def write_requirements_post(self, requirements_results: RequirementsAnalysis) -> Iterator[RunResponse]: | |
""" | |
Write a blog post on a topic. | |
:param requirements_results: requirements_analyst response | |
:return: iterator for the workflow response | |
""" | |
logger.info("Writing requirements analysis...") | |
writer_input = {"model_response": requirements_results.model_response.model_dump(), | |
"unclear_points": requirements_results.unclear_points, | |
"search_queries": requirements_results.search_queries, | |
"business_understanding": requirements_results.business_understanding | |
} | |
yield from self.writer.run(json.dumps(writer_input, indent=4), stream=True) | |
def write_research_post(self, research_results: TechnicalResearch) -> Iterator[RunResponse]: | |
""" | |
Write a blog post on a topic. | |
:param research_results: research content | |
:return: iterator for the workflow response | |
""" | |
logger.info("Writing research findings...") | |
writer_input = {"research_findings": research_results.research_findings, | |
"reference_implementations": research_results.reference_implementations, | |
"sources": research_results.sources | |
} | |
yield from self.writer.run(json.dumps(writer_input, indent=4), stream=True) | |
def run(self, user_query: str) -> Iterator[RunResponse]: | |
""" | |
Run the ML analysis workflow | |
Args: | |
user_query: Description of the business problem | |
""" | |
try: | |
# Initial requirements analysis with streaming | |
requirements_result: Optional[RequirementsAnalysis] = self.analyze_requirements(user_query) | |
if not requirements_result: | |
yield RunResponse( | |
event=RunEvent.workflow_completed, | |
content="Error: Requirements analysis failed to produce valid results." | |
) | |
return | |
logger.info("Writing initial requirements analysis...") | |
yield from self.write_requirements_post(requirements_result) | |
# Check what needs research | |
missing_fields = self.validate_model_response(requirements_result.model_response) | |
logger.info("Missing fields found!") | |
search_queries = requirements_result.search_queries | |
logger.info("Search queries found!") | |
unclear_points = requirements_result.unclear_points | |
logger.info("Unclear points found!") | |
if missing_fields or search_queries: | |
# Conduct technical research | |
logger.info("Researching technical specifications...") | |
research_prompt = ( | |
f"Research the following for this ML problem: {user_query}\n" | |
f"Missing information needed for: {', '.join(missing_fields)}\n" | |
f"Specific topics to research: {', '.join(search_queries)}\n" | |
f"Points needing clarification: {', '.join(unclear_points)}\n" | |
f"Current understanding: {requirements_result.business_understanding}" | |
) | |
logger.info("Conducting research...") | |
research_result: Optional[TechnicalResearch] = self.conduct_research(research_prompt) | |
logger.info("Sharing research findings...") | |
yield from self.write_research_post(research_result) | |
final_prompt = ( | |
f"Original problem: {user_query}\n" | |
f"Research findings: {research_result.research_findings}\n" | |
"Please provide final technical specifications incorporating this research." | |
) | |
logger.info("Obtaining final requirements") | |
final_result: Optional[RequirementsAnalysis] = self.finalize_analysis(final_prompt) | |
logger.info("Writing final requirements...") | |
yield from self.write_requirements_post(final_result) | |
except Exception as e: | |
logger.error(f"Workflow error: {str(e)}") | |
yield RunResponse( | |
event=RunEvent.workflow_completed, | |
content=f"Error in analysis workflow: {str(e)}" | |
) | |
class MLImplementationPlanner(Workflow): | |
"""Workflow for creating detailed ML implementation plans""" | |
# Initialize architect agent | |
architect: Agent = Agent( | |
name="ML System Architect", | |
model=OpenAIChat(id="gpt-4o", api_key=api_key), | |
description="Expert ML System Architect specialized in detailed implementation planning", | |
instructions=[ | |
"Create detailed technical implementation plans for ML systems.", | |
"1. Break down the system into logical components", | |
"2. Define detailed function specifications for each component", | |
"3. Specify clear interfaces between components", | |
"4. Consider error handling and edge cases", | |
"5. Plan testing and deployment strategies", | |
"Be extremely specific about function signatures and component interactions.", | |
"Focus on maintainability and scalability in the design." | |
], | |
response_model=ImplementationPlan, | |
structured_outputs=True, | |
reasoning=True, | |
storage=SqlAgentStorage( | |
table_name="architect_sessions", | |
db_file="storage/agent_storage.db" | |
), | |
debug_mode=True, | |
# memory=AgentMemory(memory_db=architect_db) | |
) | |
writer: Agent = Agent( | |
model=OpenAIChat(id="gpt-4o", api_key=api_key), | |
instructions=[ | |
"You will be provided with lots of structured outputs. Your work is to display this" | |
"in a nicely formatted manner without changing any of the content." | |
], | |
markdown=True, | |
) | |
def create_implementation_plan(self, planning_prompt: str) -> Optional[ImplementationPlan]: | |
"""Stream implementation plan creation""" | |
logger.info("Creating implementation plan...") | |
planning_stream = self.architect.run(planning_prompt) | |
return planning_stream.content | |
def validate_interfaces(self, validation_prompt: str) -> Optional[ImplementationPlan]: | |
"""Stream interface validation""" | |
logger.info("Validating interfaces...") | |
architect_stream = self.architect.run(validation_prompt) | |
return architect_stream.content | |
def write_implementation_post(self, implementation_results: ImplementationPlan) -> Iterator[RunResponse]: | |
""" | |
Write a blog post on a topic. | |
:param implementation_results: implementation plan results | |
:return: iterator for the workflow response | |
""" | |
logger.info("Writing implementation plan...") | |
writer_input = {"components": [comp.model_dump() for comp in implementation_results.components], | |
"system_requirements": implementation_results.system_requirements, | |
"deployment_notes": implementation_results.deployment_notes, | |
"testing_strategy": implementation_results.testing_strategy, | |
"implementation_order": implementation_results.implementation_order | |
} | |
yield from self.writer.run(json.dumps(writer_input, indent=4), stream=True) | |
def run( | |
self, | |
requirements_analysis: RequirementsAnalysis, | |
technical_research: Optional[TechnicalResearch] = None | |
) -> Iterator[RunResponse]: | |
""" | |
Create implementation plan based on requirements analysis and research | |
Args: | |
requirements_analysis: Results from requirements analysis | |
technical_research: Optional results from technical research | |
""" | |
try: | |
logger.info("Starting planning workflow...") | |
# Prepare comprehensive prompt for the architect | |
planning_prompt = ( | |
f"Create a detailed implementation plan for this ML system.\n\n" | |
f"Business Understanding:\n{requirements_analysis.business_understanding}\n\n" | |
f"Technical Specifications:\n" | |
f"- Task Type: {requirements_analysis.model_response.task}\n" | |
f"- Models: {', '.join(requirements_analysis.model_response.models)}\n" | |
f"- Data Requirements: {requirements_analysis.model_response.data_source}\n" | |
f"- Technical Requirements: {requirements_analysis.model_response.technical_requirements}\n" | |
) | |
if technical_research: | |
logger.info("Technical Research found! Modifying context...") | |
planning_prompt += ( | |
f"\nResearch Findings:\n{technical_research.research_findings}\n" | |
f"Reference Implementations:\n" | |
f"{chr(10).join(technical_research.reference_implementations)}" | |
) | |
# Stream implementation plan | |
logger.info("generating implementation plan...") | |
plan_result: Optional[ImplementationPlan] = self.create_implementation_plan(planning_prompt) | |
logger.info("writing implementation plan...") | |
yield from self.write_implementation_post(plan_result) | |
if plan_result: | |
validation_prompt = ( | |
"Validate the interfaces between these components " | |
"and ensure all dependencies are properly specified:\n" | |
f"{plan_result.components}" | |
) | |
logger.info("validating results...") | |
validate_result: Optional[ImplementationPlan] = self.validate_interfaces(validation_prompt) | |
logger.info("writing validated implementation plan...") | |
yield from self.write_implementation_post(validate_result) | |
except Exception as e: | |
logger.error("Error in planning workflow".format(e)) | |
# yield RunResponse( | |
# event=RunEvent.workflow_completed, | |
# content=f"Error in planning workflow: {str(e)}" | |
# ) | |