import os
import kaggle
import tempfile
import requests
import multiprocessing

import pandas as pd

from bs4 import BeautifulSoup
from concurrent.futures import ThreadPoolExecutor

def _generate_sources() -> pd.DataFrame:
	""" Generate a dataset containing urls to retrieve data from"""
	dataset = pd.DataFrame({'type': [], 'name': [], 'url': []})
	with tempfile.TemporaryDirectory() as temp_dir:
		kaggle.api.dataset_download_files('rohanrao/formula-1-world-championship-1950-2020', path=temp_dir, unzip=True)
		df = pd.read_csv(temp_dir + '/circuits.csv')
		
		# remove all columns except 'name' and 'url'
		df = df[['name', 'url']]
		df['type'] = 'circuit'
		dataset = pd.concat([dataset, df], ignore_index=True)

		# Drivers
		df = pd.read_csv(temp_dir + '/drivers.csv')

    	# remove all columns except 'forename', 'surname' and 'url'
		df = df[['forename', 'surname', 'url']]

    	# Join 'forename' and 'surname' columns
		df['name'] = df['forename'] + ' ' + df['surname']

		df = df[['name', 'url']]
		df['type'] = 'driver'

		dataset = pd.concat([dataset, df], ignore_index=True)

		# Constructors
		df = pd.read_csv(temp_dir + '/constructors.csv')

    	# Remove broken links
		df = df[(df['url'] != 'http://en.wikipedia.org/wiki/Turner_(constructor)') & (df['url'] != 'http://en.wikipedia.org/wiki/Hall_(constructor)')]

    	# remove all columns except 'name' and 'url'
		df = df[['name', 'url']]
		df['type'] = 'constructor'

		dataset = pd.concat([dataset, df], ignore_index=True)

    	# Races
		df = pd.read_csv(temp_dir + '/races.csv')

    	# remove all columns except 'name' and 'url'
		df['name'] = df['name'] + " " + df['year'].astype(str) + "-" + df['round'].astype(str)
		df = df[['name', 'url']]
		df['type'] = 'race'

		dataset = pd.concat([dataset, df], ignore_index=True)

    	# Seasons
		df = pd.read_csv(temp_dir + '/seasons.csv')

    	# remove all columns except 'year' and 'url'
		df = df[['year', 'url']]
		df['name'] = 'Year ' + df['year'].astype(str)
		
		df = df[['name', 'url']]
		df['type'] = 'season'

		dataset = pd.concat([dataset, df], ignore_index=True)
	
	return dataset
	
def _extract_paragraphs(url):
    response = requests.get(url)
    html = response.text

    soup = BeautifulSoup(html, "html.parser")

    pars = soup.find_all("p")
    pars = [p.get_text() for p in pars]
    return pars

def generate_trainset(persist: bool = True, persist_path: str = './datasets', filename='train.csv') -> pd.DataFrame:
	"""
	Generate the dataset used to train the model.
	
	Parameters:
    persist (bool): Whether to save the generated dataset to a file.
    persist_path (str): The directory where the generated dataset will be saved.
	filename (str): The name of the file to save the dataset.
	
    Returns:
    pd.DataFrame: The generated DataFrame.
	"""

	if os.path.exists(persist_path + '/' + filename):
		return pd.read_csv(f"{persist_path}/{filename}")
	
	sources = _generate_sources()

	num_threads = multiprocessing.cpu_count()
	with ThreadPoolExecutor(max_workers=num_threads) as executor:
		paragraphs = list(executor.map(_extract_paragraphs, sources['url']))
		paragraphs = [" ".join(p[0:5]).strip("\n") for p in paragraphs] 		# Take the first 4 paragraphs
		sources['description'] = paragraphs
	df = sources[['type', 'name', 'description']]

	if persist:
		os.makedirs(persist_path, exist_ok=True)
		df.to_csv(f"{persist_path}/{filename}", index=False)

	return df

def generate_ragset(persist=True, persist_path: str = './datasets', filename='rag.csv') -> pd.DataFrame:
	"""
	Generate the dataset used for Retrieval-Augmented Generation.
	
	Parameters:
    persist (bool): Whether to save the generated dataset to a file.
    persist_path (str): The directory where the generated dataset will be saved.
	filename (str): The name of the file to save the dataset.

    Returns:
    pd.DataFrame: The generated DataFrame.
	"""

	if os.path.exists(persist_path + '/' + filename):
		return pd.read_csv(f"{persist_path}/{filename}")
	
	sources = _generate_sources()

	num_threads = multiprocessing.cpu_count()
	with ThreadPoolExecutor(max_workers=num_threads) as executor:
		paragraphs = list(executor.map(_extract_paragraphs, sources['url']))
		paragraphs = [" ".join(p).strip("\n") for p in paragraphs]				# Take all the paragraphs
		sources['description'] = paragraphs
	df = sources[['type', 'name', 'description']]

	if persist:
		os.makedirs(persist_path, exist_ok=True)
		df.to_csv(f"{persist_path}/{filename}", index=False)

	return df