Spaces:
Paused
Paused
import os | |
import kaggle | |
import tempfile | |
import requests | |
import multiprocessing | |
import pandas as pd | |
from bs4 import BeautifulSoup | |
from concurrent.futures import ThreadPoolExecutor | |
def _generate_sources() -> pd.DataFrame: | |
""" Generate a dataset containing urls to retrieve data from""" | |
dataset = pd.DataFrame({'type': [], 'name': [], 'url': []}) | |
with tempfile.TemporaryDirectory() as temp_dir: | |
kaggle.api.dataset_download_files('rohanrao/formula-1-world-championship-1950-2020', path=temp_dir, unzip=True) | |
df = pd.read_csv(temp_dir + '/circuits.csv') | |
# remove all columns except 'name' and 'url' | |
df = df[['name', 'url']] | |
df['type'] = 'circuit' | |
dataset = pd.concat([dataset, df], ignore_index=True) | |
# Drivers | |
df = pd.read_csv(temp_dir + '/drivers.csv') | |
# remove all columns except 'forename', 'surname' and 'url' | |
df = df[['forename', 'surname', 'url']] | |
# Join 'forename' and 'surname' columns | |
df['name'] = df['forename'] + ' ' + df['surname'] | |
df = df[['name', 'url']] | |
df['type'] = 'driver' | |
dataset = pd.concat([dataset, df], ignore_index=True) | |
# Constructors | |
df = pd.read_csv(temp_dir + '/constructors.csv') | |
# Remove broken links | |
df = df[(df['url'] != 'http://en.wikipedia.org/wiki/Turner_(constructor)') & (df['url'] != 'http://en.wikipedia.org/wiki/Hall_(constructor)')] | |
# remove all columns except 'name' and 'url' | |
df = df[['name', 'url']] | |
df['type'] = 'constructor' | |
dataset = pd.concat([dataset, df], ignore_index=True) | |
# Races | |
df = pd.read_csv(temp_dir + '/races.csv') | |
# remove all columns except 'name' and 'url' | |
df['name'] = df['name'] + " " + df['year'].astype(str) + "-" + df['round'].astype(str) | |
df = df[['name', 'url']] | |
df['type'] = 'race' | |
dataset = pd.concat([dataset, df], ignore_index=True) | |
# Seasons | |
df = pd.read_csv(temp_dir + '/seasons.csv') | |
# remove all columns except 'year' and 'url' | |
df = df[['year', 'url']] | |
df['name'] = 'Year ' + df['year'].astype(str) | |
df = df[['name', 'url']] | |
df['type'] = 'season' | |
dataset = pd.concat([dataset, df], ignore_index=True) | |
return dataset | |
def _extract_paragraphs(url): | |
response = requests.get(url) | |
html = response.text | |
soup = BeautifulSoup(html, "html.parser") | |
pars = soup.find_all("p") | |
pars = [p.get_text() for p in pars] | |
return pars | |
def generate_trainset(persist: bool = True, persist_path: str = './datasets', filename='train.csv') -> pd.DataFrame: | |
""" | |
Generate the dataset used to train the model. | |
Parameters: | |
persist (bool): Whether to save the generated dataset to a file. | |
persist_path (str): The directory where the generated dataset will be saved. | |
filename (str): The name of the file to save the dataset. | |
Returns: | |
pd.DataFrame: The generated DataFrame. | |
""" | |
if os.path.exists(persist_path + '/' + filename): | |
return pd.read_csv(f"{persist_path}/{filename}") | |
sources = _generate_sources() | |
num_threads = multiprocessing.cpu_count() | |
with ThreadPoolExecutor(max_workers=num_threads) as executor: | |
paragraphs = list(executor.map(_extract_paragraphs, sources['url'])) | |
paragraphs = [" ".join(p[0:5]).strip("\n") for p in paragraphs] # Take the first 4 paragraphs | |
sources['description'] = paragraphs | |
df = sources[['type', 'name', 'description']] | |
if persist: | |
os.makedirs(persist_path, exist_ok=True) | |
df.to_csv(f"{persist_path}/{filename}", index=False) | |
return df | |
def generate_ragset(persist=True, persist_path: str = './datasets', filename='rag.csv') -> pd.DataFrame: | |
""" | |
Generate the dataset used for Retrieval-Augmented Generation. | |
Parameters: | |
persist (bool): Whether to save the generated dataset to a file. | |
persist_path (str): The directory where the generated dataset will be saved. | |
filename (str): The name of the file to save the dataset. | |
Returns: | |
pd.DataFrame: The generated DataFrame. | |
""" | |
if os.path.exists(persist_path + '/' + filename): | |
return pd.read_csv(f"{persist_path}/{filename}") | |
sources = _generate_sources() | |
num_threads = multiprocessing.cpu_count() | |
with ThreadPoolExecutor(max_workers=num_threads) as executor: | |
paragraphs = list(executor.map(_extract_paragraphs, sources['url'])) | |
paragraphs = [" ".join(p).strip("\n") for p in paragraphs] # Take all the paragraphs | |
sources['description'] = paragraphs | |
df = sources[['type', 'name', 'description']] | |
if persist: | |
os.makedirs(persist_path, exist_ok=True) | |
df.to_csv(f"{persist_path}/{filename}", index=False) | |
return df |