File size: 6,332 Bytes
4067b64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""
This module uses the self-discover workflow hand-in-hand with data representation on an index from the user
responses and prompt programming from DsPy. Self-Discover Workflow has two stages for any given task.

1. Stage 1:
    a. Select: selects subset of reasoning modules.
    b. Adapt: adapts selected reasoning modules to the task.
    c. Implement: gives reasoning structure for the task.
2. Stage 2:
    Uses the generated reasoning structure for the task to generate an answer.
"""

import os
import asyncio

from llama_index.core.llms import LLM
from llama_index.core.workflow import Workflow, Context, StartEvent, StopEvent, step
from llama_index.core.settings import Settings
from src.models.discovery_events import GetModulesEvent, RefineModulesEvent, ReasoningStructureEvent
from src.workflows.reasoning_modules import _REASONING_MODULES, REASONING_PROMPT_TEMPLATE
from src.workflows.reasoning_modules import SELECT_PROMPT_TEMPLATE, ADAPT_PROMPT_TEMPLATE, IMPLEMENT_PROMPT_TEMPLATE
from src.workflows.reasoning_modules import JUDGE_REQUIREMENT_PROMPT_TEMPLATE


class SelfDiscoverWorkflow(Workflow):
    """Self discover workflow."""

    @step
    async def get_modules(self, context: Context, event: StartEvent) -> GetModulesEvent:
        """
        Select modules required for the task from the defined reasoning modules.
        :param context: global context maintained for the user until StopEvent is emitted.
        :param event: trigger event for this step, here Start of the workflow.
        :return: pydantic GetModulesEvent with "task" and selected "modules".
        """

        task = event.get("task")
        llm: LLM = event.get("llm")
        await context.set("llm", llm)

        prompt = SELECT_PROMPT_TEMPLATE.format(task=task, reasoning_modules=_REASONING_MODULES)
        result = llm.complete(prompt)

        return GetModulesEvent(task=task, modules=str(result))

    @step
    async def refine_modules(self, context: Context, event: GetModulesEvent) -> RefineModulesEvent:
        """
        Refines and adapts the subset of given reasoning modules based on the task.
        :param context: global context maintained for the user until StopEvent is emitted.
        :param event: trigger event for the step, here completion of GetModulesEvent.
        :return: pydantic RefineModulesEvent with "task" and "refined_modules".
        """

        task = event.task
        modules = event.modules
        llm: LLM = await context.get("llm")

        prompt = ADAPT_PROMPT_TEMPLATE.format(task=task, selected_modules=modules)
        result = llm.complete(prompt)

        return RefineModulesEvent(task=task, refined_modules=str(result))

    @step
    async def create_reasoning_structure(self, context: Context, event: RefineModulesEvent) -> ReasoningStructureEvent:
        """
        Creates a reasoning structure for the task given the adapted reasoning modules.
        :param context: global context maintained for the user until StopEvent is emitted.
        :param event: trigger event for the step, here completion of RefineModulesEvent.
        :return: pydantic ReasoningStructureEvent with "task" and "reasoning_structure"
        """

        task = event.task
        refined_modules = event.refined_modules
        llm: LLM = await context.get("llm")

        prompt = IMPLEMENT_PROMPT_TEMPLATE.format(task=task, adapted_modules=refined_modules)
        result = llm.complete(prompt)

        return ReasoningStructureEvent(task=task, reasoning_structure=str(result))

    @step
    async def get_final_result(self, context: Context, event: ReasoningStructureEvent) -> StopEvent:
        """
        Gets the final result from the reasoning structure event
        :param context: global context maintained for the user until StopEvent is emitted.
        :param event: trigger event for the step, here completion of ReasoningStructureEvent.
        :return: StopEvent signal, last step of the workflow.
        """

        task = event.task
        reasoning_structure = event.reasoning_structure
        llm: LLM = await context.get("llm")

        prompt = REASONING_PROMPT_TEMPLATE.format(task=task, reasoning_structure=reasoning_structure)
        result = llm.complete(prompt)
        await context.set("workflow_result", result)

        return StopEvent(result=str(result))


class JudgeWorkflow(Workflow):
    """Judgement Workflow to decide whether further questions are necessary."""
    @step
    async def judge(self, context: Context, event: StartEvent) -> StopEvent:
        """
        Select modules required for the task from the defined reasoning modules.
        :param context: global context maintained for the user until StopEvent is emitted.
        :param event: trigger event for this step, here Start of the workflow.
        :return: StopEvent signal, last step of the workflow.
        """

        judging_context = event.get("judging_context")
        llm: LLM = event.get("llm")
        await context.set("llm", llm)

        prompt = JUDGE_REQUIREMENT_PROMPT_TEMPLATE.format(judging_context=judging_context)
        result = str(llm.complete(prompt))
        result = False if result == "0" else True

        return StopEvent(result=result)


# runner for the workflow
async def main():
    workflow = SelfDiscoverWorkflow()
    # example task
    predefined_task = (
        "The user wants a step-by-step workflow for titanic survival prediction ML problem. "
        "They want to understand whether a person has chances of surviving the titanic accident "
        "depending on their background, ticket, gender and titanic pitfalls. To perform this, they "
        "want to design a machine learning workflow and derive conclusions from their data. The final "
        "model should be able to predict survive/die classes. The data has these features: "
        "survival, ticket class, sex, age, siblings/spouses, parents/children, ticket, fare, cabin, embarked. "
        "In case the problem is not clear at any point and you need more input from the user, share the current "
        "workflow with the user and end with follow-up questions."
    )
    intermediate_result = await workflow.run(task=predefined_task, llm=Settings._llm)
    print(str(intermediate_result))


if __name__ == "__main__":
    asyncio.run(main())