import logging
import os
from datetime import datetime
from decimal import Decimal
from typing import List

import boto3
from boto3.dynamodb.conditions import Attr, Key
from datasets import Dataset

logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))

# Create a DynamoDB client
dynamodb = boto3.resource('dynamodb', region_name='us-east-1')


def _create_arena_table():
    dynamodb.create_table(
        TableName='oaaic_chatbot_arena',
        KeySchema=[
            {
                'AttributeName': 'arena_battle_id',
                'KeyType': 'HASH'
            },
        ],
        AttributeDefinitions=[
            {
                'AttributeName': 'arena_battle_id',
                'AttributeType': 'S'
            },
            {
                'AttributeName': 'timestamp',
                'AttributeType': 'S'
            },
        ],
        ProvisionedThroughput={
            'ReadCapacityUnits': 5,
            'WriteCapacityUnits': 5
        },
        GlobalSecondaryIndexes=[
            {
                'IndexName': 'TimestampIndex',
                'KeySchema': [
                    {
                        'AttributeName': 'arena_battle_id',
                        'KeyType': 'HASH'
                    },
                    {
                        'AttributeName': 'timestamp',
                        'KeyType': 'RANGE'
                    },
                ],
                'Projection': {
                    'ProjectionType': 'ALL',
                },
                'ProvisionedThroughput': {
                    'ReadCapacityUnits': 5,
                    'WriteCapacityUnits': 5,
                }
            },
        ]
    )

def _create_elo_scores_table():
    dynamodb.create_table(
        TableName='elo_scores',
        KeySchema=[
            {
                'AttributeName': 'chatbot_name',
                'KeyType': 'HASH'  # Partition key
            },
        ],
        AttributeDefinitions=[
            {
                'AttributeName': 'chatbot_name',
                'AttributeType': 'S'
            },
        ],
        ProvisionedThroughput={
            'ReadCapacityUnits': 5,
            'WriteCapacityUnits': 5
        }
    )


def _create_elo_logs_table():
    dynamodb.create_table(
        TableName='elo_logs',
        KeySchema=[
            {
                'AttributeName': 'arena_battle_id',
                'KeyType': 'HASH'  # Partition key
            },
            {
                'AttributeName': 'battle_timestamp',
                'KeyType': 'RANGE'  # Sort key
            },
        ],
        AttributeDefinitions=[
            {
                'AttributeName': 'arena_battle_id',
                'AttributeType': 'S'
            },
            {
                'AttributeName': 'battle_timestamp',
                'AttributeType': 'S'
            },
            {
                'AttributeName': 'all',
                'AttributeType': 'S'
            }
        ],
        ProvisionedThroughput={
            'ReadCapacityUnits': 10,
            'WriteCapacityUnits': 10
        },
        GlobalSecondaryIndexes=[
            {
                'IndexName': 'AllTimestampIndex',
                'KeySchema': [
                    {
                        'AttributeName': 'all',
                        'KeyType': 'HASH'  # Partition key for the GSI
                    },
                    {
                        'AttributeName': 'battle_timestamp',
                        'KeyType': 'RANGE'  # Sort key for the GSI
                    }
                ],
                'Projection': {
                    'ProjectionType': 'ALL'
                },
                'ProvisionedThroughput': {
                    'ReadCapacityUnits': 10,
                    'WriteCapacityUnits': 10
                }
            },
        ]
    )


def get_unprocessed_battles(last_processed_timestamp):
    # Use boto3 to create a DynamoDB resource and reference the table
    table = dynamodb.Table('oaaic_chatbot_arena')

    # Use a query to retrieve unprocessed battles in temporal order
    response = table.scan(
        FilterExpression=Attr('timestamp').gt(last_processed_timestamp),
        # ScanIndexForward=True
    )

    return response['Items']


def calculate_elo(rating1, rating2, result, K=32):
    # Convert ratings to float
    rating1 = float(rating1)
    rating2 = float(rating2)

    # Calculate the expected outcomes
    expected_outcome1 = 1.0 / (1.0 + 10.0 ** ((rating2 - rating1) / 400.0))
    expected_outcome2 = 1.0 - expected_outcome1

    # Calculate the new Elo ratings
    new_rating1 = rating1 + K * (result - expected_outcome1)
    new_rating2 = rating2 + K * ((1.0 - result) - expected_outcome2)

    return Decimal(new_rating1).quantize(Decimal('0.00')), Decimal(new_rating2).quantize(Decimal('0.00'))


