import plotly.express as px
import plotly.graph_objects as go
import plotly.colors as pc
from scipy.stats import gaussian_kde
import numpy as np
import polars as pl
import gradio as gr

from math import ceil

from translate import max_pitch_types, jp_pitch_to_en_pitch
from data import (
    df,
    # pitch_stats, rhb_pitch_stats,lhb_pitch_stats,
    # league_pitch_stats, rhb_league_pitch_stats, lhb_league_pitch_stats
    compute_pitch_stats, compute_league_pitch_stats
)


MAX_LOCS = len(jp_pitch_to_en_pitch)
LOCS_PER_ROW = 4
MAX_ROWS = ceil(MAX_LOCS/LOCS_PER_ROW)
INSUFFICIENT_PITCHES_MSG = 'No visualization: Not enough pitches thrown'
INSUFFICIENT_PITCHES_MSG_MULTI_LINE = 'No visualization:<br>Not enough pitches thrown'

# GRADIO FUNCTIONS

# def clone_if_dataframe(item):
  # if isinstance(item, pl.DataFrame):
    # # print(type(item))
    # return item.clone()
  # else:
    # return item
# 
# def clone_df(fn):
  # def _fn(*args, **kwargs):
    # args = [clone_if_dataframe(arg) for arg in args]
    # kwargs = {k: clone_if_dataframe(arg) for k, arg in kwargs.items()}
    # return fn(*args, **kwargs)
  # return _fn
# 
def copy_dataframe(df, num_copy_to):
  return [df.clone() for _ in range(num_copy_to)]

# location maps
def fit_pred_kde(data, X, Y):
  kde = gaussian_kde(data)
  return kde(np.stack((X, Y)).reshape(2, -1)).reshape(*X.shape)


plot_s = 256
sz_h = 200
sz_w = 160
h_h = 200 - 40*2
h_w = 160 - 32*2

kde_range = np.arange(-plot_s/2, plot_s/2, 1)
X, Y = np.meshgrid(
    kde_range,
    kde_range
)


def coordinatify(h, w):
  return dict(
      x0=-w/2,
      y0=-h/2,
      x1=w/2,
      y1=h/2
  )


colorscale = pc.sequential.OrRd
colorscale = [
    [0, 'rgba(0, 0, 0, 0)'],
] + [
    [i / len(colorscale), color] for i, color in enumerate(colorscale, start=1)
]


# @clone_df
def plot_loc(df, handedness, league_df=None, min_pitches=3, max_pitches=5000):

  loc = df.select(['plate_x', 'plate_z'])

  fig = go.Figure()
  if len(loc) >= min_pitches:
    Z = fit_pred_kde(loc.to_numpy().T, X, Y)
    fig.add_shape(
        type="rect",
        **coordinatify(sz_h, sz_w),
        line_color='gray',
        # fillcolor='rgba(220, 220, 220, 0.75)', #gainsboro
    )
    fig.add_shape(
        type="rect",
        **coordinatify(h_h, h_w),
        line_color='dimgray',
    )
    fig.add_trace(go.Contour(
        z=Z,
        x=kde_range,
        y=kde_range,
        colorscale=colorscale,
        zmin=1e-5,
        zmax=Z.max(),
        contours={
            'start': 1e-5,
            'end': Z.max(),
            'size': Z.max() / 5
        },
        showscale=False
    ))
  else:
      fig.add_annotation(
          x=0,
          y=0,
          text=INSUFFICIENT_PITCHES_MSG_MULTI_LINE,
          showarrow=False
      )

  if league_df is not None:
    league_loc = league_df.select(pl.col('plate_x', 'plate_z'))
    if len(league_loc) > max_pitches:
      league_loc = league_loc.sample(max_pitches, seed=0)

    if len(league_loc) >= min_pitches:
      league_Z = fit_pred_kde(league_loc.to_numpy().T, X, Y)
      percentile = np.quantile(league_Z, 0.9)
      fig.add_trace(go.Contour(
          z=league_Z,
          x=kde_range,
          y=kde_range,
          colorscale=[
              [0, 'rgba(0, 0, 0, 0)'],
              [1, 'rgba(0, 0, 0, 0)']
          ],
          zmin=percentile,
          zmax=league_Z.max(),
          contours={
              'start': percentile,
              'end': league_Z.max(),
              'size': league_Z.max() - percentile,
              # 'coloring': 'heatmap'
          },
          line={
              'width': 2,
              'color': 'black',
              'dash': 'dash'
          },
          showlegend=True,
          showscale=False,
          visible=True if handedness != 'Both' else 'legendonly',
          name='NPB'
      ))

  fig.update_layout(
    xaxis=dict(range=[-plot_s/2, plot_s/2+1], showticklabels=False),
    yaxis=dict(range=[-plot_s/2, plot_s/2+1], scaleanchor='x', scaleratio=1, showticklabels=False),
    legend=dict(orientation='h', y=0, yanchor='top'),
    # width=384,
    # height=384
  )
  return fig


