franky-v1 / src /workflows /workflow_discovery.py
architojha's picture
adding files
4067b64
"""
This module uses the self-discover workflow hand-in-hand with data representation on an index from the user
responses and prompt programming from DsPy. Self-Discover Workflow has two stages for any given task.
1. Stage 1:
a. Select: selects subset of reasoning modules.
b. Adapt: adapts selected reasoning modules to the task.
c. Implement: gives reasoning structure for the task.
2. Stage 2:
Uses the generated reasoning structure for the task to generate an answer.
"""
import os
import asyncio
from llama_index.core.llms import LLM
from llama_index.core.workflow import Workflow, Context, StartEvent, StopEvent, step
from llama_index.core.settings import Settings
from src.models.discovery_events import GetModulesEvent, RefineModulesEvent, ReasoningStructureEvent
from src.workflows.reasoning_modules import _REASONING_MODULES, REASONING_PROMPT_TEMPLATE
from src.workflows.reasoning_modules import SELECT_PROMPT_TEMPLATE, ADAPT_PROMPT_TEMPLATE, IMPLEMENT_PROMPT_TEMPLATE
from src.workflows.reasoning_modules import JUDGE_REQUIREMENT_PROMPT_TEMPLATE
class SelfDiscoverWorkflow(Workflow):
"""Self discover workflow."""
@step
async def get_modules(self, context: Context, event: StartEvent) -> GetModulesEvent:
"""
Select modules required for the task from the defined reasoning modules.
:param context: global context maintained for the user until StopEvent is emitted.
:param event: trigger event for this step, here Start of the workflow.
:return: pydantic GetModulesEvent with "task" and selected "modules".
"""
task = event.get("task")
llm: LLM = event.get("llm")
await context.set("llm", llm)
prompt = SELECT_PROMPT_TEMPLATE.format(task=task, reasoning_modules=_REASONING_MODULES)
result = llm.complete(prompt)
return GetModulesEvent(task=task, modules=str(result))
@step
async def refine_modules(self, context: Context, event: GetModulesEvent) -> RefineModulesEvent:
"""
Refines and adapts the subset of given reasoning modules based on the task.
:param context: global context maintained for the user until StopEvent is emitted.
:param event: trigger event for the step, here completion of GetModulesEvent.
:return: pydantic RefineModulesEvent with "task" and "refined_modules".
"""
task = event.task
modules = event.modules
llm: LLM = await context.get("llm")
prompt = ADAPT_PROMPT_TEMPLATE.format(task=task, selected_modules=modules)
result = llm.complete(prompt)
return RefineModulesEvent(task=task, refined_modules=str(result))
@step
async def create_reasoning_structure(self, context: Context, event: RefineModulesEvent) -> ReasoningStructureEvent:
"""
Creates a reasoning structure for the task given the adapted reasoning modules.
:param context: global context maintained for the user until StopEvent is emitted.
:param event: trigger event for the step, here completion of RefineModulesEvent.
:return: pydantic ReasoningStructureEvent with "task" and "reasoning_structure"
"""
task = event.task
refined_modules = event.refined_modules
llm: LLM = await context.get("llm")
prompt = IMPLEMENT_PROMPT_TEMPLATE.format(task=task, adapted_modules=refined_modules)
result = llm.complete(prompt)
return ReasoningStructureEvent(task=task, reasoning_structure=str(result))
@step
async def get_final_result(self, context: Context, event: ReasoningStructureEvent) -> StopEvent:
"""
Gets the final result from the reasoning structure event
:param context: global context maintained for the user until StopEvent is emitted.
:param event: trigger event for the step, here completion of ReasoningStructureEvent.
:return: StopEvent signal, last step of the workflow.
"""
task = event.task
reasoning_structure = event.reasoning_structure
llm: LLM = await context.get("llm")
prompt = REASONING_PROMPT_TEMPLATE.format(task=task, reasoning_structure=reasoning_structure)
result = llm.complete(prompt)
await context.set("workflow_result", result)
return StopEvent(result=str(result))
class JudgeWorkflow(Workflow):
"""Judgement Workflow to decide whether further questions are necessary."""
@step
async def judge(self, context: Context, event: StartEvent) -> StopEvent:
"""
Select modules required for the task from the defined reasoning modules.
:param context: global context maintained for the user until StopEvent is emitted.
:param event: trigger event for this step, here Start of the workflow.
:return: StopEvent signal, last step of the workflow.
"""
judging_context = event.get("judging_context")
llm: LLM = event.get("llm")
await context.set("llm", llm)
prompt = JUDGE_REQUIREMENT_PROMPT_TEMPLATE.format(judging_context=judging_context)
result = str(llm.complete(prompt))
result = False if result == "0" else True
return StopEvent(result=result)
# runner for the workflow
async def main():
workflow = SelfDiscoverWorkflow()
# example task
predefined_task = (
"The user wants a step-by-step workflow for titanic survival prediction ML problem. "
"They want to understand whether a person has chances of surviving the titanic accident "
"depending on their background, ticket, gender and titanic pitfalls. To perform this, they "
"want to design a machine learning workflow and derive conclusions from their data. The final "
"model should be able to predict survive/die classes. The data has these features: "
"survival, ticket class, sex, age, siblings/spouses, parents/children, ticket, fare, cabin, embarked. "
"In case the problem is not clear at any point and you need more input from the user, share the current "
"workflow with the user and end with follow-up questions."
)
intermediate_result = await workflow.run(task=predefined_task, llm=Settings._llm)
print(str(intermediate_result))
if __name__ == "__main__":
asyncio.run(main())