FormuLLaMa-Demo / utils /dataset_utils.py
MarioCerulo's picture
Upload 3 files
50b2d56 verified
import os
import kaggle
import tempfile
import requests
import multiprocessing
import pandas as pd
from bs4 import BeautifulSoup
from concurrent.futures import ThreadPoolExecutor
def _generate_sources() -> pd.DataFrame:
""" Generate a dataset containing urls to retrieve data from"""
dataset = pd.DataFrame({'type': [], 'name': [], 'url': []})
with tempfile.TemporaryDirectory() as temp_dir:
kaggle.api.dataset_download_files('rohanrao/formula-1-world-championship-1950-2020', path=temp_dir, unzip=True)
df = pd.read_csv(temp_dir + '/circuits.csv')
# remove all columns except 'name' and 'url'
df = df[['name', 'url']]
df['type'] = 'circuit'
dataset = pd.concat([dataset, df], ignore_index=True)
# Drivers
df = pd.read_csv(temp_dir + '/drivers.csv')
# remove all columns except 'forename', 'surname' and 'url'
df = df[['forename', 'surname', 'url']]
# Join 'forename' and 'surname' columns
df['name'] = df['forename'] + ' ' + df['surname']
df = df[['name', 'url']]
df['type'] = 'driver'
dataset = pd.concat([dataset, df], ignore_index=True)
# Constructors
df = pd.read_csv(temp_dir + '/constructors.csv')
# Remove broken links
df = df[(df['url'] != 'http://en.wikipedia.org/wiki/Turner_(constructor)') & (df['url'] != 'http://en.wikipedia.org/wiki/Hall_(constructor)')]
# remove all columns except 'name' and 'url'
df = df[['name', 'url']]
df['type'] = 'constructor'
dataset = pd.concat([dataset, df], ignore_index=True)
# Races
df = pd.read_csv(temp_dir + '/races.csv')
# remove all columns except 'name' and 'url'
df['name'] = df['name'] + " " + df['year'].astype(str) + "-" + df['round'].astype(str)
df = df[['name', 'url']]
df['type'] = 'race'
dataset = pd.concat([dataset, df], ignore_index=True)
# Seasons
df = pd.read_csv(temp_dir + '/seasons.csv')
# remove all columns except 'year' and 'url'
df = df[['year', 'url']]
df['name'] = 'Year ' + df['year'].astype(str)
df = df[['name', 'url']]
df['type'] = 'season'
dataset = pd.concat([dataset, df], ignore_index=True)
return dataset
def _extract_paragraphs(url):
response = requests.get(url)
html = response.text
soup = BeautifulSoup(html, "html.parser")
pars = soup.find_all("p")
pars = [p.get_text() for p in pars]
return pars
def generate_trainset(persist: bool = True, persist_path: str = './datasets', filename='train.csv') -> pd.DataFrame:
"""
Generate the dataset used to train the model.
Parameters:
persist (bool): Whether to save the generated dataset to a file.
persist_path (str): The directory where the generated dataset will be saved.
filename (str): The name of the file to save the dataset.
Returns:
pd.DataFrame: The generated DataFrame.
"""
if os.path.exists(persist_path + '/' + filename):
return pd.read_csv(f"{persist_path}/{filename}")
sources = _generate_sources()
num_threads = multiprocessing.cpu_count()
with ThreadPoolExecutor(max_workers=num_threads) as executor:
paragraphs = list(executor.map(_extract_paragraphs, sources['url']))
paragraphs = [" ".join(p[0:5]).strip("\n") for p in paragraphs] # Take the first 4 paragraphs
sources['description'] = paragraphs
df = sources[['type', 'name', 'description']]
if persist:
os.makedirs(persist_path, exist_ok=True)
df.to_csv(f"{persist_path}/{filename}", index=False)
return df
def generate_ragset(persist=True, persist_path: str = './datasets', filename='rag.csv') -> pd.DataFrame:
"""
Generate the dataset used for Retrieval-Augmented Generation.
Parameters:
persist (bool): Whether to save the generated dataset to a file.
persist_path (str): The directory where the generated dataset will be saved.
filename (str): The name of the file to save the dataset.
Returns:
pd.DataFrame: The generated DataFrame.
"""
if os.path.exists(persist_path + '/' + filename):
return pd.read_csv(f"{persist_path}/{filename}")
sources = _generate_sources()
num_threads = multiprocessing.cpu_count()
with ThreadPoolExecutor(max_workers=num_threads) as executor:
paragraphs = list(executor.map(_extract_paragraphs, sources['url']))
paragraphs = [" ".join(p).strip("\n") for p in paragraphs] # Take all the paragraphs
sources['description'] = paragraphs
df = sources[['type', 'name', 'description']]
if persist:
os.makedirs(persist_path, exist_ok=True)
df.to_csv(f"{persist_path}/{filename}", index=False)
return df