import gradio as gr
import vllm
import torch
from collections import Counter

# Initialize Model
llm = vllm.LLM(
    "Qwen/Qwen2.5-32B-Instruct-AWQ",
    tensor_parallel_size=2, 
    quantization="AWQ",
    gpu_memory_utilization=0.95, 
    trust_remote_code=True,
    dtype="half", 
    enforce_eager=True,
    max_model_len=10500,
)
tokenizer = llm.get_tokenizer()

# Helper Functions
def extract_answer(text):
    idx = text.rfind("\\boxed")
    if idx < 0:
        return None
    
    i = idx
    num_open = 0
    close_idx = None
    
    while i < len(text):
        if text[i] == "{":
            num_open += 1
        elif text[i] == "}":
            num_open -= 1
            if num_open == 0:
                close_idx = i
                break
        i += 1
        
    if close_idx is None:
        return None
        
    boxed = text[idx:close_idx + 1]
    left = "\\boxed{"
    try:
        assert boxed[:len(left)] == left
        assert boxed[-1] == "}"
        return boxed[len(left):-1]
    except:
        return None

def majority_vote(answers):
    answers = [a for a in answers if a is not None]
    if not answers:
        return None
    counts = Counter(answers)
    return counts.most_common(1)[0][0]

class TIRAgent:
    def __init__(self, problem_id, id, problem, tokenizer, max_depth, log):
        self.problem_id = problem_id
        self.id = id
        self.depth = 1
        self.max_depth = max_depth
        self.tokenizer = tokenizer
        self.problem = problem
        self.messages = [
            {
                "role": "user",
                "content": f"""Here is a boolean expression to simplify:
{self.problem}

Show the step by step simplification using Boolean algebra laws. For each step:
1. Write the current expression
2. Name the rule applied
3. Explain the transformation clearly

Put your final simplified answer in a LaTeX box \\boxed{{}}."""
            }
        ]
        self.last_response = None
        self.answers = []
        self.is_complete = False
        self.log = log
        self.next_prompt = None

    def complete(self):
        return self.is_complete

    def add_response(self, response):
        self.depth += 1
        self.last_response = response
        self.messages.append({"role": "assistant", "content": response})
        
        # Extract boxed answer if present
        answer = extract_answer(response)
        if answer is not None:
            self.answers.append(answer)
            
        # Mark complete after first response
        self.is_complete = True

    def next_message(self):
        assert not self.is_complete
        text = self.tokenizer.apply_chat_template(
            self.messages,
            tokenize=False,
            add_generation_prompt=True
        )
        return text

    def final_answer(self):
        ans = None
        if len(self.answers) > 0:
            ans = self.answers[-1]
        if self.log:
            self.log.writerow([self.problem_id, self.id, ans])
        return ans

class SCTIRAgent:
    def __init__(self, problem_id, problem, tokenizer, samples, max_depth, log):
        self.problem_id = problem_id
        self.problem = problem
        self.tokenizer = tokenizer
        self.samples = samples
        self.max_depth = max_depth
        self.agents = [
            TIRAgent(problem_id, i, problem, tokenizer, max_depth, log)
            for i in range(samples)
        ]
        self.log = log

    def complete(self):
        return all(agent.complete() for agent in self.agents)

    def get_ready_agents(self):
        return [agent for agent in self.agents if not agent.complete()]

    def final_answer(self):
        assert self.complete()
        answers = [agent.final_answer() for agent in self.agents]
        answer = majority_vote(answers)
        return answer if answer is not None else None

# Sampling parameters
sampling_params = vllm.SamplingParams(
    max_tokens=512,
    temperature=0.7,
    top_p=0.9
)

def simplify_boolean_expression(expression):
    agent = SCTIRAgent(0, expression, tokenizer, samples=1, max_depth=1, log=None)
    
    while not agent.complete():
        ready_agents = agent.get_ready_agents()
        texts = [a.next_message() for a in ready_agents]
        
        responses = llm.generate(texts, sampling_params)
        
        for j, ready_agent in enumerate(ready_agents):
            response = responses[j].outputs[0].text
            ready_agent.add_response(response)
    
    answer = agent.final_answer()
    return answer

# Gradio Interface
def interface(boolean_expr):
    simplified_expr = simplify_boolean_expression(boolean_expr)
    return simplified_expr

# Gradio app
app = gr.Interface(
    fn=interface,
    inputs=gr.Textbox(label="Enter Boolean Expression", placeholder="e.g., (B.C' + A'.D).(A.B' + C.D')"),
    outputs=gr.Textbox(label="Final Simplified Expression"),
    title="Boolean Expression Simplifier",
    description="Input a Boolean expression, and the model will provide the final simplified result.",
)

if __name__ == "__main__":
    app.launch()