import csv
import random

import spacy
import srsly
import tqdm
import yaml

params = yaml.safe_load(open("params.yaml"))

nlp = spacy.load("en_core_web_trf")

INPUT_FILE = "data/processed/wellcome_grant_descriptions.csv"
OUTPUT_FILE = "data/processed/entities.jsonl"
INCLUDE_ENTS = {"GPE", "LOC"}
EXCLUDE_ENTS = {"PERSON"}


def process_documents(input_file: str, output_file: str):

    data = []

    print(f"Reading data from {input_file}...")

    with open(input_file, "r") as f:
        reader = csv.reader(f)
        next(reader)

        for row in reader:
            data.append(row[0])

    print(f"Processing {len(data)} documents...")

    entities = []

    for doc_ in tqdm.tqdm(data):
        doc = nlp(doc_)

        # Get a list of found entities

        ents = [
            {
                "text": ent.text,
                "label": ent.label_,
                "start": ent.start_char,
                "end": ent.end_char,
            }
            for ent in doc.ents
        ]

        if ents:
            found_ents = set([ent["label"] for ent in ents])

            if found_ents.intersection(INCLUDE_ENTS) and not found_ents.intersection(
                EXCLUDE_ENTS
            ):
                entities.append(
                    {
                        "text": doc.text,
                        "ents": ents,
                    }
                )

    print(f"Randomly selecting {params['max_docs']} documents...")

    random.shuffle(entities)
    entities = entities[: params["max_docs"]]

    print(f"Writing {len(entities)} documents to {output_file}...")

    srsly.write_jsonl(output_file, entities)


if __name__ == "__main__":
    process_documents(INPUT_FILE, OUTPUT_FILE)