import React, { useState, useEffect, useRef } from 'react';
import * as d3 from 'd3';
import { useTheme } from '../context/themeContext';

import MODELS from '../utils/models';
import DEVICES from '../utils/devices';

type Precision = '32-bit' | '16-bit' | '8-bit' | '4-bit';

interface ModelSizeBarChartProps {
  modelSize: number; // in GB
  largestModelSize: number; // largest model in full precision (32-bit)
  modelPrecision: Precision;
  deviceMemorySet: boolean;
  activationMemorySize?: number;
}

interface InferenceRuntimeLineChartProps {
  availableMemory: AvailableMemory; // in GB
  memoryPerInput: number; // in GB
}

interface LineChartData {
  seqLength: number;
  batchSize: number;
}

interface AvailableMemory {
  '4-bit': number;
  '8-bit': number;
  '16-bit': number;
  '32-bit': number;
}

// Utility to determine color based on precision
function chooseColor(precision: Precision) {
  const colors = {
    '32-bit': '#e45f5b',
    '16-bit': '#ffc068',
    '8-bit': '#71cce9',
    '4-bit': '#383d95',
  };
  return colors[precision] || 'gray';
}

// Calculate standard memory (model size based on precision only)
function calculateStandardMemory(modelParams: number, precision: Precision): number {
  const precisionFactor = {
    '32-bit': 4,
    '16-bit': 2,
    '8-bit': 1,
    '4-bit': 0.5,
  };

  const memory = modelParams * precisionFactor[precision]; // GB
  console.log(`[Standard] ${precision.toUpperCase()} Memory:`, memory);
  return memory;
}

// Calculate prefill chunking memory (model size + activation + input memory)
function calculatePrefillMemory(
  modelParams: number,
  hiddenSize: number,
  numLayers: number,
  intermediateSize: number,
  precision: Precision
): number {
  const precisionFactor = {
    '32-bit': 4,
    '16-bit': 2,
    '8-bit': 1,
    '4-bit': 0.5,
  };

  // Max Chunk Size - adjustable in the future
  const maxChunkSize = 512;

  // Calculate each memory component
  const modelMemorySize = modelParams * precisionFactor[precision]; // GB
  const activationMemorySize = (maxChunkSize * 2 * Math.max(2 * intermediateSize, 4 * hiddenSize)) / 1_000_000_000; // GB
  const memoryPerInput = (4 * hiddenSize * numLayers) / 1_000_000_000; // GB

  // Combine all components
  const totalMemory = modelMemorySize + activationMemorySize + memoryPerInput;

  console.log(`[Prefill] ${precision.toUpperCase()} Memory:`, totalMemory);
  console.log(`[Prefill] Activation Memory:`, activationMemorySize);
  console.log(`[Prefill] Memory Per Input:`, memoryPerInput);

  return totalMemory;
}

// Bar chart for model footprint (shared by both standard and prefill chunking calculators)
function ModelSizeBarChart({
  modelSize,
  largestModelSize,
  modelPrecision,
  deviceMemorySet,
  activationMemorySize = 0,
}: ModelSizeBarChartProps) {
  const { theme } = useTheme();
  const chartRef = useRef<SVGSVGElement>(null);

  const width = 600;
  const height = 50;

  useEffect(() => {
    if (modelSize > 0 && largestModelSize > 0) {
      d3.select(chartRef.current).selectAll('*').remove();

      const svg = d3.select(chartRef.current)
        .attr('width', width)
        .attr('height', height)
        .style('animation', 'fadeIn 0.3s ease-in-out') // Inline animation
        .style('transition', 'transform 0.3s ease-in-out') // Hover effect
        .on('mouseover', function () {
          d3.select(this).style('transform', 'scale(1.02)');
        })
        .on('mouseout', function () {
          d3.select(this).style('transform', 'scale(1)');
        });

      const xScale = d3.scaleLinear().domain([0, largestModelSize]).range([0, width]);

      if (modelSize + activationMemorySize > largestModelSize) {
        svg
          .append('rect')
          .attr('x', 0)
          .attr('y', 0)
          .attr('width', width)
          .attr('height', height)
          .attr('fill', 'transparent')
          .style('stroke', theme === 'dark' ? '#f9fafb' : '#181f26')
          .style('stroke-dasharray', '4, 4')
          .style('stroke-width', '2px');
        svg
          .append('text')
          .attr('x', width / 2)
          .attr('y', height / 2)
          .attr('text-anchor', 'middle')
          .attr('alignment-baseline', 'middle')
          .attr('fill', theme === 'dark' ? '#f9fafb' : '#181f26')
          .text('Out of Memory');
      } else {
        svg
          .append('rect')
          .attr('x', 0)
          .attr('y', 0)
          .attr('width', xScale(modelSize))
          .attr('height', height)
          .attr('fill', chooseColor(modelPrecision));

        if (activationMemorySize > 0) {
          svg
            .append('rect')
            .attr('x', xScale(modelSize))
            .attr('y', 0)
            .attr('width', xScale(activationMemorySize))
            .attr('height', height)
            .attr('fill', '#a4b8e0');
        }

        if (deviceMemorySet) {
          svg
            .append('rect')
            .attr('x', xScale(modelSize + activationMemorySize))
            .attr('y', 0)
            .attr('width', xScale(largestModelSize - (modelSize + activationMemorySize)))
            .attr('height', height)
            .attr('fill', 'transparent')
            .style('stroke', chooseColor(modelPrecision))
            .style('stroke-width', '2px');
        }
      }
    }
  }, [modelSize, largestModelSize, modelPrecision, deviceMemorySet, activationMemorySize, theme]);

  return <svg ref={chartRef}></svg>;
}

