import math
import json
import regex
import inspect
import guidance
from ast import literal_eval
from transformers import Tool
from ortools.linear_solver import pywraplp

guidance.llm = guidance.llms.OpenAI("gpt-4")

structure_program = guidance(
'''
{{#user~}}
{{description}}
Help me extract args from the data blob to apply the following algorithm:
{{code}}

----

{{~#each examples}}
Data Blob: {{this.input}}
Result: {{this.output}}
---
{{~/each}}

Please help me extract the input values from a given data blob into a JSON.
Data Blob: {{data_blob}}
Result: 
{{~/user}}

{{#assistant~}}
{{gen 'output'}}
{{~/assistant}}
''')

class IntegerProgrammingTool(Tool):
    name = "integer_programming_tool"
    description = """
    This tool solves an integer programming problem.
    Input is data_blob
    Output is the optimal solution as a string.
    """
    inputs = ["text"]
    outputs = ["text"]
    examples = [
        {'input': '''
        ,Space,Price
        Puzzle,1,2
        BoardGame,5,12

        Constraint,100
        ''',
        'output': {
            "objective_coeffs": [
                [1, 2],
                [5, 15],
            ],
            "constraints": [100],
            "bounds": [
                [0, None],
                [0, None]
            ],
            "goal": "max"
          },
        },
        {'input': '''
        ,Space,Price
        Puzzle,3,12
        BoardGame,15,120

        Constraint,300
        ''',
        'output': {
            "objective_coeffs": [
                [3, 2],
                [15, 120],
            ],
            "constraints": [300],
            "bounds": [
                [0, None],
                [0, None]
            ],
            "goal": "max"
          },
        },
    ]
    def __call__(self, data_blob):
        code = inspect.getsourcelines(self.__call__)
        args = structure_program(
            description=self.description,
            code=code,
            examples=self.examples,
            data_blob=data_blob,
        )['output']
        pattern = regex.compile(r"\{(?:[^{}]|(?R))*\}")
        matches = pattern.findall(args)[0]
        args = literal_eval(matches)
        print(args)
        objective_coeffs = args['objective_coeffs']
        constraints = args['constraints']
        bounds = args['bounds']
        goal = args['goal']

        objective_coeffs = list(zip(*objective_coeffs))
        solver = pywraplp.Solver.CreateSolver("CBC")
        variables = [solver.IntVar(
        int(bounds[i][0]) if bounds[i][0] is not None else -math.inf,
        int(bounds[i][1]) if bounds[i][1] is not None else math.inf,
        f"x{i}") for i in range(len(objective_coeffs))]

        # Set objective function
        objective = solver.Objective()
        for i, coeff_list in enumerate(objective_coeffs):
            for j, coeff_value in enumerate(coeff_list):
                objective.SetCoefficient(variables[j], coeff_value)
        if goal == 'max':
            objective.SetMaximization()
        else:
            objective.SetMinimization()

        # Add constraints
        for i, constraint_value in enumerate(constraints):
            if constraint_value:
                constraint = solver.RowConstraint(0, constraint_value)
                for j, coeff in enumerate(objective_coeffs[i]):
                    constraint.SetCoefficient(variables[j], coeff)
        
        solver.Solve()
        
        solution = {"Objective value": objective.Value()}
        for i, variable in enumerate(variables):
            solution[f"x{i}"] = variable.solution_value()
        
        return json.dumps(solution)