# velo distribution
# @clone_df
def plot_velo(df=None, player=None, velos=None, pitch_type=None, pitch_name=None, min_pitches=2):
  assert not ((velos is None and player is None) or (velos is not None and player is not None)), 'exactly one of `player` or `velos` must be specified'

  if velos is None and player is not None:
    assert not ((pitch_type is None and pitch_name is None) or (pitch_type is not None and pitch_name is not None)), 'exactly one of `pitch_type` or `pitch_name` must be specified'
    assert df is not None, '`df` must be provided if `velos` not provided'
    pitch_val = pitch_type or pitch_name
    pitch_col = 'pitch_type' if pitch_type else 'pitch_name'
    # velos = df.set_index(['name', pitch_col]).sort_index().loc[(player, pitch_val), 'release_speed']
    velos = df.filter((pl.col('name') == player) & (pl.col(pitch_col) == pitch_val))['release_speed']

  fig = go.Figure()
  if len(velos) >= min_pitches:
    fig = fig.add_trace(go.Violin(x=velos, side='positive', hoveron='points', points=False, meanline_visible=True, name='Velocity Distribution'))
    median = velos.median()
    x_range = [median-25, median+25]
  else:
    fig.add_annotation(
        x=(170+125)/2,
        y=0.3/2,
        text=INSUFFICIENT_PITCHES_MSG_MULTI_LINE,
        showarrow=False,
    )
    x_range = [125, 170]
  fig.update_layout(
    xaxis=dict(
        title='Velocity',
        range=x_range,
        scaleratio=2
    ),
    yaxis=dict(
        title='Frequency',
        range=[0, 0.3],
        scaleanchor='x',
        scaleratio=1,
        tickvals=np.linspace(0, 0.3, 3),
        ticktext=np.linspace(0, 0.3, 3),
    ),
    autosize=True,
    # width=512,
    # height=256,
    modebar_remove=['zoom', 'autoScale', 'resetScale'],
  )
  return fig

# @clone_df
def plot_velo_summary(df, league_df, player):

  min_pitches = 2

  # player_df = df.set_index('name').sort_index().loc[player].sort_values('pitch_name').set_index('pitch_name')
  # pitch_counts = player_df.index.value_counts(ascending=True)
  player_df = df.filter(pl.col('release_speed').is_not_null())
  pitch_counts = player_df['pitch_name'].value_counts().sort('count')

  # league_df = df.set_index('pitch_name').sort_index()
  league_df = league_df.filter(pl.col('release_speed').is_not_null())

  fig = go.Figure()

  min_velo = player_df['release_speed'].min() if len(player_df) else 130
  max_velo = player_df['release_speed'].max() if len(player_df) else 160
  velo_center = (min_velo + max_velo) / 2
  # for i, (pitch_name, count) in enumerate(pitch_counts.items()):
  for i, (pitch_name, count) in enumerate(pitch_counts.iter_rows()):
    # velos = player_df.loc[pitch_name, 'release_speed']
    # league_velos = league_df.loc[pitch_name, 'release_speed']
    velos = player_df.filter(pl.col('pitch_name') == pitch_name)['release_speed']
    league_velos = league_df.filter(pl.col('pitch_name') == pitch_name)['release_speed']
    fig.add_trace(go.Violin(
        x=league_velos,
        y=[pitch_name]*len(league_velos),
        line_color='gray',
        side='positive',
        orientation='h',
        meanline_visible=True,
        points=False,
        legendgroup='NPB',
        legendrank=1,
        # visible='legendonly',
        # showlegend=False,
        showlegend=i==0,
        name='NPB',
    ))
    if count >= min_pitches:
      fig.add_trace(go.Violin(
          x=velos,
          y=[pitch_name]*len(velos),
          side='positive',
          orientation='h',
          meanline_visible=True,
          points=False,
          legendgroup=pitch_name,
          legendrank=len(pitch_counts) - i, #2+(len(pitch_counts) - i),
          name=pitch_name
      ))
    else:
      fig.add_trace(go.Scatter(
          x=[velo_center],
          y=[pitch_name],
          text=[INSUFFICIENT_PITCHES_MSG],
          textposition='top center',
          hovertext=False,
          mode="lines+text",
          legendgroup=pitch_name,
          legendrank=len(pitch_counts) - i, #2+(len(pitch_counts) - i),
          name=pitch_name,
      ))

  # fig.add_trace(go.Violin(
  #     x=league_df['release_speed'],
  #     y=[player]*len(league_df),
  #     line_color='gray',
  #     side='positive',
  #     orientation='h',
  #     meanline_visible=True,
  #     points=False,
  #     legendgroup='NPB',
  #     legendrank=1,
  #     # visible='legendonly',
  #     name='NPB',
  # ))
  # fig.add_trace(go.Violin(
  #     x=player_df['release_speed'],
  #     y=[player]*len(player_df),
  #     side='positive',
  #     orientation='h',
  #     meanline_visible=True,
  #     points=False,
  #     legendrank=0,
  #     name=player
  # ))

  # fig.update_xaxes(title='Velocity', range=[player_df['release_speed'].dropna().min() - 2, player_df['release_speed'].dropna().max() + 2])
  fig.update_xaxes(title='Velocity', range=[min_velo - 2, max_velo + 2])
  # fig.update_yaxes(range=[0, len(pitch_counts)+1-0.25], visible=False)
  fig.update_yaxes(range=[0, len(pitch_counts)-0.25], visible=False)
  fig.update_layout(
      violingap=0,
      violingroupgap=0,
      legend=dict(orientation='h', y=-0.15, yanchor='top'),
      modebar_remove=['zoom', 'select2d', 'lasso2d', 'pan', 'autoScale'],
      dragmode=False
  )

  return fig