// Line chart for inference runtime (shared by both standard and prefill chunking calculators)
function InferenceRuntimeLineChart({ availableMemory, memoryPerInput }: InferenceRuntimeLineChartProps) {
  const { theme } = useTheme();
  const chartRef = useRef(null);
  const tooltipRef = useRef<HTMLDivElement>(null); // Ref for the tooltip
  const maxSeqLength = 4096;
  const maxBatchSize = 128;

  useEffect(() => {
    if (memoryPerInput > 0 && Object.values(availableMemory).some((val) => val > 0)) {
      const margin = { top: 20, right: 20, bottom: 50, left: 50 };
      const width = 600 - margin.left - margin.right;
      const height = 400 - margin.top - margin.bottom;

      const svg = d3.select(chartRef.current);
      svg.selectAll('*').remove();

      const xScale = d3.scaleLinear().domain([0, maxSeqLength]).range([0, width]);
      const yScale = d3.scaleLinear().domain([0, maxBatchSize]).range([height, 0]);

      const xAxis = d3.axisBottom(xScale);
      const yAxis = d3.axisLeft(yScale);

      const zoom = d3.zoom()
        .scaleExtent([0.5, 10])
        .translateExtent([[-width, -height], [2 * width, 2 * height]])
        .on('zoom', (event) => {
          const transform = event.transform;
          svg.select('.x-axis').call(xAxis.scale(transform.rescaleX(xScale)));
          svg.select('.y-axis').call(yAxis.scale(transform.rescaleY(yScale)));
          svg.selectAll('path').attr('transform', transform);
        });

      svg
        .attr('width', width + margin.left + margin.right)
        .attr('height', height + margin.top + margin.bottom)
        .append('g')
        .attr('transform', `translate(${margin.left}, ${margin.top})`)
        .call(zoom);

      svg.append('g').attr('class', 'x-axis').attr('transform', `translate(${margin.left}, ${height + margin.top})`).call(xAxis);

      svg.append('g').attr('class', 'y-axis').attr('transform', `translate(${margin.left}, ${margin.top})`).call(yAxis);

      svg.append('text')
        .attr('transform', `translate(${width / 2 + margin.left}, ${height + margin.top + 40})`)
        .style('text-anchor', 'middle')
        .attr('fill', theme === 'dark' ? '#f9fafb' : '#181f26')
        .text('Sequence Length');

      svg.append('text')
        .attr('transform', `rotate(-90)`)
        .attr('y', 0)
        .attr('x', 0 - height / 2 - margin.top)
        .attr('dy', '1em')
        .style('text-anchor', 'middle')
        .attr('fill', theme === 'dark' ? '#f9fafb' : '#181f26')
        .text('Batch Size');

      // Adding legend for precisions
      const precisions = [
        { name: '32-bit', color: '#e45f5b' },
        { name: '16-bit', color: '#ffc068' },
        { name: '8-bit', color: '#71cce9' },
        { name: '4-bit', color: '#383d95' },
      ];

      const legend = svg
        .append('g')
        .attr('class', 'legend')
        .attr('transform', `translate(${width - 20}, 20)`);

      precisions.forEach((precision, index) => {
        const legendItem = legend.append('g').attr('transform', `translate(0, ${index * 30})`);

        legendItem.append('rect')
          .attr('x', 10)
          .attr('y', 10)
          .attr('width', 10)
          .attr('height', 10)
          .style('fill', precision.color);

        legendItem.append('text')
          .attr('x', 30)
          .attr('y', 16)
          .text(precision.name)
          .style('font-size', '16px')
          .style('fill', theme === 'dark' ? '#f9fafb' : '#181f26')
          .attr('alignment-baseline', 'middle');
      });

      legend.append('rect')
        .attr('class', 'legend-box')
        .attr('width', 80)
        .attr('height', precisions.length * 30)
        .style('fill', 'none')
        .style('stroke-width', '1px')
        .style('stroke', theme === 'dark' ? '#f9fafb' : '#181f26');

      const tooltip = d3.select(tooltipRef.current)
        .style('position', 'absolute')
        .style('padding', '8px')
        .style('border-radius', '4px')
        .style('pointer-events', 'none')
        .style('opacity', 0)
        .style('transition', 'opacity 0.3s ease-in-out, transform 0.3s ease-in-out')
        .style('background-color', 'rgba(0, 0, 0, 0.75)')
        .style('color', 'white')
        .style('font-size', '14px');

      for (const [precision, memory] of Object.entries(availableMemory)) {
        const sequenceLengths = d3.range(1, maxSeqLength, 1)
          .map((seqLength) => ({
            seqLength,
            batchSize: memory / (seqLength * memoryPerInput),
          }))
          .filter((d) => d.batchSize <= maxBatchSize && d.batchSize > 1 && d.seqLength > 1);

        const lineGroup = svg.append('g').attr('transform', `translate(${margin.left}, ${margin.top})`);

        const line = d3.line<LineChartData>()
          .x((d) => xScale(d.seqLength))
          .y((d) => yScale(d.batchSize))
          .curve(d3.curveBasis);

        lineGroup.append('path')
          .datum(sequenceLengths)
          .attr('fill', 'none')
          .attr('stroke', chooseColor(precision as Precision))
          .attr('stroke-width', 4)
          .attr('d', line)
          .on('mouseover', () => {
            tooltip.style('opacity', 1)
              .style('transform', 'translateY(-10px)');
          })
          .on('mousemove', (event) => {
            tooltip.selectAll('text').remove();
            const [x, y] = d3.pointer(event);
            const xValue = xScale.invert(x);
            const yValue = yScale.invert(y);
            tooltip.html(`Sequence Length: ${xValue.toFixed(0)}<br/>Batch Size: ${yValue.toFixed(0)}`)
              .style('left', event.pageX + 10 + 'px')
              .style('top', event.pageY + 10 + 'px');
          })
          .on('mouseout', () => {
            tooltip.style('opacity', 0);
          });
      }
    }
  }, [availableMemory, memoryPerInput, theme]);

  return (
    <>
      <div id="tooltip" ref={tooltipRef}></div>
      <svg ref={chartRef} width={600} height={400} />
    </>
  );
}

