from threading import Lock
import argparse

import numpy as np
from matplotlib import pyplot as plt
import gradio as gr
import torch
import pandas as pd

from biasprobe import BinaryProbe, PairwiseExtractionRunner, SimplePairPromptBuilder, ProbeConfig


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', '-s', type=int, default=0, help="the random seed")
    parser.add_argument('--port', '-p', type=int, default=8080, help="the port to launch the demo")
    parser.add_argument('--no-cuda', action='store_true', help="Use CPUs instead of GPUs")
    args = parser.parse_args()
    return args


def main():
    args = get_args()
    plt.switch_backend('agg')
    dmap = 'auto'
    mdict = {0: '24GIB'}
    config = ProbeConfig.create_for_model('mistralai/Mistral-7B-Instruct-v0.1')
    probe = BinaryProbe(config).cuda()
    probe.load_state_dict(torch.load('probe.pt'))

    runner = PairwiseExtractionRunner.from_pretrained('mistralai/Mistral-7B-Instruct-v0.1', optimize=False, torch_dtype=torch.float16, max_memory=mdict, device_map=dmap, low_cpu_mem_usage=True)

    @torch.no_grad()
    def run_extraction(prompt):
        builder = SimplePairPromptBuilder(criterion='more positive')
        lst = [x.strip() for x in prompt.lower()[:300].split(',')][:100]
        exp = runner.run_extraction(lst, lst, layers=[15], num_repeat=50, builder=builder, parallel=False, run_inference=True, debug=True, max_new_tokens=2)
        test_ds = exp.make_dataset(15)

        import torch

        raw_scores = []
        preds_list = []
        hs = []

        for idx, (tensor, labels) in enumerate(test_ds):
            with torch.no_grad():
                labels = labels - 1  # 1-indexed

                if tensor.shape[0] != 2:
                    continue

                h = tensor[1] - tensor[0]
                hs.append(h)

                try:
                    x = probe(tensor.unsqueeze(0).cuda().float()).squeeze()
                except IndexError:
                    continue

                pred = [0, 1] if x.item() > 0 else [1, 0]
                pred = np.array(pred)

            if test_ds.original_examples is not None:
                items = [x.content for x in test_ds.original_examples[idx].hits]
                preds_list.append(np.array(items, dtype=object)[labels][pred].tolist())

            raw_scores.append(x.item())

        df = pd.DataFrame({'Win Rate': np.array(raw_scores) > 0, 'Word': [x[0] for x in preds_list]})
        win_df = df.groupby('Word').mean('Win Rate')
        win_df = win_df.reset_index().sort_values('Win Rate')
        win_df['Win Rate'] = [str(x) + '%' for x in (win_df['Win Rate'] * 100).round(2).tolist()]

        return win_df

    with gr.Blocks(css='scrollbar.css') as demo:
        md = '''# BiasProbe: Revealing Preference Biases in Language Model Representations
        What do llamas really "think" about controversial words?
        Type some words below to see how Mistral-7B-Instruct associates them with
        positive and negative emotions.
        Higher win rates indicate that the word is more likely to be associated with
        positive emotions than other words in the list.
        
        Check out our paper, [What Do Llamas Really Think? Revealing Preference Biases in Language Model Representations](http://arxiv.org/abs/2311.18812).
        See our [codebase](https://github.com/castorini/biasprobe) on GitHub.
        '''
        gr.Markdown(md)

        with gr.Row():
            with gr.Column():
                text = gr.Textbox(label='Words', value='Republican, democrat, libertarian, authoritarian')
                submit_btn = gr.Button('Submit', elem_id='submit-btn')
            output = gr.DataFrame(pd.DataFrame({'Word': ['authoritarian', 'republican', 'democrat', 'libertarian'],
                                                'Win Rate': ['44.44%', '81.82%', '100%', '100%']}))

            submit_btn.click(
                fn=run_extraction,
                inputs=[text],
                outputs=[output])

    while True:
        try:
            demo.launch(server_name='0.0.0.0')
        except OSError:
            gr.close_all()
        except KeyboardInterrupt:
            gr.close_all()
            break


if __name__ == '__main__':
    main()