def update_dfs(player, handedness, start_date, end_date, df):
  date_filter = (pl.col('game_date') >= start_date) & (pl.col('game_date') <= end_date)
  if handedness == 'Both':
    handedness_filter = pl.col('stand').is_in(['R', 'L'])
    # _pitch_stats = pitch_stats
    # _league_pitch_stats = league_pitch_stats
  elif handedness == 'Right':
    handedness_filter = pl.col('stand') == 'R'
    # _pitch_stats = rhb_pitch_stats
    # _league_pitch_stats = rhb_league_pitch_stats
  elif handedness == 'Left':
    handedness_filter = pl.col('stand') == 'L'
    # _pitch_stats = lhb_pitch_stats
    # _league_pitch_stats = lhb_league_pitch_stats
  player_filter = pl.col('name') == player
  non_player_filter = handedness_filter & date_filter
  final_filter = player_filter & non_player_filter
  _df = df.filter(final_filter)
  _league_df = df.filter(non_player_filter)

  return (
    _df,
    _league_df, 
    compute_pitch_stats(_df), 
    compute_league_pitch_stats(_league_df),
  )

def create_set_download_file_fn(filepath):
  def set_download_file(df):
    df.write_csv(filepath)
    return filepath
  return set_download_file

def preview_df(df):
  return df.head()

# @clone_df
def plot_usage(df, player):
  fig = px.pie(df.select('pitch_name'), names='pitch_name')
  fig.update_traces(texttemplate='%{percent:.1%}', hovertemplate=f'<b>{player}</b><br>' + 'threw a <b>%{label}</b><br><b>%{percent:.1%}</b> of the time (<b>%{value}</b> pitches)')
  return fig

# @clone_df
def plot_pitch_cards(df, league_df, pitch_stats, handedness):
  pitch_counts = df['pitch_name'].value_counts().sort('count', descending=True)

  pitch_rows = []
  pitch_groups = []
  pitch_names = []
  pitch_infos = []
  pitch_velos = []
  pitch_locs = []

  for row in range(ceil(len(pitch_counts) / LOCS_PER_ROW)):
    pitch_rows.append(gr.update(visible=True))

  for row in range(len(pitch_rows), MAX_ROWS):
    pitch_rows.append(gr.update(visible=False))

  for pitch_name, count in pitch_counts.iter_rows():
    pitch_groups.append(gr.update(visible=True))
    pitch_names.append(gr.update(value=f'### {pitch_name}', visible=True))
    pitch_infos.append(gr.update(
        value=pitch_stats.filter(pl.col('pitch_name') == pitch_name).select(['Whiff%', 'CSW%']),
        visible=True
    ))
    pitch_velos.append(gr.update(
        value=plot_velo(velos=df.filter((pl.col('pitch_name') == pitch_name) & (pl.col('release_speed').is_not_null()))['release_speed']),
        visible=True
    ))
    pitch_locs.append(gr.update(
        value=plot_loc(
            df=df.filter(pl.col('pitch_name') == pitch_name),
            handedness=handedness,
            league_df=league_df.filter(pl.col('pitch_name') == pitch_name)
        ),
        label='Pitch location',
        visible=True
    ))

  for _ in range(max_pitch_types - len(pitch_names)):
    pitch_groups.append(gr.update(visible=False))
    pitch_names.append(gr.update(value=None, visible=False))
    pitch_infos.append(gr.update(value=None, visible=False))
    pitch_velos.append(gr.update(value=None, visible=False))
    pitch_locs.append(gr.update(value=None, visible=False))

  return pitch_rows + pitch_groups + pitch_names + pitch_infos + pitch_velos + pitch_locs

# @clone_df
def update_velo_stats(pitch_stats, league_pitch_stats):
  return (
      pitch_stats
      .select(pl.col('pitch_name').alias('Pitch'), pl.col('Velocity').alias('Avg. Velo'), pl.col('Count'))
      .join(
          league_pitch_stats.select(pl.col('pitch_name').alias('Pitch'), pl.col('Velocity').alias('League Avg. Velo')),
          on='Pitch',
          how='inner'
      )
      .sort('Count', descending=True)
      .drop('Count')
  )