// Prefill Chunking Calculator with Updated Logic and Precision Adjustment
function PrefillChunkingCalculator({
  deviceMemory,
  modelParams,
  hiddenSize,
  numLayers,
  intermediateSize,
}: {
  deviceMemory: number;
  modelParams: number;
  hiddenSize: number;
  numLayers: number;
  intermediateSize: number;
}) {
  if (!deviceMemory || !modelParams || !hiddenSize || !numLayers || !intermediateSize) {
    return null;
  }

  // Calculate activation memory size based on intermediate size and hidden size
  const activationMemorySize = (512 * 2 * (Math.max(2 * intermediateSize, 4 * hiddenSize))) / 1_000_000_000;

  return (
    <>
      {/* Model Footprint with Prefill Chunking */}
      <div className="chart">
        <div style={{ animation: 'fadeIn 0.3s ease-in-out' }} className="text-2xl text-center mb-4">Model Footprint with Prefill Chunking</div>
        <div className="space-y-8">
          {(['32-bit', '16-bit', '8-bit', '4-bit'] as Precision[]).map((precision) => {
            const totalMemory = calculatePrefillMemory(
              modelParams,
              hiddenSize,
              numLayers,
              intermediateSize,
              precision
            );
            return (
              <div key={precision} style={{ animation: 'fadeIn 0.3s ease-in-out' }} className="chart-row">
                <div className="chart-row-title">{precision.toUpperCase()}</div>
                <ModelSizeBarChart
                  modelSize={totalMemory}
                  largestModelSize={deviceMemory}
                  modelPrecision={precision}
                  deviceMemorySet={deviceMemory > 0}
                  activationMemorySize={activationMemorySize} // Updated to pass activation memory size
                />
                <div className="chart-row-size ml-8">
                  {totalMemory.toFixed(2)} / {deviceMemory} GB
                </div>
              </div>
            );
          })}
        </div>
      </div>

      {/* Inference Runtime with Prefill Chunking */}
      <div className="chart">
        <div style={{ animation: 'fadeIn 0.3s ease-in-out' }} className="text-2xl text-center mb-4">
          Maximum Batch Size / Sequence Length with Prefill Chunking
        </div>
        <InferenceRuntimeLineChart
          availableMemory={{
            '4-bit': deviceMemory - calculatePrefillMemory(modelParams, hiddenSize, numLayers, intermediateSize, '4-bit'),
            '8-bit': deviceMemory - calculatePrefillMemory(modelParams, hiddenSize, numLayers, intermediateSize, '8-bit'),
            '16-bit': deviceMemory - calculatePrefillMemory(modelParams, hiddenSize, numLayers, intermediateSize, '16-bit'),
            '32-bit': deviceMemory - calculatePrefillMemory(modelParams, hiddenSize, numLayers, intermediateSize, '32-bit'),
          }}
          memoryPerInput={(4 * hiddenSize * numLayers) / 1_000_000_000}
        />
      </div>
    </>
  );
}

