Spaces:
Sleeping
Sleeping
""" | |
This module uses the self-discover workflow hand-in-hand with data representation on an index from the user | |
responses and prompt programming from DsPy. Self-Discover Workflow has two stages for any given task. | |
1. Stage 1: | |
a. Select: selects subset of reasoning modules. | |
b. Adapt: adapts selected reasoning modules to the task. | |
c. Implement: gives reasoning structure for the task. | |
2. Stage 2: | |
Uses the generated reasoning structure for the task to generate an answer. | |
""" | |
import os | |
import asyncio | |
from llama_index.core.llms import LLM | |
from llama_index.core.workflow import Workflow, Context, StartEvent, StopEvent, step | |
from llama_index.core.settings import Settings | |
from src.models.discovery_events import GetModulesEvent, RefineModulesEvent, ReasoningStructureEvent | |
from src.workflows.reasoning_modules import _REASONING_MODULES, REASONING_PROMPT_TEMPLATE | |
from src.workflows.reasoning_modules import SELECT_PROMPT_TEMPLATE, ADAPT_PROMPT_TEMPLATE, IMPLEMENT_PROMPT_TEMPLATE | |
from src.workflows.reasoning_modules import JUDGE_REQUIREMENT_PROMPT_TEMPLATE | |
class SelfDiscoverWorkflow(Workflow): | |
"""Self discover workflow.""" | |
async def get_modules(self, context: Context, event: StartEvent) -> GetModulesEvent: | |
""" | |
Select modules required for the task from the defined reasoning modules. | |
:param context: global context maintained for the user until StopEvent is emitted. | |
:param event: trigger event for this step, here Start of the workflow. | |
:return: pydantic GetModulesEvent with "task" and selected "modules". | |
""" | |
task = event.get("task") | |
llm: LLM = event.get("llm") | |
await context.set("llm", llm) | |
prompt = SELECT_PROMPT_TEMPLATE.format(task=task, reasoning_modules=_REASONING_MODULES) | |
result = llm.complete(prompt) | |
return GetModulesEvent(task=task, modules=str(result)) | |
async def refine_modules(self, context: Context, event: GetModulesEvent) -> RefineModulesEvent: | |
""" | |
Refines and adapts the subset of given reasoning modules based on the task. | |
:param context: global context maintained for the user until StopEvent is emitted. | |
:param event: trigger event for the step, here completion of GetModulesEvent. | |
:return: pydantic RefineModulesEvent with "task" and "refined_modules". | |
""" | |
task = event.task | |
modules = event.modules | |
llm: LLM = await context.get("llm") | |
prompt = ADAPT_PROMPT_TEMPLATE.format(task=task, selected_modules=modules) | |
result = llm.complete(prompt) | |
return RefineModulesEvent(task=task, refined_modules=str(result)) | |
async def create_reasoning_structure(self, context: Context, event: RefineModulesEvent) -> ReasoningStructureEvent: | |
""" | |
Creates a reasoning structure for the task given the adapted reasoning modules. | |
:param context: global context maintained for the user until StopEvent is emitted. | |
:param event: trigger event for the step, here completion of RefineModulesEvent. | |
:return: pydantic ReasoningStructureEvent with "task" and "reasoning_structure" | |
""" | |
task = event.task | |
refined_modules = event.refined_modules | |
llm: LLM = await context.get("llm") | |
prompt = IMPLEMENT_PROMPT_TEMPLATE.format(task=task, adapted_modules=refined_modules) | |
result = llm.complete(prompt) | |
return ReasoningStructureEvent(task=task, reasoning_structure=str(result)) | |
async def get_final_result(self, context: Context, event: ReasoningStructureEvent) -> StopEvent: | |
""" | |
Gets the final result from the reasoning structure event | |
:param context: global context maintained for the user until StopEvent is emitted. | |
:param event: trigger event for the step, here completion of ReasoningStructureEvent. | |
:return: StopEvent signal, last step of the workflow. | |
""" | |
task = event.task | |
reasoning_structure = event.reasoning_structure | |
llm: LLM = await context.get("llm") | |
prompt = REASONING_PROMPT_TEMPLATE.format(task=task, reasoning_structure=reasoning_structure) | |
result = llm.complete(prompt) | |
await context.set("workflow_result", result) | |
return StopEvent(result=str(result)) | |
class JudgeWorkflow(Workflow): | |
"""Judgement Workflow to decide whether further questions are necessary.""" | |
async def judge(self, context: Context, event: StartEvent) -> StopEvent: | |
""" | |
Select modules required for the task from the defined reasoning modules. | |
:param context: global context maintained for the user until StopEvent is emitted. | |
:param event: trigger event for this step, here Start of the workflow. | |
:return: StopEvent signal, last step of the workflow. | |
""" | |
judging_context = event.get("judging_context") | |
llm: LLM = event.get("llm") | |
await context.set("llm", llm) | |
prompt = JUDGE_REQUIREMENT_PROMPT_TEMPLATE.format(judging_context=judging_context) | |
result = str(llm.complete(prompt)) | |
result = False if result == "0" else True | |
return StopEvent(result=result) | |
# runner for the workflow | |
async def main(): | |
workflow = SelfDiscoverWorkflow() | |
# example task | |
predefined_task = ( | |
"The user wants a step-by-step workflow for titanic survival prediction ML problem. " | |
"They want to understand whether a person has chances of surviving the titanic accident " | |
"depending on their background, ticket, gender and titanic pitfalls. To perform this, they " | |
"want to design a machine learning workflow and derive conclusions from their data. The final " | |
"model should be able to predict survive/die classes. The data has these features: " | |
"survival, ticket class, sex, age, siblings/spouses, parents/children, ticket, fare, cabin, embarked. " | |
"In case the problem is not clear at any point and you need more input from the user, share the current " | |
"workflow with the user and end with follow-up questions." | |
) | |
intermediate_result = await workflow.run(task=predefined_task, llm=Settings._llm) | |
print(str(intermediate_result)) | |
if __name__ == "__main__": | |
asyncio.run(main()) | |