# MODIFY AS REQUIRED import torch import pandas as pd import seaborn as sns import numpy as np import matplotlib.pyplot as plt import plotly.express as px import plotly.graph_objects as go from plotly.subplots import make_subplots from sklearn.model_selection import train_test_split from datasets import load_dataset from datasets import Dataset, DatasetDict from transformers import DataCollatorWithPadding from torch.utils.data import DataLoader from transformers import AutoTokenizer from torch.optim import AdamW from torch.nn import BCEWithLogitsLoss from transformers import get_scheduler from tqdm.auto import tqdm import evaluate from tqdm import tqdm import logging logging.basicConfig(level=logging.INFO) from text_preprocessing import clean_tweet, clear_reply_mentions, normalizeTweet from custom_model import CustomModel ''' DATA_PATH = "../../data" PROCESSED_PATH = f"{DATA_PATH}/processed" PROCESSED_PATH_VIRAL = f'{DATA_PATH}/new/processed/viral' PROCESSED_PATH_COVID = f'{DATA_PATH}/new/processed/covid' ''' # Different models BERT_BASE_UNCASED = "bert-base-uncased" BERT_BASE_CASED = "bert-base-cased" ROBERTA_BASE = "roberta-base" BERT_TWEET = "vinai/bertweet-base" # TODO: Don't forget to cite papers if you use some model BERT_TINY = "prajjwal1/bert-tiny" TWEET_MAX_LENGTH = 280 # TEST SPLIT RATIO + MODELS (ADD MORE MODELS FROM ABOVE) MODELS = [BERT_TWEET, BERT_TINY, BERT_BASE_CASED, ROBERTA_BASE] TEST_RATIO = 0.2 TOP_FEATURES = ["verified", "tweet_length", "possibly_sensitive", "sentiment", "nb_of_hashtags", "has_media", "nb_of_mentions"] def preprocess_data(dataset): dataset.loc[:, 'has_media'] = dataset.has_media.astype("int") dataset.loc[:, 'possibly_sensitive'] = dataset.possibly_sensitive.astype("int") #dataset = dataset[dataset.sentiment_score > 0.7] dataset.loc[:, 'sentiment'] = dataset.sentiment.replace({'POSITIVE': 1, 'NEGATIVE': 0}) dataset.loc[:, 'verified'] = dataset['verified'].astype(int) # remove tweets with 0 retweets (to eliminate their effects) #dataset = dataset[dataset.retweet_count > 0] ## UPDATE: Get tweets tweeted by the same user, on the same day he tweeted a viral tweet # Get the date from datetime # normalize() sets all datetimes clock to midnight, which is equivalent as keeping only the date part dataset['date'] = dataset.created_at.dt.normalize() viral_tweets = dataset[dataset.viral] non_viral_tweets = dataset[~dataset.viral] temp = non_viral_tweets.merge(viral_tweets[['author_id', 'date', 'id', 'viral']], on=['author_id', 'date'], suffixes=(None, '_y')) same_day_viral_ids = temp.id_y.unique() same_day_viral_tweets = viral_tweets[viral_tweets.id.isin(same_day_viral_ids)].drop_duplicates(subset=['author_id', 'date']) same_day_non_viral_tweets = temp.drop_duplicates(subset=['author_id', 'date']) logging.info(f"Number of viral tweets tweeted on the same day {len(same_day_viral_tweets)}") logging.info(f"Number of non viral tweets tweeted on the same day {len(same_day_non_viral_tweets)}") dataset = pd.concat([same_day_viral_tweets, same_day_non_viral_tweets], axis=0) dataset = dataset[['id', 'text'] + TOP_FEATURES + ['viral']] # Balance classes to have as many viral as non viral ones #dataset = pd.concat([positives, negatives.sample(n=len(positives))]) #dataset = pd.concat([positives.iloc[:100], negatives.sample(n=len(positives)).iloc[:200]]) # Clean text to prepare for tokenization #dataset = dataset.dropna() dataset.loc[:, "viral"] = dataset.viral.astype(int) # TODO: COMMENT IF YOU WANT TO KEEP TEXT AS IS dataset["cleaned_text"] = dataset.text.apply(lambda x: clean_tweet(x, demojize_emojis=False)) dataset = dataset.dropna() dataset.loc[:, "extra_features"] = dataset[TOP_FEATURES].values.tolist() dataset = dataset[['id', 'cleaned_text', 'extra_features', 'viral']] return dataset def prepare_dataset(sample_data, balance=False): # Split the train and test data st each has a fixed proportion of viral tweets train_dataset, eval_dataset = train_test_split(sample_data, test_size=TEST_RATIO, random_state=42, stratify=sample_data.viral) # Balance test set if balance: eval_virals = eval_dataset[eval_dataset.viral == 1] eval_non_virals = eval_dataset[eval_dataset.viral == 0] eval_dataset = pd.concat([eval_virals, eval_non_virals.sample(n=len(eval_virals))]) logging.info('{:>5,} training samples with {:>5,} positives and {:>5,} negatives'.format( len(train_dataset), len(train_dataset[train_dataset.viral == 1]), len(train_dataset[train_dataset.viral == 0]))) logging.info('{:>5,} validation samples with {:>5,} positives and {:>5,} negatives'.format( len(eval_dataset), len(eval_dataset[eval_dataset.viral == 1]), len(eval_dataset[eval_dataset.viral == 0]))) train_dataset.to_parquet("train.parquet.gzip", compression='gzip') eval_dataset.to_parquet("test.parquet.gzip", compression='gzip') ds = load_dataset("parquet", data_files={'train': 'train.parquet.gzip', 'test': 'test.parquet.gzip'}) return ds def tokenize_function(example, tokenizer): # Truncate to max length. Note that a tweet's maximum length is 280 # TODO: check dynamic padding: https://huggingface.co/course/chapter3/2?fw=pt#dynamic-padding return tokenizer(example["cleaned_text"], truncation=True) def test_all_models(ds, nb_extra_dims, models=MODELS): models_losses = {} device = torch.device("mps") if torch.mps.is_available() else torch.device("cpu") output = "" for checkpoint in models: torch.mps.empty_cache() tokenizer = AutoTokenizer.from_pretrained(checkpoint) custom_model = CustomModel(checkpoint, num_extra_dims=nb_extra_dims, num_labels=2) custom_model.to(device) tokenized_datasets = ds.map(lambda x: tokenize_function(x, tokenizer=tokenizer), batched=True) data_collator = DataCollatorWithPadding(tokenizer=tokenizer) tokenized_datasets = tokenized_datasets.remove_columns(["__index_level_0__", "cleaned_text", "id"]) tokenized_datasets = tokenized_datasets.rename_column("viral", "labels") tokenized_datasets.set_format("torch") batch_size = 32 train_dataloader = DataLoader(tokenized_datasets["train"], shuffle=True, batch_size=batch_size, collate_fn=data_collator) eval_dataloader = DataLoader(tokenized_datasets["test"], batch_size=batch_size, collate_fn=data_collator) criterion = BCEWithLogitsLoss() optimizer = AdamW(custom_model.parameters(), lr=5e-5) num_epochs = 15 num_training_steps = num_epochs * len(train_dataloader) lr_scheduler = get_scheduler( name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps ) progress_bar = tqdm(range(num_training_steps)) losses = [] custom_model.train() for epoch in range(num_epochs): for batch in train_dataloader: batch = {k: v.to(device) for k, v in batch.items()} logits = custom_model(**batch).squeeze() loss = criterion(logits, batch['labels'].float()) #losses.append(loss.cpu().item()) losses.append(loss.item()) loss.backward() optimizer.step() lr_scheduler.step() optimizer.zero_grad() progress_bar.update(1) models_losses[checkpoint] = losses metric = evaluate.combine(["accuracy", "recall", "precision", "f1"]) custom_model.eval() for batch in eval_dataloader: batch = {k: v.to(device) for k, v in batch.items()} with torch.no_grad(): logits = custom_model(**batch) #predictions = torch.argmax(outputs, dim=-1) predictions = torch.round(torch.sigmoid(logits)) metric.add_batch(predictions=predictions, references=batch["labels"]) output += f"checkpoint: {checkpoint}: {metric.compute()}\n" logging.info(output) with open("same_day_as_viral_with_features_train_test_balanced_accuracy.txt", "w") as text_file: text_file.write(output) return models_losses def main(): # DATA FILE SHOULD BE AT THE ROOT WITH THIS SCRIPT all_tweets_labeled = pd.read_parquet(f'final_dataset_since_october_2022.parquet.gzip') dataset = preprocess_data(all_tweets_labeled) ds = prepare_dataset(dataset, balance=True) nb_extra_dims = len(TOP_FEATURES) test_all_models(ds, nb_extra_dims=nb_extra_dims) if __name__ == "__main__": main()