Spaces:
Sleeping
Sleeping
File size: 6,332 Bytes
4067b64 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
"""
This module uses the self-discover workflow hand-in-hand with data representation on an index from the user
responses and prompt programming from DsPy. Self-Discover Workflow has two stages for any given task.
1. Stage 1:
a. Select: selects subset of reasoning modules.
b. Adapt: adapts selected reasoning modules to the task.
c. Implement: gives reasoning structure for the task.
2. Stage 2:
Uses the generated reasoning structure for the task to generate an answer.
"""
import os
import asyncio
from llama_index.core.llms import LLM
from llama_index.core.workflow import Workflow, Context, StartEvent, StopEvent, step
from llama_index.core.settings import Settings
from src.models.discovery_events import GetModulesEvent, RefineModulesEvent, ReasoningStructureEvent
from src.workflows.reasoning_modules import _REASONING_MODULES, REASONING_PROMPT_TEMPLATE
from src.workflows.reasoning_modules import SELECT_PROMPT_TEMPLATE, ADAPT_PROMPT_TEMPLATE, IMPLEMENT_PROMPT_TEMPLATE
from src.workflows.reasoning_modules import JUDGE_REQUIREMENT_PROMPT_TEMPLATE
class SelfDiscoverWorkflow(Workflow):
"""Self discover workflow."""
@step
async def get_modules(self, context: Context, event: StartEvent) -> GetModulesEvent:
"""
Select modules required for the task from the defined reasoning modules.
:param context: global context maintained for the user until StopEvent is emitted.
:param event: trigger event for this step, here Start of the workflow.
:return: pydantic GetModulesEvent with "task" and selected "modules".
"""
task = event.get("task")
llm: LLM = event.get("llm")
await context.set("llm", llm)
prompt = SELECT_PROMPT_TEMPLATE.format(task=task, reasoning_modules=_REASONING_MODULES)
result = llm.complete(prompt)
return GetModulesEvent(task=task, modules=str(result))
@step
async def refine_modules(self, context: Context, event: GetModulesEvent) -> RefineModulesEvent:
"""
Refines and adapts the subset of given reasoning modules based on the task.
:param context: global context maintained for the user until StopEvent is emitted.
:param event: trigger event for the step, here completion of GetModulesEvent.
:return: pydantic RefineModulesEvent with "task" and "refined_modules".
"""
task = event.task
modules = event.modules
llm: LLM = await context.get("llm")
prompt = ADAPT_PROMPT_TEMPLATE.format(task=task, selected_modules=modules)
result = llm.complete(prompt)
return RefineModulesEvent(task=task, refined_modules=str(result))
@step
async def create_reasoning_structure(self, context: Context, event: RefineModulesEvent) -> ReasoningStructureEvent:
"""
Creates a reasoning structure for the task given the adapted reasoning modules.
:param context: global context maintained for the user until StopEvent is emitted.
:param event: trigger event for the step, here completion of RefineModulesEvent.
:return: pydantic ReasoningStructureEvent with "task" and "reasoning_structure"
"""
task = event.task
refined_modules = event.refined_modules
llm: LLM = await context.get("llm")
prompt = IMPLEMENT_PROMPT_TEMPLATE.format(task=task, adapted_modules=refined_modules)
result = llm.complete(prompt)
return ReasoningStructureEvent(task=task, reasoning_structure=str(result))
@step
async def get_final_result(self, context: Context, event: ReasoningStructureEvent) -> StopEvent:
"""
Gets the final result from the reasoning structure event
:param context: global context maintained for the user until StopEvent is emitted.
:param event: trigger event for the step, here completion of ReasoningStructureEvent.
:return: StopEvent signal, last step of the workflow.
"""
task = event.task
reasoning_structure = event.reasoning_structure
llm: LLM = await context.get("llm")
prompt = REASONING_PROMPT_TEMPLATE.format(task=task, reasoning_structure=reasoning_structure)
result = llm.complete(prompt)
await context.set("workflow_result", result)
return StopEvent(result=str(result))
class JudgeWorkflow(Workflow):
"""Judgement Workflow to decide whether further questions are necessary."""
@step
async def judge(self, context: Context, event: StartEvent) -> StopEvent:
"""
Select modules required for the task from the defined reasoning modules.
:param context: global context maintained for the user until StopEvent is emitted.
:param event: trigger event for this step, here Start of the workflow.
:return: StopEvent signal, last step of the workflow.
"""
judging_context = event.get("judging_context")
llm: LLM = event.get("llm")
await context.set("llm", llm)
prompt = JUDGE_REQUIREMENT_PROMPT_TEMPLATE.format(judging_context=judging_context)
result = str(llm.complete(prompt))
result = False if result == "0" else True
return StopEvent(result=result)
# runner for the workflow
async def main():
workflow = SelfDiscoverWorkflow()
# example task
predefined_task = (
"The user wants a step-by-step workflow for titanic survival prediction ML problem. "
"They want to understand whether a person has chances of surviving the titanic accident "
"depending on their background, ticket, gender and titanic pitfalls. To perform this, they "
"want to design a machine learning workflow and derive conclusions from their data. The final "
"model should be able to predict survive/die classes. The data has these features: "
"survival, ticket class, sex, age, siblings/spouses, parents/children, ticket, fare, cabin, embarked. "
"In case the problem is not clear at any point and you need more input from the user, share the current "
"workflow with the user and end with follow-up questions."
)
intermediate_result = await workflow.run(task=predefined_task, llm=Settings._llm)
print(str(intermediate_result))
if __name__ == "__main__":
asyncio.run(main())
|