// Standard Model Memory Calculator (unchanged)
function StandardCalculator({
  deviceMemory,
  modelParams,
  hiddenSize,
  numLayers,
}: {
  deviceMemory: number;
  modelParams: number;
  hiddenSize: number;
  numLayers: number;
}) {
  if (!deviceMemory || !modelParams || !hiddenSize || !numLayers) {
    return null;
  }

  return (
    <>
      {/* Model Footprint */}
      <div className="chart mb-8">
        <div style={{ animation: 'fadeIn 0.3s ease-in-out' }} className="text-2xl text-center mb-4">Model Footprint</div>
        <div className="space-y-8">
          {(['32-bit', '16-bit', '8-bit', '4-bit'] as Precision[]).map((precision) => (
            <div key={precision} style={{ animation: 'fadeIn 0.3s ease-in-out' }} className="chart-row">
              <div className="chart-row-title">{precision.toUpperCase()}</div>
              <ModelSizeBarChart
                modelSize={calculateStandardMemory(modelParams, precision)}
                largestModelSize={deviceMemory}
                modelPrecision={precision}
                deviceMemorySet={deviceMemory > 0}
              />
              <div className="chart-row-size ml-8">
                {calculateStandardMemory(modelParams, precision).toFixed(2)} / {deviceMemory} GB
              </div>
            </div>
          ))}
        </div>
      </div>

      {/* Maximum Batch Size / Sequence Length */}
      <div className="chart">
        <div style={{ animation: 'fadeIn 0.3s ease-in-out' }} className="text-2xl text-center mb-4">
          Maximum Batch Size / Sequence Length
        </div>
        <InferenceRuntimeLineChart
          availableMemory={{
            '4-bit': deviceMemory - calculateStandardMemory(modelParams, '4-bit'),
            '8-bit': deviceMemory - calculateStandardMemory(modelParams, '8-bit'),
            '16-bit': deviceMemory - calculateStandardMemory(modelParams, '16-bit'),
            '32-bit': deviceMemory - calculateStandardMemory(modelParams, '32-bit'),
          }}
          memoryPerInput={(4 * hiddenSize * numLayers) / 1_000_000_000}
        />
      </div>
    </>
  );
}