def get_last_processed_timestamp():
    table = dynamodb.Table('elo_logs')

    # Scan the table sorted by timestamp in descending order
    response = table.query(
        IndexName='AllTimestampIndex',
        KeyConditionExpression=Key('all').eq('ALL'),
        ScanIndexForward=False,
        Limit=1
    )

    # If there are no items in the table, return a default timestamp
    if not response['Items']:
        return '1970-01-01T00:00:00'

    # Otherwise, return the timestamp of the latest item
    return response['Items'][0]['battle_timestamp']


def log_elo_update(arena_battle_id, battle_timestamp, new_rating1, new_rating2):
    # Reference the elo_logs table
    table = dynamodb.Table('elo_logs')

    # Update the table
    table.put_item(
        Item={
            'arena_battle_id': arena_battle_id,
            'battle_timestamp': battle_timestamp,  # Use the timestamp of the battle
            'log_timestamp': datetime.now().isoformat(),  # Also store the timestamp of the log for completeness
            'new_rating1': new_rating1,
            'new_rating2': new_rating2,
            'all': 'ALL',
        }
    )


def get_elo_score(chatbot_name, elo_scores):
    if chatbot_name in elo_scores:
        return elo_scores[chatbot_name]

    table = dynamodb.Table('elo_scores')
    response = table.get_item(Key={'chatbot_name': chatbot_name})

    # If there is no item in the table, return a default score
    if 'Item' not in response:
        return 1500

    return response['Item']['elo_score']


def update_elo_score(chatbot_name, new_elo_score):
    table = dynamodb.Table('elo_scores')

    # This will create a new item if it doesn't exist
    table.put_item(
        Item={
            'chatbot_name': chatbot_name,
            'elo_score': Decimal(str(new_elo_score)),
        }
    )


def get_elo_scores():
    table = dynamodb.Table('elo_scores')

    response = table.scan()
    data = response['Items']

    return data


def _backfill_logs():
    table = dynamodb.Table('elo_logs')

    # Initialize the scan operation
    response = table.scan()

    for item in response['Items']:
        table.update_item(
            Key={
                'arena_battle_id': item['arena_battle_id'],
                'battle_timestamp': item['battle_timestamp']
            },
            UpdateExpression="SET #all = :value",
            ExpressionAttributeNames={
                '#all': 'all'
            },
            ExpressionAttributeValues={
                ':value': 'ALL'
            }
        )

def main():
    last_processed_timestamp = get_last_processed_timestamp()
    battles: List[dict] = get_unprocessed_battles(last_processed_timestamp)
    battles = sorted(battles, key=lambda x: x['timestamp'])
    elo_scores = {}

    for battle in battles:
        print(repr(battle))
        if battle['label'] in {-1, 0, 1, 2}:
            outcome = battle['label']
            for chatbot_name in [battle['choice1_name'], battle['choice2_name']]:
                if chatbot_name not in elo_scores:
                    elo_scores[chatbot_name] = get_elo_score(chatbot_name, elo_scores)
            # 1: This means that the first player (or team) won the match.
            # 0.5: This means that the match ended in a draw.
            # 0: This means that the first player (or team) lost the match.
            if outcome == 0 or outcome == -1:
                elo_result = 0.5
            elif outcome == 1:
                elo_result = 1
            else:
                elo_result = 0

            new_rating1, new_rating2 = calculate_elo(elo_scores[battle['choice1_name']], elo_scores[battle['choice2_name']], elo_result)
            logging.info(f"{battle['choice1_name']}: {elo_scores[battle['choice1_name']]} -> {new_rating1} | {battle['choice2_name']}: {elo_scores[battle['choice2_name']]} -> {new_rating2}")
            elo_scores[battle['choice1_name']] = new_rating1
            elo_scores[battle['choice2_name']] = new_rating2
            log_elo_update(battle['arena_battle_id'], battle['timestamp'], new_rating1, new_rating2)
            update_elo_score(battle['choice1_name'], new_rating1)
            update_elo_score(battle['choice2_name'], new_rating2)
            elo_scores[battle['choice1_name']] = new_rating1
            elo_scores[battle['choice2_name']] = new_rating2

    elo_scores = get_elo_scores()
    for i, j in enumerate(elo_scores):
        j["elo_score"] = float(j["elo_score"])
        elo_scores[i] = j
    print(elo_scores)

    if battles:
        # Convert the data into a format suitable for Hugging Face Dataset
        elo_dataset = Dataset.from_list(elo_scores)
        elo_dataset.push_to_hub("openaccess-ai-collective/chatbot-arena-elo-scores", private=False)


if __name__ == "__main__":
    main()