# import pandas as pd
import polars as pl
import numpy as np
from gradio_client import Client
from tqdm.auto import tqdm

import os
import re

from translate import (
    translate_pa_outcome, translate_pitch_outcome,
    jp_pitch_to_en_pitch, jp_pitch_to_pitch_code,
    max_pitch_types
)

# load game data
game_df = pl.read_csv('game.csv').unique()
assert len(game_df) == len(game_df['game_pk'].unique())

# load pa data
pa_df = []
for game_pk in tqdm(game_df['game_pk']):
  pa_df.append(pl.read_csv(os.path.join('pa', f'{game_pk}.csv'), schema_overrides={'pa_pk': str}))
pa_df = pl.concat(pa_df)

# load pitch data
pitch_df = []
for game_pk in tqdm(game_df['game_pk']):
  pitch_df.append(pl.read_csv(os.path.join('pitch', f'{game_pk}.csv'), schema_overrides={'pa_pk': str, 'on_1b': pl.Int64, 'on_2b': pl.Int64, 'on_3b': pl.Int64}))
pitch_df = pl.concat(pitch_df)

# load player data
player_df = pl.read_csv('player.csv')

# translate pa data

def identify_bb_type(hit_type):
  if hit_type in list(range(1, 10)) + list(range(40, 49)):
    return 'ground_ball'
  elif hit_type in list(range(58, 67))+list(range(201, 209)):
    return 'line_drive'
  elif hit_type in list(range(28, 31)) + list(range(55, 58)) + list(range(107, 110)) + list(range(247, 251)):
    return 'fly_ball'
  elif hit_type in list(range(49, 55)) + list(range(103, 107)) + list(range(242, 248)):
    return 'pop_up'
  elif hit_type in [31, 32]:
    return None
  else:
    raise Exception(f'Unexpect hit_type {hit_type}')

pa_df = (
    pa_df
    .with_columns(
        pl.col('des').str.strip_chars().alias('_des'),
        pl.col('des').str.strip_chars(),
        pl.col('des_more').str.strip_chars()
    )
    .with_columns(
        pl.col('des').fill_null(pl.col('des_more'))
    )
    .with_columns(
        pl.when(
            (pl.col('des').str.split(' ').list.len() > 1) &
            (pl.col('des').str.contains(r'+\d+点'))
        )
        .then(pl.col('des').str.split(' ').list.first())
        .otherwise(pl.col('des'))
        .alias('des')
    )
    .with_columns(
        pl.when(
            pl.col('des').is_in(['ボール', '見逃し', '空振り']) |
            pl.col('des').str.ends_with('塁けん制')
        )
        .then(
            pl.col('des_more')
        )
        .otherwise(
            pl.col('des')
        )
        .alias('des')
    )
    .with_columns(
        pl.col('des').map_elements(translate_pa_outcome, return_dtype=str)
    )
    .with_columns(
        pl.col('bb_type').alias('hit_type').str.strip_prefix('dakyu').cast(int).alias('hit_type')
    )
    .with_columns(
        pl.col('hit_type').map_elements(lambda hit_type: identify_bb_type(hit_type), return_dtype=str).alias('bb_type')
    )
)

# translate pitch data
pitch_df = (
    pitch_df
    .filter(pl.col('pitch_name').is_not_null())
    .with_columns(
        pl.col('pitch_name').alias('jp_pitch_name')
    )
    .with_columns(
        pl.col('jp_pitch_name').map_elements(lambda pitch_name: jp_pitch_to_en_pitch[pitch_name], return_dtype=str).alias('pitch_name'),
        pl.col('jp_pitch_name').map_elements(lambda pitch_name: jp_pitch_to_pitch_code[pitch_name], return_dtype=str).alias('pitch_type'),
        pl.col('description').str.split(' ').list.first().map_elements(translate_pitch_outcome, return_dtype=str),
        pl.when(
            pl.col('release_speed') != '-'
        )
        .then(
            pl.col('release_speed').str.strip_suffix('km/h')
        )
        .otherwise(
            None
        )
        .alias('release_speed'),
        ((pl.col('plate_x') + 13) - 80).alias('plate_x'),
        (200 - (pl.col('plate_z') + 13) - 100).alias('plate_z'),
    )
    .with_columns(
        pl.col('release_speed').cast(int), # idk why I can't do this during the strip_suffix step
    )
)

# translate player data
player_df = pl.read_csv('player.csv')
register = (
    pl.read_csv('register.csv')
    .with_columns(
        pl.col('en_name').str.replace(',', '').alias('en_name'),

    )
    .select(
        pl.col('en_name'),
        pl.col('jp_team').alias('team'),
        pl.col('jp_name').alias('name')
    )
)
player_df = player_df.join(register, on=['name', 'team'], how='inner').with_columns(pl.col('en_name').alias('name')).drop(pl.col('en_name'))

# merge pitch and pa data

df = (
    (
        pitch_df
        .join(pa_df, on=['game_pk', 'pa_pk'], how='inner')
        .join(player_df.rename({'player_id': 'pitcher'}), on='pitcher', how='inner')
    )
    .with_columns(
        pl.col('description').is_in(['SS', 'K']).alias('whiff'),
        ~pl.col('description').is_in(['B', 'BB', 'LS', 'inv_K', 'bunt_K', 'HBP', 'SH', 'SH E', 'SH FC', 'obstruction', 'illegal_pitch', 'defensive_interference']).alias('swing'),
        pl.col('description').is_in(['SS', 'K', 'LS', 'inv_K']).alias('csw'),
        ~pl.col('description').is_in(['obstruction', 'illegal_pitch', 'defensive_interference']).alias('normal_pitch') # guess
    )
).sort(['game_pk', 'pa_pk', 'pitch_id'])

# add players to pa_df
# unfortunately we have pas that don't show up in the pitch data, so this would be useful for
pa_df = pa_df.join(player_df.rename({'player_id': 'pitcher'}), on='pitcher', how='inner')

pitch_stats, rhb_pitch_stats, lhb_pitch_stats = [
    (
        _df
        .group_by(['name', 'pitch_name'])
        .agg(
            ((pl.col('whiff').sum() / pl.col('swing').sum()) * 100).round(1).alias('Whiff%'),
            ((pl.col('csw').sum() / pl.col('normal_pitch').sum()) * 100).round(1).alias('CSW%'),
            pl.col('release_speed').mean().round(1).alias('Velocity'),
            pl.len().alias('Count')
        )
        .sort(['name', 'Count'], descending=[False, True])
        # .rename({'name': 'Player', 'pitch_name': 'Pitch'})
    )
    for _df
    in (
        df,
        df.filter(pl.col('stand') == 'R'),
        df.filter(pl.col('stand') == 'L'),
    )
]
league_pitch_stats, rhb_league_pitch_stats, lhb_league_pitch_stats = [
    _df.group_by('pitch_name').agg(pl.col('release_speed').mean().round(1).alias('Velocity'))
    for _df
    in (
        df,
        df.filter(pl.col('stand') == 'R'),
        df.filter(pl.col('stand') == 'L'),
    )
]