// Main Calculator Page
const Calculator = () => {
  const [modelParams, setModelParams] = useState<number | null>(null);
  const [hiddenSize, setHiddenSize] = useState<number | null>(null);
  const [numLayers, setNumLayers] = useState<number | null>(null);
  const [intermediateSize, setIntermediateSize] = useState<number | null>(null);
  const [deviceMemory, setDeviceMemory] = useState<number | null>(null);
  const [isPrefillChunking, setIsPrefillChunking] = useState<boolean>(false);
  const [modelSelectionTab, setModelSelectionTab] = useState<boolean>(true);
  const [deviceSelectionTab, setDeviceSelectionTab] = useState<boolean>(true);

  return (
    <div className="flex flex-col items-center justify-center min-h-screen px-4">
      {/* Toggle Between Standard and Prefill Chunking */}
      <div className="mb-4 flex space-x-4">
        <button
          style={{ transition: 'background-color 0.3s ease-in-out, color 0.3s ease-in-out' }}
          className={`${!isPrefillChunking ? 'calculator-input-tab-active' : 'calculator-input-tab'}`}
          onClick={() => setIsPrefillChunking(false)}
        >
          Standard Calculator
        </button>
        <button
          style={{ transition: 'background-color 0.3s ease-in-out, color 0.3s ease-in-out' }}
          className={`${isPrefillChunking ? 'calculator-input-tab-active' : 'calculator-input-tab'}`}
          onClick={() => setIsPrefillChunking(true)}
        >
          Calculator with Prefill Chunking
        </button>
      </div>

      {/* Model and Device Selection */}
      <div className="w-full max-w-4xl">
        <div style={{ animation: 'fadeIn 0.3s ease-in-out' }} className="text-4xl mb-4 text-center">Model Memory Calculator</div>
        <div className="mb-6 text-center">
          Use our Model Memory Calculator to help you estimate the memory footprint of your model
          and the maximum batch size/sequence length combination you can run on your device.
        </div>

        <div className="grid grid-cols-1 sm:grid-cols-2 gap-4 mb-6">
          {/* Model Selection */}
          <div className="calculator-input-box">
            <div className="text-2xl calculator-input-title">Model</div>
            <div className="calculator-input-content">
              <div className="mb-2">
                <button
                  style={{ transition: 'background-color 0.3s ease-in-out, color 0.3s ease-in-out' }}
                  className={`${modelSelectionTab ? 'calculator-input-tab-active' : 'calculator-input-tab'}`}
                  onClick={() => setModelSelectionTab(true)}
                >
                  Model Selection
                </button>
                <button
                  style={{ transition: 'background-color 0.3s ease-in-out, color 0.3s ease-in-out' }}
                  className={`${modelSelectionTab ? 'calculator-input-tab' : 'calculator-input-tab-active'}`}
                  onClick={() => setModelSelectionTab(false)}
                >
                  Custom Model
                </button>
              </div>
              <div>
                {modelSelectionTab ? (
                  <>
                    <label htmlFor="model">Select a Model</label>
                    <select
                      id="model"
                      className="calculator-select"
                      style={{ transition: 'border 0.3s ease-in-out, box-shadow 0.3s ease-in-out' }}
                      onChange={(e) => {
                        const selectedModel = MODELS.find(
                          (model) => model.params === Number(e.target.value)
                        );
                        if (selectedModel) {
                          setModelParams(selectedModel.params);
                          setHiddenSize(selectedModel.hidden_size);
                          setNumLayers(selectedModel.num_hidden_layers);
                          setIntermediateSize(selectedModel.intermediate_size);
                        }
                      }}
                    >
                      <option value="">None selected</option>
                      {MODELS.map((model) => (
                        <option key={model.name} value={model.params}>
                          {model.name}
                        </option>
                      ))}
                    </select>
                  </>
                ) : (
                  <>
                    <label htmlFor="modelParams">Model Parameters (in billions)</label>
                    <input
                      type="number"
                      id="modelParams"
                      className="calculator-input mb-2"
                      placeholder="e.g. 7 (for LLaMA-7B)"
                      style={{ transition: 'border 0.3s ease-in-out, box-shadow 0.3s ease-in-out' }}
                      value={modelParams || ''}
                      min={0}
                      onChange={(e) => setModelParams(Number(e.target.value))}
                    />
                    <label htmlFor="hiddenSize">Hidden Size</label>
                    <input
                      type="number"
                      id="hiddenSize"
                      className="calculator-input mb-2"
                      placeholder="e.g. 4096 (for LLaMA-7B)"
                      style={{ transition: 'border 0.3s ease-in-out, box-shadow 0.3s ease-in-out' }}
                      value={hiddenSize || ''}
                      min={1}
                      onChange={(e) => setHiddenSize(Number(e.target.value))}
                    />
                    <label htmlFor="numLayers">Number of Layers</label>
                    <input
                      type="number"
                      id="numLayers"
                      className="calculator-input"
                      placeholder="e.g. 32 (for LLaMA-7B)"
                      style={{ transition: 'border 0.3s ease-in-out, box-shadow 0.3s ease-in-out' }}
                      value={numLayers || ''}
                      min={1}
                      onChange={(e) => setNumLayers(Number(e.target.value))}
                    />
                    {isPrefillChunking && (
                      <>
                        <label htmlFor="intermediateSize">Intermediate Size</label>
                        <input
                          type="number"
                          id="intermediateSize"
                          className="calculator-input"
                          placeholder="e.g. 11008 (for LLaMA-7B)"
                          style={{ transition: 'border 0.3s ease-in-out, box-shadow 0.3s ease-in-out' }}
                          value={intermediateSize || ''}
                          min={1}
                          onChange={(e) => setIntermediateSize(Number(e.target.value))}
                        />
                      </>
                    )}
                  </>
                )}
              </div>
            </div>
          </div>

          {/* Device Selection */}
          <div className="calculator-input-box">
            <div className="text-2xl calculator-input-title">Device</div>
            <div className="calculator-input-content">
              <div className="mb-2">
                <button
                  style={{ transition: 'background-color 0.3s ease-in-out, color 0.3s ease-in-out' }}
                  className={`${deviceSelectionTab ? 'calculator-input-tab-active' : 'calculator-input-tab'}`}
                  onClick={() => {
                    setDeviceSelectionTab(true);
                    setDeviceMemory(null);
                  }}
                >
                  Device Selection
                </button>
                <button
                  style={{ transition: 'background-color 0.3s ease-in-out, color 0.3s ease-in-out' }}
                  className={`${deviceSelectionTab ? 'calculator-input-tab' : 'calculator-input-tab-active'}`}
                  onClick={() => {
                    setDeviceSelectionTab(false);
                    setDeviceMemory(null);
                  }}
                >
                  Custom Device
                </button>
              </div>
              <div>
                {deviceSelectionTab ? (
                  <>
                    <label htmlFor="device">Select a Device</label>
                    <select
                      id="device"
                      className="calculator-select"
                      style={{ transition: 'border 0.3s ease-in-out, box-shadow 0.3s ease-in-out' }}
                      onChange={(e) => setDeviceMemory(Number(e.target.value))}
                    >
                      <option value="">None selected</option>
                      {DEVICES.map((device) => (
                        <option key={device.name} value={device.size}>
                          {device.name}
                        </option>
                      ))}
                    </select>
                  </>
                ) : (
                  <>
                    <label htmlFor="deviceMemory">Device RAM (in GB)</label>
                    <input
                      type="number"
                      id="deviceMemory"
                      className="calculator-input"
                      placeholder="e.g. 24"
                      style={{ transition: 'border 0.3s ease-in-out, box-shadow 0.3s ease-in-out' }}
                      value={deviceMemory || ''}
                      min={0}
                      onChange={(e) => setDeviceMemory(Number(e.target.value))}
                    />
                  </>
                )}
              </div>
            </div>
          </div>
        </div>

        {/* Render Appropriate Calculator Based on Toggle */}
        {isPrefillChunking ? (
          // eslint-disable-next-line
          <PrefillChunkingCalculator
            deviceMemory={deviceMemory!}
            modelParams={modelParams!}
            hiddenSize={hiddenSize!}
            numLayers={numLayers!}
            intermediateSize={intermediateSize!}
          />
        ) : (
          // eslint-disable-next-line
          <StandardCalculator
            deviceMemory={deviceMemory!}
            modelParams={modelParams!}
            hiddenSize={hiddenSize!}
            numLayers={numLayers!}
          />
        )}
      </div>
    </div>
  );
};

export default Calculator;