diff --git a/__pycache__/parse.cpython-38.pyc b/__pycache__/parse.cpython-38.pyc
index 91f1165b32821bd9de19f13c1cc2eefee1394ea6..efc1dd9f5d8ded68c8d8d07fd7fb0e71f12f499e 100644
Binary files a/__pycache__/parse.cpython-38.pyc and b/__pycache__/parse.cpython-38.pyc differ
diff --git a/app.py b/app.py
index f1c54fb8b3fee713ce3296bd3208a8146ab3cbee..7e32cb2603cb8ae19556299fe28beb9de6b9d24e 100644
--- a/app.py
+++ b/app.py
@@ -1,5 +1,5 @@
 import streamlit as st
-# from parse import parse_text
+from parse import parse
 from nltk import Tree
 import pandas as pd
 import re
@@ -31,19 +31,21 @@ if text:
   
   df = pd.DataFrame(zipped, columns=['Token', 'Tag', 'Prob.'])
   
-  # # Convert the bracket parse tree into an NLTK Tree
-  # t = Tree.fromstring(re.sub(r'(\.[^ )]+)+', '', parse_tree))
+  parse_tree = parse(tokens)
   
-  # tree_svg = TreePrettyPrinter(t).svg(nodecolor='black', leafcolor='black', funccolor='black')
+  # Convert the bracket parse tree into an NLTK Tree
+  t = Tree.fromstring(re.sub(r'-[^ )]*', '', parse_tree))
+  
+  tree_svg = TreePrettyPrinter(t).svg(nodecolor='black', leafcolor='black', funccolor='black')
   
   col1 = st.columns(1)[0]
   col1.header("POS tagging result:")
   col1.table(df)
   
-#   col2 = st.columns(1)[0]
-#   col2.header("Parsing result:")
-#   col2.write(parse_tree.replace('_', '\_').replace('$', '\$').replace('*', '\*'))
+  col2 = st.columns(1)[0]
+  col2.header("Parsing result:")
+  col2.write(parse_tree.replace('_', '\_').replace('$', '\$').replace('*', '\*'))
 
-# # Display the graph in the Streamlit app
-#   col2.image(tree_svg, use_column_width=True)
+# Display the graph in the Streamlit app
+  col2.image(tree_svg, use_column_width=True)
     
diff --git a/benepar/__init__.py b/benepar/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d6ad660648d0c407a0cda608e0e5fc027ca33c5
--- /dev/null
+++ b/benepar/__init__.py
@@ -0,0 +1,20 @@
+"""
+benepar: Berkeley Neural Parser
+"""
+
+# This file and all code in integrations/ relate to the version of the parser
+# released via PyPI. If you only need to run research experiments, it is safe
+# to delete the integrations/ folder and replace this __init__.py with an
+# empty file.
+
+__all__ = [
+    "Parser",
+    "InputSentence",
+    "download",
+    "BeneparComponent",
+    "NonConstituentException",
+]
+
+from .integrations.downloader import download
+from .integrations.nltk_plugin import Parser, InputSentence
+from .integrations.spacy_plugin import BeneparComponent, NonConstituentException
diff --git a/benepar/__pycache__/__init__.cpython-310.pyc b/benepar/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..45a19cfab4c95912135340975ee7c80d9e6653c9
Binary files /dev/null and b/benepar/__pycache__/__init__.cpython-310.pyc differ
diff --git a/benepar/__pycache__/__init__.cpython-37.pyc b/benepar/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4d77d11129d29385632095630e86529454267644
Binary files /dev/null and b/benepar/__pycache__/__init__.cpython-37.pyc differ
diff --git a/benepar/__pycache__/__init__.cpython-38.pyc b/benepar/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0da1ed57d62dc6473027badf2011e8f38258aa7f
Binary files /dev/null and b/benepar/__pycache__/__init__.cpython-38.pyc differ
diff --git a/benepar/__pycache__/char_lstm.cpython-310.pyc b/benepar/__pycache__/char_lstm.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dc3c7ea69a2b62715f0b44ad8ebb01a9b76a1072
Binary files /dev/null and b/benepar/__pycache__/char_lstm.cpython-310.pyc differ
diff --git a/benepar/__pycache__/char_lstm.cpython-37.pyc b/benepar/__pycache__/char_lstm.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..660cc6476b5690ff29bf5f7529664bec9db83385
Binary files /dev/null and b/benepar/__pycache__/char_lstm.cpython-37.pyc differ
diff --git a/benepar/__pycache__/char_lstm.cpython-38.pyc b/benepar/__pycache__/char_lstm.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..323b09a8ec3d934999aaad1cb482e3b10a75a092
Binary files /dev/null and b/benepar/__pycache__/char_lstm.cpython-38.pyc differ
diff --git a/benepar/__pycache__/decode_chart.cpython-310.pyc b/benepar/__pycache__/decode_chart.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3e3431b0d230bd82050ba9f408cb7cc8eddfaea1
Binary files /dev/null and b/benepar/__pycache__/decode_chart.cpython-310.pyc differ
diff --git a/benepar/__pycache__/decode_chart.cpython-37.pyc b/benepar/__pycache__/decode_chart.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..58c7245b3c10c4ac4a03e513b40e702e54bf6dd8
Binary files /dev/null and b/benepar/__pycache__/decode_chart.cpython-37.pyc differ
diff --git a/benepar/__pycache__/decode_chart.cpython-38.pyc b/benepar/__pycache__/decode_chart.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ff87ed9f80b99e40af476df23eb396a8a2ba351a
Binary files /dev/null and b/benepar/__pycache__/decode_chart.cpython-38.pyc differ
diff --git a/benepar/__pycache__/nkutil.cpython-310.pyc b/benepar/__pycache__/nkutil.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a4e63dc023c3ea8621d9c30a145aa344c429e029
Binary files /dev/null and b/benepar/__pycache__/nkutil.cpython-310.pyc differ
diff --git a/benepar/__pycache__/nkutil.cpython-37.pyc b/benepar/__pycache__/nkutil.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2db77757aaf4cf3c9d6fc39ef488f8ceae44e031
Binary files /dev/null and b/benepar/__pycache__/nkutil.cpython-37.pyc differ
diff --git a/benepar/__pycache__/nkutil.cpython-38.pyc b/benepar/__pycache__/nkutil.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..07a6ad6c28301eb8c4502c11bb1250aa9afb053f
Binary files /dev/null and b/benepar/__pycache__/nkutil.cpython-38.pyc differ
diff --git a/benepar/__pycache__/parse_base.cpython-310.pyc b/benepar/__pycache__/parse_base.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4c85df1217db38b533bbe05dd804240c97ed2f73
Binary files /dev/null and b/benepar/__pycache__/parse_base.cpython-310.pyc differ
diff --git a/benepar/__pycache__/parse_base.cpython-37.pyc b/benepar/__pycache__/parse_base.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f2011e6e6e6c9f547d20e43d5f496ebfe29a48a8
Binary files /dev/null and b/benepar/__pycache__/parse_base.cpython-37.pyc differ
diff --git a/benepar/__pycache__/parse_base.cpython-38.pyc b/benepar/__pycache__/parse_base.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a953338cc7b745be1886527e5cec3cc6d7a37a9c
Binary files /dev/null and b/benepar/__pycache__/parse_base.cpython-38.pyc differ
diff --git a/benepar/__pycache__/parse_chart.cpython-310.pyc b/benepar/__pycache__/parse_chart.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c508f6f3183df932a0201463dd700baae36b1972
Binary files /dev/null and b/benepar/__pycache__/parse_chart.cpython-310.pyc differ
diff --git a/benepar/__pycache__/parse_chart.cpython-37.pyc b/benepar/__pycache__/parse_chart.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..94b981d88bcbc5884b57712e367ce0caff0a2d8d
Binary files /dev/null and b/benepar/__pycache__/parse_chart.cpython-37.pyc differ
diff --git a/benepar/__pycache__/parse_chart.cpython-38.pyc b/benepar/__pycache__/parse_chart.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9308df4c1fe8573b564c26cb9043757e0347165c
Binary files /dev/null and b/benepar/__pycache__/parse_chart.cpython-38.pyc differ
diff --git a/benepar/__pycache__/partitioned_transformer.cpython-310.pyc b/benepar/__pycache__/partitioned_transformer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6db3303e19ba14903564c9461db9844040a9bb11
Binary files /dev/null and b/benepar/__pycache__/partitioned_transformer.cpython-310.pyc differ
diff --git a/benepar/__pycache__/partitioned_transformer.cpython-37.pyc b/benepar/__pycache__/partitioned_transformer.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..58454aa626b29a0d715096a0a094576700177d01
Binary files /dev/null and b/benepar/__pycache__/partitioned_transformer.cpython-37.pyc differ
diff --git a/benepar/__pycache__/partitioned_transformer.cpython-38.pyc b/benepar/__pycache__/partitioned_transformer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..74d8b4ef4b30654c29ace3cc6034c23e85f283f7
Binary files /dev/null and b/benepar/__pycache__/partitioned_transformer.cpython-38.pyc differ
diff --git a/benepar/__pycache__/ptb_unescape.cpython-310.pyc b/benepar/__pycache__/ptb_unescape.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f10e64bc86bcbbad136506cf917bda48c0dee1aa
Binary files /dev/null and b/benepar/__pycache__/ptb_unescape.cpython-310.pyc differ
diff --git a/benepar/__pycache__/ptb_unescape.cpython-37.pyc b/benepar/__pycache__/ptb_unescape.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5d6b52653a2dacc6450a2ebd7e69dd9d0f5d9423
Binary files /dev/null and b/benepar/__pycache__/ptb_unescape.cpython-37.pyc differ
diff --git a/benepar/__pycache__/ptb_unescape.cpython-38.pyc b/benepar/__pycache__/ptb_unescape.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3e59b7b3cc54d485edf5a05c299f2dbf47186abd
Binary files /dev/null and b/benepar/__pycache__/ptb_unescape.cpython-38.pyc differ
diff --git a/benepar/__pycache__/retokenization.cpython-310.pyc b/benepar/__pycache__/retokenization.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..143dbcadd97f9b68aefc3f0d8d1f808d17fd85f6
Binary files /dev/null and b/benepar/__pycache__/retokenization.cpython-310.pyc differ
diff --git a/benepar/__pycache__/retokenization.cpython-37.pyc b/benepar/__pycache__/retokenization.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e5ac8095a1a2de0062fd5b84d767b411ae15a9fd
Binary files /dev/null and b/benepar/__pycache__/retokenization.cpython-37.pyc differ
diff --git a/benepar/__pycache__/retokenization.cpython-38.pyc b/benepar/__pycache__/retokenization.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..32215cdfc07304ad96dbcaf1709f3c7c7807ffec
Binary files /dev/null and b/benepar/__pycache__/retokenization.cpython-38.pyc differ
diff --git a/benepar/__pycache__/subbatching.cpython-310.pyc b/benepar/__pycache__/subbatching.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f9d8eceffd6ae995f21f1b4d9c2bb095b3a5c741
Binary files /dev/null and b/benepar/__pycache__/subbatching.cpython-310.pyc differ
diff --git a/benepar/__pycache__/subbatching.cpython-37.pyc b/benepar/__pycache__/subbatching.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f421b80679299e0a14e40db0572f5d3e5c560319
Binary files /dev/null and b/benepar/__pycache__/subbatching.cpython-37.pyc differ
diff --git a/benepar/__pycache__/subbatching.cpython-38.pyc b/benepar/__pycache__/subbatching.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..38b82df09d7bb4ef8b02392c20f4d370b978b94f
Binary files /dev/null and b/benepar/__pycache__/subbatching.cpython-38.pyc differ
diff --git a/benepar/char_lstm.py b/benepar/char_lstm.py
new file mode 100644
index 0000000000000000000000000000000000000000..0aefc5c18959865e9a75cbb476b21e0d2afd5678
--- /dev/null
+++ b/benepar/char_lstm.py
@@ -0,0 +1,160 @@
+"""
+Character LSTM implementation (matches https://arxiv.org/pdf/1805.01052.pdf)
+"""
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class CharacterLSTM(nn.Module):
+    def __init__(self, num_embeddings, d_embedding, d_out, char_dropout=0.0, **kwargs):
+        super().__init__()
+
+        self.d_embedding = d_embedding
+        self.d_out = d_out
+
+        self.lstm = nn.LSTM(
+            self.d_embedding, self.d_out // 2, num_layers=1, bidirectional=True
+        )
+
+        self.emb = nn.Embedding(num_embeddings, self.d_embedding, **kwargs)
+        self.char_dropout = nn.Dropout(char_dropout)
+
+    def forward(self, chars_packed, valid_token_mask):
+        inp_embs = nn.utils.rnn.PackedSequence(
+            self.char_dropout(self.emb(chars_packed.data)),
+            batch_sizes=chars_packed.batch_sizes,
+            sorted_indices=chars_packed.sorted_indices,
+            unsorted_indices=chars_packed.unsorted_indices,
+        )
+
+        _, (lstm_out, _) = self.lstm(inp_embs)
+        lstm_out = torch.cat([lstm_out[0], lstm_out[1]], -1)
+
+        # Switch to a representation where there are dummy vectors for invalid
+        # tokens generated by padding.
+        res = lstm_out.new_zeros(
+            (valid_token_mask.shape[0], valid_token_mask.shape[1], lstm_out.shape[-1])
+        )
+        res[valid_token_mask] = lstm_out
+        return res
+
+
+class RetokenizerForCharLSTM:
+    # Assumes that these control characters are not present in treebank text
+    CHAR_UNK = "\0"
+    CHAR_ID_UNK = 0
+    CHAR_START_SENTENCE = "\1"
+    CHAR_START_WORD = "\2"
+    CHAR_STOP_WORD = "\3"
+    CHAR_STOP_SENTENCE = "\4"
+
+    def __init__(self, char_vocab):
+        self.char_vocab = char_vocab
+
+    @classmethod
+    def build_vocab(cls, sentences):
+        char_set = set()
+        for sentence in sentences:
+            if isinstance(sentence, tuple):
+                sentence = sentence[0]
+            for word in sentence:
+                char_set |= set(word)
+
+        # If codepoints are small (e.g. Latin alphabet), index by codepoint
+        # directly
+        highest_codepoint = max(ord(char) for char in char_set)
+        if highest_codepoint < 512:
+            if highest_codepoint < 256:
+                highest_codepoint = 256
+            else:
+                highest_codepoint = 512
+
+            char_vocab = {}
+            # This also takes care of constants like CHAR_UNK, etc.
+            for codepoint in range(highest_codepoint):
+                char_vocab[chr(codepoint)] = codepoint
+            return char_vocab
+        else:
+            char_vocab = {}
+            char_vocab[cls.CHAR_UNK] = 0
+            char_vocab[cls.CHAR_START_SENTENCE] = 1
+            char_vocab[cls.CHAR_START_WORD] = 2
+            char_vocab[cls.CHAR_STOP_WORD] = 3
+            char_vocab[cls.CHAR_STOP_SENTENCE] = 4
+            for id_, char in enumerate(sorted(char_set), start=5):
+                char_vocab[char] = id_
+            return char_vocab
+
+    def __call__(self, words, space_after="ignored", return_tensors=None):
+        if return_tensors != "np":
+            raise NotImplementedError("Only return_tensors='np' is supported.")
+
+        res = {}
+
+        # Sentence-level start/stop tokens are encoded as 3 pseudo-chars
+        # Within each word, account for 2 start/stop characters
+        max_word_len = max(3, max(len(word) for word in words)) + 2
+        char_ids = np.zeros((len(words) + 2, max_word_len), dtype=int)
+        word_lens = np.zeros(len(words) + 2, dtype=int)
+
+        char_ids[0, :5] = [
+            self.char_vocab[self.CHAR_START_WORD],
+            self.char_vocab[self.CHAR_START_SENTENCE],
+            self.char_vocab[self.CHAR_START_SENTENCE],
+            self.char_vocab[self.CHAR_START_SENTENCE],
+            self.char_vocab[self.CHAR_STOP_WORD],
+        ]
+        word_lens[0] = 5
+        for i, word in enumerate(words, start=1):
+            char_ids[i, 0] = self.char_vocab[self.CHAR_START_WORD]
+            for j, char in enumerate(word, start=1):
+                char_ids[i, j] = self.char_vocab.get(char, self.CHAR_ID_UNK)
+            char_ids[i, j + 1] = self.char_vocab[self.CHAR_STOP_WORD]
+            word_lens[i] = j + 2
+        char_ids[i + 1, :5] = [
+            self.char_vocab[self.CHAR_START_WORD],
+            self.char_vocab[self.CHAR_STOP_SENTENCE],
+            self.char_vocab[self.CHAR_STOP_SENTENCE],
+            self.char_vocab[self.CHAR_STOP_SENTENCE],
+            self.char_vocab[self.CHAR_STOP_WORD],
+        ]
+        word_lens[i + 1] = 5
+
+        res["char_ids"] = char_ids
+        res["word_lens"] = word_lens
+        res["valid_token_mask"] = np.ones_like(word_lens, dtype=bool)
+
+        return res
+
+    def pad(self, examples, return_tensors=None):
+        if return_tensors != "pt":
+            raise NotImplementedError("Only return_tensors='pt' is supported.")
+        max_word_len = max(example["char_ids"].shape[-1] for example in examples)
+        char_ids = torch.cat(
+            [
+                F.pad(
+                    torch.tensor(example["char_ids"]),
+                    (0, max_word_len - example["char_ids"].shape[-1]),
+                )
+                for example in examples
+            ]
+        )
+        word_lens = torch.cat(
+            [torch.tensor(example["word_lens"]) for example in examples]
+        )
+        valid_token_mask = nn.utils.rnn.pad_sequence(
+            [torch.tensor(example["valid_token_mask"]) for example in examples],
+            batch_first=True,
+            padding_value=False,
+        )
+
+        char_ids = nn.utils.rnn.pack_padded_sequence(
+            char_ids, word_lens, batch_first=True, enforce_sorted=False
+        )
+        return {
+            "char_ids": char_ids,
+            "valid_token_mask": valid_token_mask,
+        }
diff --git a/benepar/decode_chart.py b/benepar/decode_chart.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d32ed1bdbe3bef17f509ceffdd1138267a36b0e
--- /dev/null
+++ b/benepar/decode_chart.py
@@ -0,0 +1,291 @@
+"""
+Parsing formulated as span classification (https://arxiv.org/abs/1705.03919)
+"""
+
+import nltk
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch_struct
+
+from .parse_base import CompressedParserOutput
+
+
+def pad_charts(charts, padding_value=-100):
+    """Pad a list of variable-length charts with `padding_value`."""
+    batch_size = len(charts)
+    max_len = max(chart.shape[0] for chart in charts)
+    padded_charts = torch.full(
+        (batch_size, max_len, max_len),
+        padding_value,
+        dtype=charts[0].dtype,
+        device=charts[0].device,
+    )
+    for i, chart in enumerate(charts):
+        chart_size = chart.shape[0]
+        padded_charts[i, :chart_size, :chart_size] = chart
+    return padded_charts
+
+
+def collapse_unary_strip_pos(tree, strip_top=True):
+    """Collapse unary chains and strip part of speech tags."""
+
+    def strip_pos(tree):
+        if len(tree) == 1 and isinstance(tree[0], str):
+            return tree[0]
+        else:
+            return nltk.tree.Tree(tree.label(), [strip_pos(child) for child in tree])
+
+    collapsed_tree = strip_pos(tree)
+    collapsed_tree.collapse_unary(collapsePOS=True, joinChar="::")
+    if collapsed_tree.label() in ("TOP", "ROOT", "S1", "VROOT"):
+        if strip_top:
+            if len(collapsed_tree) == 1:
+                collapsed_tree = collapsed_tree[0]
+            else:
+                collapsed_tree.set_label("")
+        elif len(collapsed_tree) == 1:
+            collapsed_tree[0].set_label(
+                collapsed_tree.label() + "::" + collapsed_tree[0].label())
+            collapsed_tree = collapsed_tree[0]
+    return collapsed_tree
+
+
+def _get_labeled_spans(tree, spans_out, start):
+    if isinstance(tree, str):
+        return start + 1
+
+    assert len(tree) > 1 or isinstance(
+        tree[0], str
+    ), "Must call collapse_unary_strip_pos first"
+    end = start
+    for child in tree:
+        end = _get_labeled_spans(child, spans_out, end)
+    # Spans are returned as closed intervals on both ends
+    spans_out.append((start, end - 1, tree.label()))
+    return end
+
+
+def get_labeled_spans(tree):
+    """Converts a tree into a list of labeled spans.
+
+    Args:
+        tree: an nltk.tree.Tree object
+
+    Returns:
+        A list of (span_start, span_end, span_label) tuples. The start and end
+        indices indicate the first and last words of the span (a closed
+        interval). Unary chains are collapsed, so e.g. a (S (VP ...)) will
+        result in a single span labeled "S+VP".
+    """
+    tree = collapse_unary_strip_pos(tree)
+    spans_out = []
+    _get_labeled_spans(tree, spans_out, start=0)
+    return spans_out
+
+
+def uncollapse_unary(tree, ensure_top=False):
+    """Un-collapse unary chains."""
+    if isinstance(tree, str):
+        return tree
+    else:
+        labels = tree.label().split("::")
+        if ensure_top and labels[0] != "TOP":
+            labels = ["TOP"] + labels
+        children = []
+        for child in tree:
+            child = uncollapse_unary(child)
+            children.append(child)
+        for label in labels[::-1]:
+            children = [nltk.tree.Tree(label, children)]
+        return children[0]
+
+
+class ChartDecoder:
+    """A chart decoder for parsing formulated as span classification."""
+
+    def __init__(self, label_vocab, force_root_constituent=True):
+        """Constructs a new ChartDecoder object.
+        Args:
+            label_vocab: A mapping from span labels to integer indices.
+        """
+        self.label_vocab = label_vocab
+        self.label_from_index = {i: label for label, i in label_vocab.items()}
+        self.force_root_constituent = force_root_constituent
+
+    @staticmethod
+    def build_vocab(trees):
+        label_set = set()
+        for tree in trees:
+            for _, _, label in get_labeled_spans(tree):
+                if label:
+                    label_set.add(label)
+        label_set = [""] + sorted(label_set)
+        return {label: i for i, label in enumerate(label_set)}
+    
+    @staticmethod
+    def infer_force_root_constituent(trees):
+        for tree in trees:
+            for _, _, label in get_labeled_spans(tree):
+                if not label:
+                    return False
+        return True
+
+    def chart_from_tree(self, tree):
+        spans = get_labeled_spans(tree)
+        num_words = len(tree.leaves())
+        chart = np.full((num_words, num_words), -100, dtype=int)
+        chart = np.tril(chart, -1)
+        # Now all invalid entries are filled with -100, and valid entries with 0
+        for start, end, label in spans:
+            # Previously unseen unary chains can occur in the dev/test sets.
+            # For now, we ignore them and don't mark the corresponding chart
+            # entry as a constituent.
+            if label in self.label_vocab:
+                chart[start, end] = self.label_vocab[label]
+        return chart
+
+    def charts_from_pytorch_scores_batched(self, scores, lengths):
+        """Runs CKY to recover span labels from scores (e.g. logits).
+
+        This method uses pytorch-struct to speed up decoding compared to the
+        pure-Python implementation of CKY used by tree_from_scores().
+
+        Args:
+            scores: a pytorch tensor of shape (batch size, max length,
+                max length, label vocab size).
+            lengths: a pytorch tensor of shape (batch size,)
+
+        Returns:
+            A list of numpy arrays, each of shape (sentence length, sentence
+                length).
+        """
+        scores = scores.detach()
+        scores = scores - scores[..., :1]
+        if self.force_root_constituent:
+            scores[torch.arange(scores.shape[0]), 0, lengths - 1, 0] -= 1e9
+        dist = torch_struct.TreeCRF(scores, lengths=lengths)
+        amax = dist.argmax
+        amax[..., 0] += 1e-9
+        padded_charts = amax.argmax(-1)
+        padded_charts = padded_charts.detach().cpu().numpy()
+        return [
+            chart[:length, :length] for chart, length in zip(padded_charts, lengths)
+        ]
+
+    def compressed_output_from_chart(self, chart):
+        chart_with_filled_diagonal = chart.copy()
+        np.fill_diagonal(chart_with_filled_diagonal, 1)
+        chart_with_filled_diagonal[0, -1] = 1
+        starts, inclusive_ends = np.where(chart_with_filled_diagonal)
+        preorder_sort = np.lexsort((-inclusive_ends, starts))
+        starts = starts[preorder_sort]
+        inclusive_ends = inclusive_ends[preorder_sort]
+        labels = chart[starts, inclusive_ends]
+        ends = inclusive_ends + 1
+        return CompressedParserOutput(starts=starts, ends=ends, labels=labels)
+
+    def tree_from_chart(self, chart, leaves):
+        compressed_output = self.compressed_output_from_chart(chart)
+        return compressed_output.to_tree(leaves, self.label_from_index)
+
+    def tree_from_scores(self, scores, leaves):
+        """Runs CKY to decode a tree from scores (e.g. logits).
+
+        If speed is important, consider using charts_from_pytorch_scores_batched
+        followed by compressed_output_from_chart or tree_from_chart instead.
+
+        Args:
+            scores: a chart of scores (or logits) of shape
+                (sentence length, sentence length, label vocab size). The first
+                two dimensions may be padded to a longer length, but all padded
+                values will be ignored.
+            leaves: the leaf nodes to use in the constructed tree. These
+                may be of type str or nltk.Tree, or (word, tag) tuples that
+                will be used to construct the leaf node objects.
+
+        Returns:
+            An nltk.Tree object.
+        """
+        leaves = [
+            nltk.Tree(node[1], [node[0]]) if isinstance(node, tuple) else node
+            for node in leaves
+        ]
+
+        chart = {}
+        scores = scores - scores[:, :, 0, None]
+        for length in range(1, len(leaves) + 1):
+            for left in range(0, len(leaves) + 1 - length):
+                right = left + length
+
+                label_scores = scores[left, right - 1]
+                label_scores = label_scores - label_scores[0]
+
+                argmax_label_index = int(
+                    label_scores.argmax()
+                    if length < len(leaves) or not self.force_root_constituent
+                    else label_scores[1:].argmax() + 1
+                )
+                argmax_label = self.label_from_index[argmax_label_index]
+                label = argmax_label
+                label_score = label_scores[argmax_label_index]
+
+                if length == 1:
+                    tree = leaves[left]
+                    if label:
+                        tree = nltk.tree.Tree(label, [tree])
+                    chart[left, right] = [tree], label_score
+                    continue
+
+                best_split = max(
+                    range(left + 1, right),
+                    key=lambda split: (chart[left, split][1] + chart[split, right][1]),
+                )
+
+                left_trees, left_score = chart[left, best_split]
+                right_trees, right_score = chart[best_split, right]
+
+                children = left_trees + right_trees
+                if label:
+                    children = [nltk.tree.Tree(label, children)]
+
+                chart[left, right] = (children, label_score + left_score + right_score)
+
+        children, score = chart[0, len(leaves)]
+        tree = nltk.tree.Tree("TOP", children)
+        tree = uncollapse_unary(tree)
+        return tree
+
+
+class SpanClassificationMarginLoss(nn.Module):
+    def __init__(self, force_root_constituent=True, reduction="mean"):
+        super().__init__()
+        self.force_root_constituent = force_root_constituent
+        if reduction not in ("none", "mean", "sum"):
+            raise ValueError(f"Invalid value for reduction: {reduction}")
+        self.reduction = reduction
+
+    def forward(self, logits, labels):
+        gold_event = F.one_hot(F.relu(labels), num_classes=logits.shape[-1])
+
+        logits = logits - logits[..., :1]
+        lengths = (labels[:, 0, :] != -100).sum(-1)
+        augment = (1 - gold_event).to(torch.float)
+        if self.force_root_constituent:
+            augment[torch.arange(augment.shape[0]), 0, lengths - 1, 0] -= 1e9
+        dist = torch_struct.TreeCRF(logits + augment, lengths=lengths)
+
+        pred_score = dist.max
+        gold_score = (logits * gold_event).sum((1, 2, 3))
+
+        margin_losses = F.relu(pred_score - gold_score)
+
+        if self.reduction == "none":
+            return margin_losses
+        elif self.reduction == "mean":
+            return margin_losses.mean()
+        elif self.reduction == "sum":
+            return margin_losses.sum()
+        else:
+            assert False, f"Unexpected reduction: {self.reduction}"
diff --git a/benepar/decode_chart.py~ b/benepar/decode_chart.py~
new file mode 100644
index 0000000000000000000000000000000000000000..8d32ed1bdbe3bef17f509ceffdd1138267a36b0e
--- /dev/null
+++ b/benepar/decode_chart.py~
@@ -0,0 +1,291 @@
+"""
+Parsing formulated as span classification (https://arxiv.org/abs/1705.03919)
+"""
+
+import nltk
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch_struct
+
+from .parse_base import CompressedParserOutput
+
+
+def pad_charts(charts, padding_value=-100):
+    """Pad a list of variable-length charts with `padding_value`."""
+    batch_size = len(charts)
+    max_len = max(chart.shape[0] for chart in charts)
+    padded_charts = torch.full(
+        (batch_size, max_len, max_len),
+        padding_value,
+        dtype=charts[0].dtype,
+        device=charts[0].device,
+    )
+    for i, chart in enumerate(charts):
+        chart_size = chart.shape[0]
+        padded_charts[i, :chart_size, :chart_size] = chart
+    return padded_charts
+
+
+def collapse_unary_strip_pos(tree, strip_top=True):
+    """Collapse unary chains and strip part of speech tags."""
+
+    def strip_pos(tree):
+        if len(tree) == 1 and isinstance(tree[0], str):
+            return tree[0]
+        else:
+            return nltk.tree.Tree(tree.label(), [strip_pos(child) for child in tree])
+
+    collapsed_tree = strip_pos(tree)
+    collapsed_tree.collapse_unary(collapsePOS=True, joinChar="::")
+    if collapsed_tree.label() in ("TOP", "ROOT", "S1", "VROOT"):
+        if strip_top:
+            if len(collapsed_tree) == 1:
+                collapsed_tree = collapsed_tree[0]
+            else:
+                collapsed_tree.set_label("")
+        elif len(collapsed_tree) == 1:
+            collapsed_tree[0].set_label(
+                collapsed_tree.label() + "::" + collapsed_tree[0].label())
+            collapsed_tree = collapsed_tree[0]
+    return collapsed_tree
+
+
+def _get_labeled_spans(tree, spans_out, start):
+    if isinstance(tree, str):
+        return start + 1
+
+    assert len(tree) > 1 or isinstance(
+        tree[0], str
+    ), "Must call collapse_unary_strip_pos first"
+    end = start
+    for child in tree:
+        end = _get_labeled_spans(child, spans_out, end)
+    # Spans are returned as closed intervals on both ends
+    spans_out.append((start, end - 1, tree.label()))
+    return end
+
+
+def get_labeled_spans(tree):
+    """Converts a tree into a list of labeled spans.
+
+    Args:
+        tree: an nltk.tree.Tree object
+
+    Returns:
+        A list of (span_start, span_end, span_label) tuples. The start and end
+        indices indicate the first and last words of the span (a closed
+        interval). Unary chains are collapsed, so e.g. a (S (VP ...)) will
+        result in a single span labeled "S+VP".
+    """
+    tree = collapse_unary_strip_pos(tree)
+    spans_out = []
+    _get_labeled_spans(tree, spans_out, start=0)
+    return spans_out
+
+
+def uncollapse_unary(tree, ensure_top=False):
+    """Un-collapse unary chains."""
+    if isinstance(tree, str):
+        return tree
+    else:
+        labels = tree.label().split("::")
+        if ensure_top and labels[0] != "TOP":
+            labels = ["TOP"] + labels
+        children = []
+        for child in tree:
+            child = uncollapse_unary(child)
+            children.append(child)
+        for label in labels[::-1]:
+            children = [nltk.tree.Tree(label, children)]
+        return children[0]
+
+
+class ChartDecoder:
+    """A chart decoder for parsing formulated as span classification."""
+
+    def __init__(self, label_vocab, force_root_constituent=True):
+        """Constructs a new ChartDecoder object.
+        Args:
+            label_vocab: A mapping from span labels to integer indices.
+        """
+        self.label_vocab = label_vocab
+        self.label_from_index = {i: label for label, i in label_vocab.items()}
+        self.force_root_constituent = force_root_constituent
+
+    @staticmethod
+    def build_vocab(trees):
+        label_set = set()
+        for tree in trees:
+            for _, _, label in get_labeled_spans(tree):
+                if label:
+                    label_set.add(label)
+        label_set = [""] + sorted(label_set)
+        return {label: i for i, label in enumerate(label_set)}
+    
+    @staticmethod
+    def infer_force_root_constituent(trees):
+        for tree in trees:
+            for _, _, label in get_labeled_spans(tree):
+                if not label:
+                    return False
+        return True
+
+    def chart_from_tree(self, tree):
+        spans = get_labeled_spans(tree)
+        num_words = len(tree.leaves())
+        chart = np.full((num_words, num_words), -100, dtype=int)
+        chart = np.tril(chart, -1)
+        # Now all invalid entries are filled with -100, and valid entries with 0
+        for start, end, label in spans:
+            # Previously unseen unary chains can occur in the dev/test sets.
+            # For now, we ignore them and don't mark the corresponding chart
+            # entry as a constituent.
+            if label in self.label_vocab:
+                chart[start, end] = self.label_vocab[label]
+        return chart
+
+    def charts_from_pytorch_scores_batched(self, scores, lengths):
+        """Runs CKY to recover span labels from scores (e.g. logits).
+
+        This method uses pytorch-struct to speed up decoding compared to the
+        pure-Python implementation of CKY used by tree_from_scores().
+
+        Args:
+            scores: a pytorch tensor of shape (batch size, max length,
+                max length, label vocab size).
+            lengths: a pytorch tensor of shape (batch size,)
+
+        Returns:
+            A list of numpy arrays, each of shape (sentence length, sentence
+                length).
+        """
+        scores = scores.detach()
+        scores = scores - scores[..., :1]
+        if self.force_root_constituent:
+            scores[torch.arange(scores.shape[0]), 0, lengths - 1, 0] -= 1e9
+        dist = torch_struct.TreeCRF(scores, lengths=lengths)
+        amax = dist.argmax
+        amax[..., 0] += 1e-9
+        padded_charts = amax.argmax(-1)
+        padded_charts = padded_charts.detach().cpu().numpy()
+        return [
+            chart[:length, :length] for chart, length in zip(padded_charts, lengths)
+        ]
+
+    def compressed_output_from_chart(self, chart):
+        chart_with_filled_diagonal = chart.copy()
+        np.fill_diagonal(chart_with_filled_diagonal, 1)
+        chart_with_filled_diagonal[0, -1] = 1
+        starts, inclusive_ends = np.where(chart_with_filled_diagonal)
+        preorder_sort = np.lexsort((-inclusive_ends, starts))
+        starts = starts[preorder_sort]
+        inclusive_ends = inclusive_ends[preorder_sort]
+        labels = chart[starts, inclusive_ends]
+        ends = inclusive_ends + 1
+        return CompressedParserOutput(starts=starts, ends=ends, labels=labels)
+
+    def tree_from_chart(self, chart, leaves):
+        compressed_output = self.compressed_output_from_chart(chart)
+        return compressed_output.to_tree(leaves, self.label_from_index)
+
+    def tree_from_scores(self, scores, leaves):
+        """Runs CKY to decode a tree from scores (e.g. logits).
+
+        If speed is important, consider using charts_from_pytorch_scores_batched
+        followed by compressed_output_from_chart or tree_from_chart instead.
+
+        Args:
+            scores: a chart of scores (or logits) of shape
+                (sentence length, sentence length, label vocab size). The first
+                two dimensions may be padded to a longer length, but all padded
+                values will be ignored.
+            leaves: the leaf nodes to use in the constructed tree. These
+                may be of type str or nltk.Tree, or (word, tag) tuples that
+                will be used to construct the leaf node objects.
+
+        Returns:
+            An nltk.Tree object.
+        """
+        leaves = [
+            nltk.Tree(node[1], [node[0]]) if isinstance(node, tuple) else node
+            for node in leaves
+        ]
+
+        chart = {}
+        scores = scores - scores[:, :, 0, None]
+        for length in range(1, len(leaves) + 1):
+            for left in range(0, len(leaves) + 1 - length):
+                right = left + length
+
+                label_scores = scores[left, right - 1]
+                label_scores = label_scores - label_scores[0]
+
+                argmax_label_index = int(
+                    label_scores.argmax()
+                    if length < len(leaves) or not self.force_root_constituent
+                    else label_scores[1:].argmax() + 1
+                )
+                argmax_label = self.label_from_index[argmax_label_index]
+                label = argmax_label
+                label_score = label_scores[argmax_label_index]
+
+                if length == 1:
+                    tree = leaves[left]
+                    if label:
+                        tree = nltk.tree.Tree(label, [tree])
+                    chart[left, right] = [tree], label_score
+                    continue
+
+                best_split = max(
+                    range(left + 1, right),
+                    key=lambda split: (chart[left, split][1] + chart[split, right][1]),
+                )
+
+                left_trees, left_score = chart[left, best_split]
+                right_trees, right_score = chart[best_split, right]
+
+                children = left_trees + right_trees
+                if label:
+                    children = [nltk.tree.Tree(label, children)]
+
+                chart[left, right] = (children, label_score + left_score + right_score)
+
+        children, score = chart[0, len(leaves)]
+        tree = nltk.tree.Tree("TOP", children)
+        tree = uncollapse_unary(tree)
+        return tree
+
+
+class SpanClassificationMarginLoss(nn.Module):
+    def __init__(self, force_root_constituent=True, reduction="mean"):
+        super().__init__()
+        self.force_root_constituent = force_root_constituent
+        if reduction not in ("none", "mean", "sum"):
+            raise ValueError(f"Invalid value for reduction: {reduction}")
+        self.reduction = reduction
+
+    def forward(self, logits, labels):
+        gold_event = F.one_hot(F.relu(labels), num_classes=logits.shape[-1])
+
+        logits = logits - logits[..., :1]
+        lengths = (labels[:, 0, :] != -100).sum(-1)
+        augment = (1 - gold_event).to(torch.float)
+        if self.force_root_constituent:
+            augment[torch.arange(augment.shape[0]), 0, lengths - 1, 0] -= 1e9
+        dist = torch_struct.TreeCRF(logits + augment, lengths=lengths)
+
+        pred_score = dist.max
+        gold_score = (logits * gold_event).sum((1, 2, 3))
+
+        margin_losses = F.relu(pred_score - gold_score)
+
+        if self.reduction == "none":
+            return margin_losses
+        elif self.reduction == "mean":
+            return margin_losses.mean()
+        elif self.reduction == "sum":
+            return margin_losses.sum()
+        else:
+            assert False, f"Unexpected reduction: {self.reduction}"
diff --git a/benepar/integrations/__init__.py b/benepar/integrations/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/benepar/integrations/__pycache__/__init__.cpython-310.pyc b/benepar/integrations/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..472b8f490320e3043a0516e7b31f60831e7d940d
Binary files /dev/null and b/benepar/integrations/__pycache__/__init__.cpython-310.pyc differ
diff --git a/benepar/integrations/__pycache__/__init__.cpython-37.pyc b/benepar/integrations/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3982b477e04fd204a3b37ed2cfc84c310c1e48d2
Binary files /dev/null and b/benepar/integrations/__pycache__/__init__.cpython-37.pyc differ
diff --git a/benepar/integrations/__pycache__/__init__.cpython-38.pyc b/benepar/integrations/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c1bd7e040e9d6b1d4fe5638f0589647ea6cbb2fc
Binary files /dev/null and b/benepar/integrations/__pycache__/__init__.cpython-38.pyc differ
diff --git a/benepar/integrations/__pycache__/downloader.cpython-310.pyc b/benepar/integrations/__pycache__/downloader.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9b8c58a99ef65a47d79a1699bb00b08c3f4dce05
Binary files /dev/null and b/benepar/integrations/__pycache__/downloader.cpython-310.pyc differ
diff --git a/benepar/integrations/__pycache__/downloader.cpython-37.pyc b/benepar/integrations/__pycache__/downloader.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..31cf8a2b28a8eb77f6b70f2d65f75c9f542cf668
Binary files /dev/null and b/benepar/integrations/__pycache__/downloader.cpython-37.pyc differ
diff --git a/benepar/integrations/__pycache__/downloader.cpython-38.pyc b/benepar/integrations/__pycache__/downloader.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a5a9d405faba5614aef68f041f7a905e5cc73ea8
Binary files /dev/null and b/benepar/integrations/__pycache__/downloader.cpython-38.pyc differ
diff --git a/benepar/integrations/__pycache__/nltk_plugin.cpython-310.pyc b/benepar/integrations/__pycache__/nltk_plugin.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a1ce1b4bfa1f0673955e35059a6c9fe9ae0213c6
Binary files /dev/null and b/benepar/integrations/__pycache__/nltk_plugin.cpython-310.pyc differ
diff --git a/benepar/integrations/__pycache__/nltk_plugin.cpython-37.pyc b/benepar/integrations/__pycache__/nltk_plugin.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..43ef2f670994fc2e3218226531e128894aefc5ec
Binary files /dev/null and b/benepar/integrations/__pycache__/nltk_plugin.cpython-37.pyc differ
diff --git a/benepar/integrations/__pycache__/nltk_plugin.cpython-38.pyc b/benepar/integrations/__pycache__/nltk_plugin.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d0fc4d2d9053cdbe0d0036b4405bf08ddb1f8c02
Binary files /dev/null and b/benepar/integrations/__pycache__/nltk_plugin.cpython-38.pyc differ
diff --git a/benepar/integrations/__pycache__/spacy_extensions.cpython-310.pyc b/benepar/integrations/__pycache__/spacy_extensions.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e11db5a43e08b4f6c74587b0e6a10f3da3d7c44b
Binary files /dev/null and b/benepar/integrations/__pycache__/spacy_extensions.cpython-310.pyc differ
diff --git a/benepar/integrations/__pycache__/spacy_extensions.cpython-37.pyc b/benepar/integrations/__pycache__/spacy_extensions.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9b1bcb9c6f7746983ae772c32e7d04f794a69faf
Binary files /dev/null and b/benepar/integrations/__pycache__/spacy_extensions.cpython-37.pyc differ
diff --git a/benepar/integrations/__pycache__/spacy_extensions.cpython-38.pyc b/benepar/integrations/__pycache__/spacy_extensions.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a51c46b521bb319ee604635af1c6365d3afcdf3e
Binary files /dev/null and b/benepar/integrations/__pycache__/spacy_extensions.cpython-38.pyc differ
diff --git a/benepar/integrations/__pycache__/spacy_plugin.cpython-310.pyc b/benepar/integrations/__pycache__/spacy_plugin.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7eccc195c49c47f56f863d6baa52b1763a87f935
Binary files /dev/null and b/benepar/integrations/__pycache__/spacy_plugin.cpython-310.pyc differ
diff --git a/benepar/integrations/__pycache__/spacy_plugin.cpython-37.pyc b/benepar/integrations/__pycache__/spacy_plugin.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2152c134b41ead0535fcbeb883261b6fc7fc8d5b
Binary files /dev/null and b/benepar/integrations/__pycache__/spacy_plugin.cpython-37.pyc differ
diff --git a/benepar/integrations/__pycache__/spacy_plugin.cpython-38.pyc b/benepar/integrations/__pycache__/spacy_plugin.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d6f77e3e57eb793f6be64bb18d1652bcece6da10
Binary files /dev/null and b/benepar/integrations/__pycache__/spacy_plugin.cpython-38.pyc differ
diff --git a/benepar/integrations/downloader.py b/benepar/integrations/downloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..019aa4e286fcce338659bab0894068512d5f71ce
--- /dev/null
+++ b/benepar/integrations/downloader.py
@@ -0,0 +1,35 @@
+import os
+
+BENEPAR_SERVER_INDEX = "https://kitaev.com/benepar/index.xml"
+
+_downloader = None
+def get_downloader():
+    global _downloader
+    if _downloader is None:
+        import nltk.downloader
+        _downloader = nltk.downloader.Downloader(server_index_url=BENEPAR_SERVER_INDEX)
+    return _downloader
+
+def download(*args, **kwargs):
+    return get_downloader().download(*args, **kwargs)
+
+def locate_model(name):
+    if os.path.exists(name):
+        return name
+    elif "/" not in name and "." not in name:
+        import nltk.data
+        try:
+            nltk_loc = nltk.data.find(f"models/{name}")
+            return nltk_loc.path
+        except LookupError as e:
+            arg = e.args[0].replace("nltk.download", "benepar.download")
+        
+        raise LookupError(arg)
+    
+    raise LookupError("Can't find {}".format(name))
+
+def load_trained_model(model_name_or_path):
+    model_path = locate_model(model_name_or_path)
+    from ..parse_chart import ChartParser
+    parser = ChartParser.from_trained(model_path)
+    return parser
diff --git a/benepar/integrations/nltk_plugin.py b/benepar/integrations/nltk_plugin.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7a454d79eed45700c0882779cc62d5a0190c5b6
--- /dev/null
+++ b/benepar/integrations/nltk_plugin.py
@@ -0,0 +1,279 @@
+import dataclasses
+import itertools
+from typing import List, Optional, Tuple
+
+import nltk
+import torch
+
+from .downloader import load_trained_model
+from ..parse_base import BaseParser, BaseInputExample
+from ..ptb_unescape import ptb_unescape, guess_space_after
+
+
+TOKENIZER_LOOKUP = {
+    "en": "english",
+    "de": "german",
+    "fr": "french",
+    "pl": "polish",
+    "sv": "swedish",
+}
+
+LANGUAGE_GUESS = {
+    "ar": ("X", "XP", "WHADVP", "WHNP", "WHPP"),
+    "zh": ("VSB", "VRD", "VPT", "VNV"),
+    "en": ("WHNP", "WHADJP", "SINV", "SQ"),
+    "de": ("AA", "AP", "CCP", "CH", "CNP", "VZ"),
+    "fr": ("P+", "P+D+", "PRO+", "PROREL+"),
+    "he": ("PREDP", "SYN_REL", "SYN_yyDOT"),
+    "pl": ("formaczas", "znakkonca"),
+    "sv": ("PSEUDO", "AVP", "XP"),
+}
+
+
+def guess_language(label_vocab):
+    """Guess parser language based on its syntactic label inventory.
+
+    The parser training scripts are designed to accept arbitrary input tree
+    files with minimal language-specific behavior, but at inference time we may
+    need to know the language identity in order to invoke other pipeline
+    elements, such as tokenizers.
+    """
+    for language, required_labels in LANGUAGE_GUESS.items():
+        if all(label in label_vocab for label in required_labels):
+            return language
+    return None
+
+
+@dataclasses.dataclass
+class InputSentence(BaseInputExample):
+    """Parser input for a single sentence.
+
+    At least one of `words` and `escaped_words` is required for each input
+    sentence. The remaining fields are optional: the parser will attempt to
+    derive the value for any missing fields using the fields that are provided.
+
+    `words` and `space_after` together form a reversible tokenization of the
+    input text: they represent, respectively, the Unicode text for each word and
+    an indicator for whether the word is followed by whitespace. These are used
+    as inputs by the parser.
+
+    `tags` is a list of part-of-speech tags, if available prior to running the
+    parser. The parser does not actually use these tags as input, but it will
+    pass them through to its output. If `tags` is None, the parser will perform
+    its own part of speech tagging (if the parser was not trained to also do
+    tagging, "UNK" part-of-speech tags will be used in the output instead).
+
+    `escaped_words` are the representations of each leaf to use in the output
+    tree. If `words` is provided, `escaped_words` will not be used by the neural
+    network portion of the parser, and will only be incorporated when
+    constructing the output tree. Therefore, `escaped_words` may be used to
+    accommodate any dataset-specific text encoding, such as transliteration.
+
+    Here is an example of the differences between these fields for English PTB:
+        (raw text):     "Fly safely."
+        words:          "       Fly     safely  .       "
+        space_after:    False   True    False   False   False
+        tags:           ``      VB      RB      .       ''
+        escaped_words:  ``      Fly     safely  .       ''
+    """
+
+    words: Optional[List[str]] = None
+    space_after: Optional[List[bool]] = None
+    tags: Optional[List[str]] = None
+    escaped_words: Optional[List[str]] = None
+
+    @property
+    def tree(self):
+        return None
+
+    def leaves(self):
+        return self.escaped_words
+
+    def pos(self):
+        if self.tags is not None:
+            return list(zip(self.escaped_words, self.tags))
+        else:
+            return [(word, "UNK") for word in self.escaped_words]
+
+
+class Parser:
+    """Berkeley Neural Parser (benepar), integrated with NLTK.
+
+    Use this class to apply the Berkeley Neural Parser to pre-tokenized datasets
+    and treebanks, or when integrating the parser into an NLP pipeline that
+    already performs tokenization, sentence splitting, and (optionally)
+    part-of-speech tagging. For parsing starting with raw text, it is strongly
+    encouraged that you use spaCy and benepar.BeneparComponent instead.
+
+    Sample usage:
+    >>> parser = benepar.Parser("benepar_en3")
+    >>> input_sentence = benepar.InputSentence(
+        words=['"', 'Fly', 'safely', '.', '"'],
+        space_after=[False, True, False, False, False],
+        tags=['``', 'VB', 'RB', '.', "''"],
+        escaped_words=['``', 'Fly', 'safely', '.', "''"],
+    )
+    >>> parser.parse(input_sentence)
+
+    Not all fields of benepar.InputSentence are required, but at least one of
+    `words` and `escaped_words` must not be None. The parser will attempt to
+    guess the value for missing fields. For example,
+    >>> input_sentence = benepar.InputSentence(
+        words=['"', 'Fly', 'safely', '.', '"'],
+    )
+    >>> parser.parse(input_sentence)
+
+    Although this class is primarily designed for use with data that has already
+    been tokenized, to help with interactive use and debugging it also accepts
+    simple text string inputs. However, using this class to parse from raw text
+    is STRONGLY DISCOURAGED for any application where parsing accuracy matters.
+    When parsing from raw text, use spaCy and benepar.BeneparComponent instead.
+    The reason is that parser models do not ship with a tokenizer or sentence
+    splitter, and some models may not include a part-of-speech tagger either. A
+    toolkit must be used to fill in these pipeline components, and spaCy
+    outperforms NLTK in all of these areas (sometimes by a large margin).
+    >>> parser.parse('"Fly safely."')  # For debugging/interactive use only.
+    """
+
+    def __init__(self, name, batch_size=64, language_code=None):
+        """Load a trained parser model.
+
+        Args:
+            name (str): Model name, or path to pytorch saved model
+            batch_size (int): Maximum number of sentences to process per batch
+            language_code (str, optional): language code for the parser (e.g.
+                'en', 'he', 'zh', etc). Our official trained models will set
+                this automatically, so this argument is only needed if training
+                on new languages or treebanks.
+        """
+        self._parser = load_trained_model(name)
+        if torch.cuda.is_available():
+            self._parser.cuda()
+        if language_code is not None:
+            self._language_code = language_code
+        else:
+            self._language_code = guess_language(self._parser.config["label_vocab"])
+        self._tokenizer_lang = TOKENIZER_LOOKUP.get(self._language_code, None)
+
+        self.batch_size = batch_size
+
+    def parse(self, sentence):
+        """Parse a single sentence
+
+        Args:
+            sentence (InputSentence or List[str] or str): Sentence to parse.
+                If the input is of List[str], it is assumed to be a sequence of
+                words and will behave the same as only setting the `words` field
+                of InputSentence. If the input is of type str, the sentence will
+                be tokenized using the default NLTK tokenizer (not recommended:
+                if parsing from raw text, use spaCy and benepar.BeneparComponent
+                instead).
+
+        Returns:
+            nltk.Tree
+        """
+        return list(self.parse_sents([sentence]))[0]
+
+    def parse_sents(self, sents):
+        """Parse multiple sentences in batches.
+
+        Args:
+            sents (Iterable[InputSentence]): An iterable of sentences to be
+                parsed. `sents` may also be a string, in which case it will be
+                segmented into sentences using the default NLTK sentence
+                splitter (not recommended: if parsing from raw text, use spaCy
+                and benepar.BeneparComponent instead). Otherwise, each element
+                of `sents` will be treated as a sentence. The elements of
+                `sents` may also be List[str] or str: see Parser.parse() for
+                documentation regarding these cases.
+
+        Yields:
+            nltk.Tree objects, one per input sentence.
+        """
+        if isinstance(sents, str):
+            if self._tokenizer_lang is None:
+                raise ValueError(
+                    "No tokenizer available for this language. "
+                    "Please split into individual sentences and tokens "
+                    "before calling the parser."
+                )
+            sents = nltk.sent_tokenize(sents, self._tokenizer_lang)
+
+        end_sentinel = object()
+        for batch_sents in itertools.zip_longest(
+            *([iter(sents)] * self.batch_size), fillvalue=end_sentinel
+        ):
+            batch_inputs = []
+            for sent in batch_sents:
+                if sent is end_sentinel:
+                    break
+                elif isinstance(sent, str):
+                    if self._tokenizer_lang is None:
+                        raise ValueError(
+                            "No word tokenizer available for this language. "
+                            "Please tokenize before calling the parser."
+                        )
+                    escaped_words = nltk.word_tokenize(sent, self._tokenizer_lang)
+                    sent = InputSentence(escaped_words=escaped_words)
+                elif isinstance(sent, (list, tuple)):
+                    sent = InputSentence(words=sent)
+                elif not isinstance(sent, InputSentence):
+                    raise ValueError(
+                        "Sentences must be one of: InputSentence, list, tuple, or str"
+                    )
+                batch_inputs.append(self._with_missing_fields_filled(sent))
+
+            for inp, output in zip(
+                batch_inputs, self._parser.parse(batch_inputs, return_compressed=True)
+            ):
+                # If pos tags are provided as input, ignore any tags predicted
+                # by the parser.
+                if inp.tags is not None:
+                    output = output.without_predicted_tags()
+                yield output.to_tree(
+                    inp.pos(),
+                    self._parser.decoder.label_from_index,
+                    self._parser.tag_from_index,
+                )
+
+    def _with_missing_fields_filled(self, sent):
+        if not isinstance(sent, InputSentence):
+            raise ValueError("Input is not an instance of InputSentence")
+        if sent.words is None and sent.escaped_words is None:
+            raise ValueError("At least one of words or escaped_words is required")
+        elif sent.words is None:
+            sent = dataclasses.replace(sent, words=ptb_unescape(sent.escaped_words))
+        elif sent.escaped_words is None:
+            escaped_words = [
+                word.replace("(", "-LRB-")
+                .replace(")", "-RRB-")
+                .replace("{", "-LCB-")
+                .replace("}", "-RCB-")
+                .replace("[", "-LSB-")
+                .replace("]", "-RSB-")
+                for word in sent.words
+            ]
+            sent = dataclasses.replace(sent, escaped_words=escaped_words)
+        else:
+            if len(sent.words) != len(sent.escaped_words):
+                raise ValueError(
+                    f"Length of words ({len(sent.words)}) does not match "
+                    f"escaped_words ({len(sent.escaped_words)})"
+                )
+
+        if sent.space_after is None:
+            if self._language_code == "zh":
+                space_after = [False for _ in sent.words]
+            elif self._language_code in ("ar", "he"):
+                space_after = [True for _ in sent.words]
+            else:
+                space_after = guess_space_after(sent.words)
+            sent = dataclasses.replace(sent, space_after=space_after)
+        elif len(sent.words) != len(sent.space_after):
+            raise ValueError(
+                f"Length of words ({len(sent.words)}) does not match "
+                f"space_after ({len(sent.space_after)})"
+            )
+
+        assert len(sent.words) == len(sent.escaped_words) == len(sent.space_after)
+        return sent
diff --git a/benepar/integrations/spacy_extensions.py b/benepar/integrations/spacy_extensions.py
new file mode 100644
index 0000000000000000000000000000000000000000..572dc45fa8371d97f758a39d213834ce33bed998
--- /dev/null
+++ b/benepar/integrations/spacy_extensions.py
@@ -0,0 +1,179 @@
+NOT_PARSED_SENTINEL = object()
+
+
+class NonConstituentException(Exception):
+    pass
+
+
+class ConstituentData:
+    def __init__(self, starts, ends, labels, loc_to_constituent, label_vocab):
+        self.starts = starts
+        self.ends = ends
+        self.labels = labels
+        self.loc_to_constituent = loc_to_constituent
+        self.label_vocab = label_vocab
+
+
+def get_constituent(span):
+    constituent_data = span.doc._._constituent_data
+    if constituent_data is NOT_PARSED_SENTINEL:
+        raise Exception(
+            "No constituency parse is available for this document."
+            " Consider adding a BeneparComponent to the pipeline."
+        )
+
+    search_start = constituent_data.loc_to_constituent[span.start]
+    if span.start + 1 < len(constituent_data.loc_to_constituent):
+        search_end = constituent_data.loc_to_constituent[span.start + 1]
+    else:
+        search_end = len(constituent_data.ends)
+    found_position = None
+    for position in range(search_start, search_end):
+        if constituent_data.ends[position] <= span.end:
+            if constituent_data.ends[position] == span.end:
+                found_position = position
+            break
+
+    if found_position is None:
+        raise NonConstituentException("Span is not a constituent: {}".format(span))
+    return constituent_data, found_position
+
+
+def get_labels(span):
+    constituent_data, position = get_constituent(span)
+    label_num = constituent_data.labels[position]
+    return constituent_data.label_vocab[label_num]
+
+
+def parse_string(span):
+    constituent_data, position = get_constituent(span)
+    label_vocab = constituent_data.label_vocab
+    doc = span.doc
+
+    idx = position - 1
+
+    def make_str():
+        nonlocal idx
+        idx += 1
+        i, j, label_idx = (
+            constituent_data.starts[idx],
+            constituent_data.ends[idx],
+            constituent_data.labels[idx],
+        )
+        label = label_vocab[label_idx]
+        if (i + 1) >= j:
+            token = doc[i]
+            s = (
+                "("
+                + u"{} {}".format(token.tag_, token.text)
+                .replace("(", "-LRB-")
+                .replace(")", "-RRB-")
+                .replace("{", "-LCB-")
+                .replace("}", "-RCB-")
+                .replace("[", "-LSB-")
+                .replace("]", "-RSB-")
+                + ")"
+            )
+        else:
+            children = []
+            while (
+                (idx + 1) < len(constituent_data.starts)
+                and i <= constituent_data.starts[idx + 1]
+                and constituent_data.ends[idx + 1] <= j
+            ):
+                children.append(make_str())
+
+            s = u" ".join(children)
+
+        for sublabel in reversed(label):
+            s = u"({} {})".format(sublabel, s)
+        return s
+
+    return make_str()
+
+
+def get_subconstituents(span):
+    constituent_data, position = get_constituent(span)
+    label_vocab = constituent_data.label_vocab
+    doc = span.doc
+
+    while position < len(constituent_data.starts):
+        start = constituent_data.starts[position]
+        end = constituent_data.ends[position]
+
+        if span.end <= start or span.end < end:
+            break
+
+        yield doc[start:end]
+        position += 1
+
+
+def get_child_spans(span):
+    constituent_data, position = get_constituent(span)
+    label_vocab = constituent_data.label_vocab
+    doc = span.doc
+
+    child_start_expected = span.start
+    position += 1
+    while position < len(constituent_data.starts):
+        start = constituent_data.starts[position]
+        end = constituent_data.ends[position]
+
+        if span.end <= start or span.end < end:
+            break
+
+        if start == child_start_expected:
+            yield doc[start:end]
+            child_start_expected = end
+
+        position += 1
+
+
+def get_parent_span(span):
+    constituent_data, position = get_constituent(span)
+    label_vocab = constituent_data.label_vocab
+    doc = span.doc
+    sent = span.sent
+
+    position -= 1
+    while position >= 0:
+        start = constituent_data.starts[position]
+        end = constituent_data.ends[position]
+
+        if start <= span.start and span.end <= end:
+            return doc[start:end]
+        if end < span.sent.start:
+            break
+        position -= 1
+
+    return None
+
+
+def install_spacy_extensions():
+    from spacy.tokens import Doc, Span, Token
+
+    # None is not allowed as a default extension value!
+    Doc.set_extension("_constituent_data", default=NOT_PARSED_SENTINEL)
+
+    Span.set_extension("labels", getter=get_labels)
+    Span.set_extension("parse_string", getter=parse_string)
+    Span.set_extension("constituents", getter=get_subconstituents)
+    Span.set_extension("parent", getter=get_parent_span)
+    Span.set_extension("children", getter=get_child_spans)
+
+    Token.set_extension(
+        "labels", getter=lambda token: get_labels(token.doc[token.i : token.i + 1])
+    )
+    Token.set_extension(
+        "parse_string",
+        getter=lambda token: parse_string(token.doc[token.i : token.i + 1]),
+    )
+    Token.set_extension(
+        "parent", getter=lambda token: get_parent_span(token.doc[token.i : token.i + 1])
+    )
+
+
+try:
+    install_spacy_extensions()
+except ImportError:
+    pass
diff --git a/benepar/integrations/spacy_plugin.py b/benepar/integrations/spacy_plugin.py
new file mode 100644
index 0000000000000000000000000000000000000000..41ca8b6e41a6a3368a7c1d207a99704b68a82491
--- /dev/null
+++ b/benepar/integrations/spacy_plugin.py
@@ -0,0 +1,206 @@
+import numpy as np
+
+from .downloader import load_trained_model
+from ..parse_base import BaseParser, BaseInputExample
+from .spacy_extensions import ConstituentData, NonConstituentException
+
+import torch
+
+
+class PartialConstituentData:
+    def __init__(self):
+        self.starts = [np.array([], dtype=int)]
+        self.ends = [np.array([], dtype=int)]
+        self.labels = [np.array([], dtype=int)]
+
+    def finalize(self, doc, label_vocab):
+        self.starts = np.hstack(self.starts)
+        self.ends = np.hstack(self.ends)
+        self.labels = np.hstack(self.labels)
+
+        # TODO(nikita): Python for loops aren't very fast
+        loc_to_constituent = np.full(len(doc), -1, dtype=int)
+        prev = None
+        for position in range(self.starts.shape[0]):
+            if self.starts[position] != prev:
+                prev = self.starts[position]
+                loc_to_constituent[self.starts[position]] = position
+
+        return ConstituentData(
+            self.starts, self.ends, self.labels, loc_to_constituent, label_vocab
+        )
+
+
+class SentenceWrapper(BaseInputExample):
+    TEXT_NORMALIZATION_MAPPING = {
+        "`": "'",
+        "«": '"',
+        "»": '"',
+        "‘": "'",
+        "’": "'",
+        "“": '"',
+        "”": '"',
+        "„": '"',
+        "‹": "'",
+        "›": "'",
+        "—": "--",  # em dash
+    }
+
+    def __init__(self, spacy_sent):
+        self.sent = spacy_sent
+
+    @property
+    def words(self):
+        return [
+            self.TEXT_NORMALIZATION_MAPPING.get(token.text, token.text)
+            for token in self.sent
+        ]
+
+    @property
+    def space_after(self):
+        return [bool(token.whitespace_) for token in self.sent]
+
+    @property
+    def tree(self):
+        return None
+
+    def leaves(self):
+        return self.words
+
+    def pos(self):
+        return [(word, "UNK") for word in self.words]
+
+
+class BeneparComponent:
+    """
+    Berkeley Neural Parser (benepar) component for spaCy.
+
+    Sample usage:
+    >>> nlp = spacy.load('en_core_web_md')
+    >>> if spacy.__version__.startswith('2'):
+            nlp.add_pipe(BeneparComponent("benepar_en3"))
+        else:
+            nlp.add_pipe("benepar", config={"model": "benepar_en3"})
+    >>> doc = nlp("The quick brown fox jumps over the lazy dog.")
+    >>> sent = list(doc.sents)[0]
+    >>> print(sent._.parse_string)
+
+    This component is only responsible for constituency parsing and (for some
+    trained models) part-of-speech tagging. It should be preceded in the
+    pipeline by other components that can, at minimum, perform tokenization and
+    sentence segmentation.
+    """
+
+    name = "benepar"
+
+    def __init__(
+        self,
+        name,
+        subbatch_max_tokens=500,
+        disable_tagger=False,
+        batch_size="ignored",
+    ):
+        """Load a trained parser model.
+
+        Args:
+            name (str): Model name, or path to pytorch saved model
+            subbatch_max_tokens (int): Maximum number of tokens to process in
+                each batch
+            disable_tagger (bool, default False): Unless disabled, the parser
+                will set predicted part-of-speech tags for the document,
+                overwriting any existing tags provided by spaCy models or
+                previous pipeline steps. This option has no effect for parser
+                models that do not have a part-of-speech tagger built in.
+            batch_size: deprecated and ignored; use subbatch_max_tokens instead
+        """
+        self._parser = load_trained_model(name)
+        if torch.cuda.is_available():
+            self._parser.cuda()
+
+        self.subbatch_max_tokens = subbatch_max_tokens
+        self.disable_tagger = disable_tagger
+
+        self._label_vocab = self._parser.config["label_vocab"]
+        label_vocab_size = max(self._label_vocab.values()) + 1
+        self._label_from_index = [()] * label_vocab_size
+        for label, i in self._label_vocab.items():
+            if label:
+                self._label_from_index[i] = tuple(label.split("::"))
+            else:
+                self._label_from_index[i] = ()
+        self._label_from_index = tuple(self._label_from_index)
+
+        if not self.disable_tagger:
+            tag_vocab = self._parser.config["tag_vocab"]
+            tag_vocab_size = max(tag_vocab.values()) + 1
+            self._tag_from_index = [()] * tag_vocab_size
+            for tag, i in tag_vocab.items():
+                self._tag_from_index[i] = tag
+            self._tag_from_index = tuple(self._tag_from_index)
+        else:
+            self._tag_from_index = None
+
+    def __call__(self, doc):
+        """Update the input document with predicted constituency parses."""
+        # TODO(https://github.com/nikitakit/self-attentive-parser/issues/16): handle
+        # tokens that consist entirely of whitespace.
+        constituent_data = PartialConstituentData()
+        wrapped_sents = [SentenceWrapper(sent) for sent in doc.sents]
+        for sent, parse in zip(
+            doc.sents,
+            self._parser.parse(
+                wrapped_sents,
+                return_compressed=True,
+                subbatch_max_tokens=self.subbatch_max_tokens,
+            ),
+        ):
+            constituent_data.starts.append(parse.starts + sent.start)
+            constituent_data.ends.append(parse.ends + sent.start)
+            constituent_data.labels.append(parse.labels)
+
+            if parse.tags is not None and not self.disable_tagger:
+                for i, tag_id in enumerate(parse.tags):
+                    sent[i].tag_ = self._tag_from_index[tag_id]
+
+        doc._._constituent_data = constituent_data.finalize(doc, self._label_from_index)
+        return doc
+
+
+def create_benepar_component(
+    nlp,
+    name,
+    model: str,
+    subbatch_max_tokens: int,
+    disable_tagger: bool,
+):
+    return BeneparComponent(
+        model,
+        subbatch_max_tokens=subbatch_max_tokens,
+        disable_tagger=disable_tagger,
+    )
+
+
+def register_benepar_component_factory():
+    # Starting with spaCy 3.0, nlp.add_pipe no longer directly accepts
+    # BeneparComponent instances. We must instead register a component factory.
+    import spacy
+
+    if spacy.__version__.startswith("2"):
+        return
+
+    from spacy.language import Language
+
+    Language.factory(
+        "benepar",
+        default_config={
+            "subbatch_max_tokens": 500,
+            "disable_tagger": False,
+        },
+        func=create_benepar_component,
+    )
+
+
+try:
+    register_benepar_component_factory()
+except ImportError:
+    pass
diff --git a/benepar/nkutil.py b/benepar/nkutil.py
new file mode 100644
index 0000000000000000000000000000000000000000..290ad20474d1406f9091aebbbbc960562c9075c1
--- /dev/null
+++ b/benepar/nkutil.py
@@ -0,0 +1,51 @@
+class HParams:
+    _skip_keys = ["populate_arguments", "set_from_args", "print", "to_dict"]
+
+    def __init__(self, **kwargs):
+        for k, v in kwargs.items():
+            setattr(self, k, v)
+
+    def __getitem__(self, item):
+        return getattr(self, item)
+
+    def __setitem__(self, item, value):
+        if not hasattr(self, item):
+            raise KeyError(f"Hyperparameter {item} has not been declared yet")
+        setattr(self, item, value)
+
+    def to_dict(self):
+        res = {}
+        for k in dir(self):
+            if k.startswith("_") or k in self._skip_keys:
+                continue
+            res[k] = self[k]
+        return res
+
+    def populate_arguments(self, parser):
+        for k in dir(self):
+            if k.startswith("_") or k in self._skip_keys:
+                continue
+            v = self[k]
+            k = k.replace("_", "-")
+            if type(v) in (int, float, str):
+                parser.add_argument(f"--{k}", type=type(v), default=v)
+            elif isinstance(v, bool):
+                if not v:
+                    parser.add_argument(f"--{k}", action="store_true")
+                else:
+                    parser.add_argument(f"--no-{k}", action="store_false")
+
+    def set_from_args(self, args):
+        for k in dir(self):
+            if k.startswith("_") or k in self._skip_keys:
+                continue
+            if hasattr(args, k):
+                self[k] = getattr(args, k)
+            elif hasattr(args, f"no_{k}"):
+                self[k] = getattr(args, f"no_{k}")
+
+    def print(self):
+        for k in dir(self):
+            if k.startswith("_") or k in self._skip_keys:
+                continue
+            print(k, repr(self[k]))
diff --git a/benepar/parse_base.py b/benepar/parse_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..9be49169f6ed97148ba6d109c7512a3c0e5feb05
--- /dev/null
+++ b/benepar/parse_base.py
@@ -0,0 +1,216 @@
+from abc import ABC, abstractmethod
+import dataclasses
+from typing import Any, Iterable, List, Optional, Tuple, Union
+
+import nltk
+import numpy as np
+
+
+class BaseInputExample(ABC):
+    """Parser input for a single sentence (abstract interface)."""
+
+    # Subclasses must define the following attributes or properties.
+    # `words` is a list of unicode representations for each word in the sentence
+    # and `space_after` is a list of booleans that indicate whether there is
+    # whitespace after a word. Together, these should form a reversible
+    # tokenization of raw text input. `tree` is an optional gold parse tree.
+    words: List[str]
+    space_after: List[bool]
+    tree: Optional[nltk.Tree]
+
+    @abstractmethod
+    def leaves(self) -> Optional[List[str]]:
+        """Returns leaves to use in the parse tree.
+
+        While `words` must be raw unicode text, these should be whatever is
+        standard for the treebank. For example, '(' in words might correspond to
+        '-LRB-' in leaves, and leaves might include other transformations such
+        as transliteration.
+        """
+        pass
+
+    @abstractmethod
+    def pos(self) -> Optional[List[Tuple[str, str]]]:
+        """Returns a list of (leaf, part-of-speech tag) tuples."""
+        pass
+
+
+@dataclasses.dataclass
+class CompressedParserOutput:
+    """Parser output, encoded as a collection of numpy arrays.
+
+    By default, a parser will return nltk.Tree objects. These have much nicer
+    APIs than the CompressedParserOutput class, and the code involved is simpler
+    and more readable. As a trade-off, code dealing with nltk.Tree objects is
+    slower: the nltk.Tree type itself has some overhead, and algorithms dealing
+    with it are implemented in pure Python as opposed to C or even CUDA. The
+    CompressedParserOutput type is an alternative that has some optimizations
+    for the sole purpose of speeding up inference.
+
+    If trying a new parser type for research purposes, it's safe to ignore this
+    class and the return_compressed argument to parse(). If the parser works
+    well and is being released, the return_compressed argument can then be added
+    with a dedicated fast implementation, or simply by using the from_tree
+    method defined below.
+    """
+
+    # A parse tree is represented as a set of constituents. In the case of
+    # non-binary trees, only the labeled non-terminal nodes are included: there
+    # are no dummy nodes inserted for binarization purposes. However, single
+    # words are always included in the set of constituents, and they may have a
+    # null label if there is no phrasal category above the part-of-speech tag.
+    # All constituents are sorted according to pre-order traversal, and each has
+    # an associated start (the index of the first word in the constituent), end
+    # (1 + the index of the last word in the constituent), and label (index
+    # associated with an external label_vocab dictionary.) These are then stored
+    # in three numpy arrays:
+    starts: Iterable[int]  # Must be a numpy array
+    ends: Iterable[int]  # Must be a numpy array
+    labels: Iterable[int]  # Must be a numpy array
+
+    # Part of speech tag ids as output by the parser (may be None if the parser
+    # does not do POS tagging). These indices are associated with an external
+    # tag_vocab dictionary.
+    tags: Optional[Iterable[int]] = None # Must be None or a numpy array
+
+    def without_predicted_tags(self):
+        return dataclasses.replace(self, tags=None)
+
+    def with_tags(self, tags):
+        return dataclasses.replace(self, tags=tags)
+
+    @classmethod
+    def from_tree(
+        cls, tree: nltk.Tree, label_vocab: dict, tag_vocab: Optional[dict] = None
+    ) -> "CompressedParserOutput":
+        num_words = len(tree.leaves())
+        starts = np.empty(2 * num_words, dtype=int)
+        ends = np.empty(2 * num_words, dtype=int)
+        labels = np.empty(2 * num_words, dtype=int)
+
+        def helper(tree, start, write_idx):
+            nonlocal starts, ends, labels
+            label = []
+            while len(tree) == 1 and not isinstance(tree[0], str):
+                if tree.label() != "TOP":
+                    label.append(tree.label())
+                tree = tree[0]
+
+            if len(tree) == 1 and isinstance(tree[0], str):
+                starts[write_idx] = start
+                ends[write_idx] = start + 1
+                labels[write_idx] = label_vocab["::".join(label)]
+                return start + 1, write_idx + 1
+
+            label.append(tree.label())
+            starts[write_idx] = start
+            labels[write_idx] = label_vocab["::".join(label)]
+
+            end = start
+            new_write_idx = write_idx + 1
+            for child in tree:
+                end, new_write_idx = helper(child, end, new_write_idx)
+
+            ends[write_idx] = end
+            return end, new_write_idx
+
+        _, num_constituents = helper(tree, 0, 0)
+        starts = starts[:num_constituents]
+        ends = ends[:num_constituents]
+        labels = labels[:num_constituents]
+
+        if tag_vocab is None:
+            tags = None
+        else:
+            tags = np.array([tag_vocab[tag] for _, tag in tree.pos()], dtype=int)
+
+        return cls(starts=starts, ends=ends, labels=labels, tags=tags)
+
+    def to_tree(self, leaves, label_from_index: dict, tag_from_index: dict = None):
+        if self.tags is not None:
+            if tag_from_index is None:
+                raise ValueError(
+                    "tags_from_index is required to convert predicted pos tags"
+                )
+            predicted_tags = [tag_from_index[i] for i in self.tags]
+            assert len(leaves) == len(predicted_tags)
+            leaves = [
+                nltk.Tree(tag, [leaf[0] if isinstance(leaf, tuple) else leaf])
+                for tag, leaf in zip(predicted_tags, leaves)
+            ]
+        else:
+            leaves = [
+                nltk.Tree(leaf[1], [leaf[0]])
+                if isinstance(leaf, tuple)
+                else (nltk.Tree("UNK", [leaf]) if isinstance(leaf, str) else leaf)
+                for leaf in leaves
+            ]
+
+        idx = -1
+
+        def helper():
+            nonlocal idx
+            idx += 1
+            i, j, label = (
+                self.starts[idx],
+                self.ends[idx],
+                label_from_index[self.labels[idx]],
+            )
+            if (i + 1) >= j:
+                children = [leaves[i]]
+            else:
+                children = []
+                while (
+                    (idx + 1) < len(self.starts)
+                    and i <= self.starts[idx + 1]
+                    and self.ends[idx + 1] <= j
+                ):
+                    children.extend(helper())
+
+            if label:
+                for sublabel in reversed(label.split("::")):
+                    children = [nltk.Tree(sublabel, children)]
+
+            return children
+
+        children = helper()
+        return nltk.Tree("TOP", children)
+
+
+class BaseParser(ABC):
+    """Parser (abstract interface)"""
+
+    @classmethod
+    @abstractmethod
+    def from_trained(
+        cls, model_name: str, config: dict = None, state_dict: dict = None
+    ) -> "BaseParser":
+        """Load a trained parser."""
+        pass
+
+    @abstractmethod
+    def parallelize(self, *args, **kwargs):
+        """Spread out pre-trained model layers across GPUs."""
+        pass
+
+    @abstractmethod
+    def parse(
+        self,
+        examples: Iterable[BaseInputExample],
+        return_compressed: bool = False,
+        return_scores: bool = False,
+        subbatch_max_tokens: Optional[int] = None,
+    ) -> Union[Iterable[nltk.Tree], Iterable[Any]]:
+        """Parse sentences."""
+        pass
+
+    @abstractmethod
+    def encode_and_collate_subbatches(
+        self, examples: List[BaseInputExample], subbatch_max_tokens: int
+    ) -> List[dict]:
+        """Split batch into sub-batches and convert to tensor features"""
+        pass
+
+    @abstractmethod
+    def compute_loss(self, batch: dict):
+        pass
diff --git a/benepar/parse_chart.py b/benepar/parse_chart.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cf8314885a8b77f01dd71d0636c34eb85d7f5ae
--- /dev/null
+++ b/benepar/parse_chart.py
@@ -0,0 +1,434 @@
+import os
+
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from transformers import AutoConfig, AutoModel
+
+from . import char_lstm
+from . import decode_chart
+from . import nkutil
+from .partitioned_transformer import (
+    ConcatPositionalEncoding,
+    FeatureDropout,
+    PartitionedTransformerEncoder,
+    PartitionedTransformerEncoderLayer,
+)
+from . import parse_base
+from . import retokenization
+from . import subbatching
+
+
+class ChartParser(nn.Module, parse_base.BaseParser):
+    def __init__(
+        self,
+        tag_vocab,
+        label_vocab,
+        char_vocab,
+        hparams,
+        pretrained_model_path=None,
+    ):
+        super().__init__()
+        self.config = locals()
+        self.config.pop("self")
+        self.config.pop("__class__")
+        self.config.pop("pretrained_model_path")
+        self.config["hparams"] = hparams.to_dict()
+
+        self.tag_vocab = tag_vocab
+        self.label_vocab = label_vocab
+        self.char_vocab = char_vocab
+
+        self.d_model = hparams.d_model
+
+        self.char_encoder = None
+        self.pretrained_model = None
+        if hparams.use_chars_lstm:
+            assert (
+                not hparams.use_pretrained
+            ), "use_chars_lstm and use_pretrained are mutually exclusive"
+            self.retokenizer = char_lstm.RetokenizerForCharLSTM(self.char_vocab)
+            self.char_encoder = char_lstm.CharacterLSTM(
+                max(self.char_vocab.values()) + 1,
+                hparams.d_char_emb,
+                hparams.d_model // 2,  # Half-size to leave room for
+                # partitioned positional encoding
+                char_dropout=hparams.char_lstm_input_dropout,
+            )
+        elif hparams.use_pretrained:
+            if pretrained_model_path is None:
+                self.retokenizer = retokenization.Retokenizer(
+                    hparams.pretrained_model, retain_start_stop=True
+                )
+                self.pretrained_model = AutoModel.from_pretrained(
+                    hparams.pretrained_model
+                )
+            else:
+                self.retokenizer = retokenization.Retokenizer(
+                    pretrained_model_path, retain_start_stop=True
+                )
+                self.pretrained_model = AutoModel.from_config(
+                    AutoConfig.from_pretrained(pretrained_model_path)
+                )
+            d_pretrained = self.pretrained_model.config.hidden_size
+
+            if hparams.use_encoder:
+                self.project_pretrained = nn.Linear(
+                    d_pretrained, hparams.d_model // 2, bias=False
+                )
+            else:
+                self.project_pretrained = nn.Linear(
+                    d_pretrained, hparams.d_model, bias=False
+                )
+
+        if hparams.use_encoder:
+            self.morpho_emb_dropout = FeatureDropout(hparams.morpho_emb_dropout)
+            self.add_timing = ConcatPositionalEncoding(
+                d_model=hparams.d_model,
+                max_len=hparams.encoder_max_len,
+            )
+            encoder_layer = PartitionedTransformerEncoderLayer(
+                hparams.d_model,
+                n_head=hparams.num_heads,
+                d_qkv=hparams.d_kv,
+                d_ff=hparams.d_ff,
+                ff_dropout=hparams.relu_dropout,
+                residual_dropout=hparams.residual_dropout,
+                attention_dropout=hparams.attention_dropout,
+            )
+            self.encoder = PartitionedTransformerEncoder(
+                encoder_layer, hparams.num_layers
+            )
+        else:
+            self.morpho_emb_dropout = None
+            self.add_timing = None
+            self.encoder = None
+
+        self.f_label = nn.Sequential(
+            nn.Linear(hparams.d_model, hparams.d_label_hidden),
+            nn.LayerNorm(hparams.d_label_hidden),
+            nn.ReLU(),
+            nn.Linear(hparams.d_label_hidden, max(label_vocab.values())),
+        )
+
+        if hparams.predict_tags:
+            self.f_tag = nn.Sequential(
+                nn.Linear(hparams.d_model, hparams.d_tag_hidden),
+                nn.LayerNorm(hparams.d_tag_hidden),
+                nn.ReLU(),
+                nn.Linear(hparams.d_tag_hidden, max(tag_vocab.values()) + 1),
+            )
+            self.tag_loss_scale = hparams.tag_loss_scale
+            self.tag_from_index = {i: label for label, i in tag_vocab.items()}
+        else:
+            self.f_tag = None
+            self.tag_from_index = None
+
+        self.decoder = decode_chart.ChartDecoder(
+            label_vocab=self.label_vocab,
+            force_root_constituent=hparams.force_root_constituent,
+        )
+        self.criterion = decode_chart.SpanClassificationMarginLoss(
+            reduction="sum", force_root_constituent=hparams.force_root_constituent
+        )
+
+        self.parallelized_devices = None
+
+    @property
+    def device(self):
+        if self.parallelized_devices is not None:
+            return self.parallelized_devices[0]
+        else:
+            return next(self.f_label.parameters()).device
+
+    @property
+    def output_device(self):
+        if self.parallelized_devices is not None:
+            return self.parallelized_devices[1]
+        else:
+            return next(self.f_label.parameters()).device
+
+    def parallelize(self, *args, **kwargs):
+        self.parallelized_devices = (torch.device("cuda", 0), torch.device("cuda", 1))
+        for child in self.children():
+            if child != self.pretrained_model:
+                child.to(self.output_device)
+        self.pretrained_model.parallelize(*args, **kwargs)
+
+    @classmethod
+    def from_trained(cls, model_path):
+        if os.path.isdir(model_path):
+            # Multi-file format used when exporting models for release.
+            # Unlike the checkpoints saved during training, these files include
+            # all tokenizer parameters and a copy of the pre-trained model
+            # config (rather than downloading these on-demand).
+            config = AutoConfig.from_pretrained(model_path).benepar
+            state_dict = torch.load(
+                os.path.join(model_path, "benepar_model.bin"), map_location="cpu"
+            )
+            config["pretrained_model_path"] = model_path
+        else:
+            # Single-file format used for saving checkpoints during training.
+            data = torch.load(model_path, map_location="cpu")
+            config = data["config"]
+            state_dict = data["state_dict"]
+
+        hparams = config["hparams"]
+
+        if "force_root_constituent" not in hparams:
+            hparams["force_root_constituent"] = True
+
+        config["hparams"] = nkutil.HParams(**hparams)
+        parser = cls(**config)
+        parser.load_state_dict(state_dict)
+        return parser
+
+    def encode(self, example):
+        if self.char_encoder is not None:
+            encoded = self.retokenizer(example.words, return_tensors="np")
+        else:
+            encoded = self.retokenizer(example.words, example.space_after)
+
+        if example.tree is not None:
+            encoded["span_labels"] = torch.tensor(
+                self.decoder.chart_from_tree(example.tree)
+            )
+            if self.f_tag is not None:
+                encoded["tag_labels"] = torch.tensor(
+                    [-100] + [self.tag_vocab[tag] for _, tag in example.pos()] + [-100]
+                )
+        return encoded
+
+    def pad_encoded(self, encoded_batch):
+        batch = self.retokenizer.pad(
+            [
+                {
+                    k: v
+                    for k, v in example.items()
+                    if (k != "span_labels" and k != "tag_labels")
+                }
+                for example in encoded_batch
+            ],
+            return_tensors="pt",
+        )
+        if encoded_batch and "span_labels" in encoded_batch[0]:
+            batch["span_labels"] = decode_chart.pad_charts(
+                [example["span_labels"] for example in encoded_batch]
+            )
+        if encoded_batch and "tag_labels" in encoded_batch[0]:
+            batch["tag_labels"] = nn.utils.rnn.pad_sequence(
+                [example["tag_labels"] for example in encoded_batch],
+                batch_first=True,
+                padding_value=-100,
+            )
+        return batch
+
+    def _get_lens(self, encoded_batch):
+        if self.pretrained_model is not None:
+            return [len(encoded["input_ids"]) for encoded in encoded_batch]
+        return [len(encoded["valid_token_mask"]) for encoded in encoded_batch]
+
+    def encode_and_collate_subbatches(self, examples, subbatch_max_tokens):
+        batch_size = len(examples)
+        batch_num_tokens = sum(len(x.words) for x in examples)
+        encoded = [self.encode(example) for example in examples]
+
+        res = []
+        for ids, subbatch_encoded in subbatching.split(
+            encoded, costs=self._get_lens(encoded), max_cost=subbatch_max_tokens
+        ):
+            subbatch = self.pad_encoded(subbatch_encoded)
+            subbatch["batch_size"] = batch_size
+            subbatch["batch_num_tokens"] = batch_num_tokens
+            res.append((len(ids), subbatch))
+        return res
+
+    def forward(self, batch):
+        valid_token_mask = batch["valid_token_mask"].to(self.output_device)
+
+        if (
+            self.encoder is not None
+            and valid_token_mask.shape[1] > self.add_timing.timing_table.shape[0]
+        ):
+            raise ValueError(
+                "Sentence of length {} exceeds the maximum supported length of "
+                "{}".format(
+                    valid_token_mask.shape[1] - 2,
+                    self.add_timing.timing_table.shape[0] - 2,
+                )
+            )
+
+        if self.char_encoder is not None:
+            assert isinstance(self.char_encoder, char_lstm.CharacterLSTM)
+            char_ids = batch["char_ids"].to(self.device)
+            extra_content_annotations = self.char_encoder(char_ids, valid_token_mask)
+        elif self.pretrained_model is not None:
+            input_ids = batch["input_ids"].to(self.device)
+            words_from_tokens = batch["words_from_tokens"].to(self.output_device)
+            pretrained_attention_mask = batch["attention_mask"].to(self.device)
+
+            extra_kwargs = {}
+            if "token_type_ids" in batch:
+                extra_kwargs["token_type_ids"] = batch["token_type_ids"].to(self.device)
+            if "decoder_input_ids" in batch:
+                extra_kwargs["decoder_input_ids"] = batch["decoder_input_ids"].to(
+                    self.device
+                )
+                extra_kwargs["decoder_attention_mask"] = batch[
+                    "decoder_attention_mask"
+                ].to(self.device)
+
+            pretrained_out = self.pretrained_model(
+                input_ids, attention_mask=pretrained_attention_mask, **extra_kwargs
+            )
+            features = pretrained_out.last_hidden_state.to(self.output_device)
+            features = features[
+                torch.arange(features.shape[0])[:, None],
+                # Note that words_from_tokens uses index -100 for invalid positions
+                F.relu(words_from_tokens),
+            ]
+            features.masked_fill_(~valid_token_mask[:, :, None], 0)
+            if self.encoder is not None:
+                extra_content_annotations = self.project_pretrained(features)
+
+        if self.encoder is not None:
+            encoder_in = self.add_timing(
+                self.morpho_emb_dropout(extra_content_annotations)
+            )
+
+            annotations = self.encoder(encoder_in, valid_token_mask)
+            # Rearrange the annotations to ensure that the transition to
+            # fenceposts captures an even split between position and content.
+
+            annotations = torch.cat(
+                [
+                    annotations[..., 0::2],
+                    annotations[..., 1::2],
+                ],
+                -1,
+            )
+        else:
+            assert self.pretrained_model is not None
+            annotations = self.project_pretrained(features)
+
+        if self.f_tag is not None:
+            tag_scores = self.f_tag(annotations)
+        else:
+            tag_scores = None
+
+        fencepost_annotations = torch.cat(
+            [
+                annotations[:, :-1, : self.d_model // 2],
+                annotations[:, 1:, self.d_model // 2 :],
+            ],
+            -1,
+        )
+
+        # Note that the bias added to the final layer norm is useless because
+        # this subtraction gets rid of it
+        span_features = (
+            torch.unsqueeze(fencepost_annotations, 1)
+            - torch.unsqueeze(fencepost_annotations, 2)
+        )[:, :-1, 1:]
+        span_scores = self.f_label(span_features)
+        span_scores = torch.cat(
+            [span_scores.new_zeros(span_scores.shape[:-1] + (1,)), span_scores], -1
+        )
+        return span_scores, tag_scores
+
+    def compute_loss(self, batch):
+        span_scores, tag_scores = self.forward(batch)
+        span_labels = batch["span_labels"].to(span_scores.device)
+        span_loss = self.criterion(span_scores, span_labels)
+        # Divide by the total batch size, not by the subbatch size
+        span_loss = span_loss / batch["batch_size"]
+        if tag_scores is None:
+            return span_loss
+        else:
+            tag_labels = batch["tag_labels"].to(tag_scores.device)
+            tag_loss = self.tag_loss_scale * F.cross_entropy(
+                tag_scores.reshape((-1, tag_scores.shape[-1])),
+                tag_labels.reshape((-1,)),
+                reduction="sum",
+                ignore_index=-100,
+            )
+            tag_loss = tag_loss / batch["batch_num_tokens"]
+            return span_loss + tag_loss
+
+    def _parse_encoded(
+        self, examples, encoded, return_compressed=False, return_scores=False
+    ):
+        with torch.no_grad():
+            batch = self.pad_encoded(encoded)
+            span_scores, tag_scores = self.forward(batch)
+            if return_scores:
+                span_scores_np = span_scores.cpu().numpy()
+            else:
+                # Start/stop tokens don't count, so subtract 2
+                lengths = batch["valid_token_mask"].sum(-1) - 2
+                charts_np = self.decoder.charts_from_pytorch_scores_batched(
+                    span_scores, lengths.to(span_scores.device)
+                )
+            if tag_scores is not None:
+                tag_ids_np = tag_scores.argmax(-1).cpu().numpy()
+            else:
+                tag_ids_np = None
+
+        for i in range(len(encoded)):
+            example_len = len(examples[i].words)
+            if return_scores:
+                yield span_scores_np[i, :example_len, :example_len]
+            elif return_compressed:
+                output = self.decoder.compressed_output_from_chart(charts_np[i])
+                if tag_ids_np is not None:
+                    output = output.with_tags(tag_ids_np[i, 1 : example_len + 1])
+                yield output
+            else:
+                if tag_scores is None:
+                    leaves = examples[i].pos()
+                else:
+                    predicted_tags = [
+                        self.tag_from_index[i]
+                        for i in tag_ids_np[i, 1 : example_len + 1]
+                    ]
+                    leaves = [
+                        (word, predicted_tag)
+                        for predicted_tag, (word, gold_tag) in zip(
+                            predicted_tags, examples[i].pos()
+                        )
+                    ]
+                yield self.decoder.tree_from_chart(charts_np[i], leaves=leaves)
+
+    def parse(
+        self,
+        examples,
+        return_compressed=False,
+        return_scores=False,
+        subbatch_max_tokens=None,
+    ):
+        training = self.training
+        self.eval()
+        encoded = [self.encode(example) for example in examples]
+        if subbatch_max_tokens is not None:
+            res = subbatching.map(
+                self._parse_encoded,
+                examples,
+                encoded,
+                costs=self._get_lens(encoded),
+                max_cost=subbatch_max_tokens,
+                return_compressed=return_compressed,
+                return_scores=return_scores,
+            )
+        else:
+            res = self._parse_encoded(
+                examples,
+                encoded,
+                return_compressed=return_compressed,
+                return_scores=return_scores,
+            )
+            res = list(res)
+        self.train(training)
+        return res
diff --git a/benepar/partitioned_transformer.py b/benepar/partitioned_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..a078b41b5c26e1ec283794d735e0e0e9bbe29201
--- /dev/null
+++ b/benepar/partitioned_transformer.py
@@ -0,0 +1,206 @@
+"""
+Transformer with partitioned content and position features.
+
+See section 3 of https://arxiv.org/pdf/1805.01052.pdf
+"""
+
+import copy
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class FeatureDropoutFunction(torch.autograd.function.InplaceFunction):
+    @staticmethod
+    def forward(ctx, input, p=0.5, train=False, inplace=False):
+        if p < 0 or p > 1:
+            raise ValueError(
+                "dropout probability has to be between 0 and 1, but got {}".format(p)
+            )
+
+        ctx.p = p
+        ctx.train = train
+        ctx.inplace = inplace
+
+        if ctx.inplace:
+            ctx.mark_dirty(input)
+            output = input
+        else:
+            output = input.clone()
+
+        if ctx.p > 0 and ctx.train:
+            ctx.noise = torch.empty(
+                (input.size(0), input.size(-1)),
+                dtype=input.dtype,
+                layout=input.layout,
+                device=input.device,
+            )
+            if ctx.p == 1:
+                ctx.noise.fill_(0)
+            else:
+                ctx.noise.bernoulli_(1 - ctx.p).div_(1 - ctx.p)
+            ctx.noise = ctx.noise[:, None, :]
+            output.mul_(ctx.noise)
+
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        if ctx.p > 0 and ctx.train:
+            return grad_output.mul(ctx.noise), None, None, None
+        else:
+            return grad_output, None, None, None
+
+
+class FeatureDropout(nn.Dropout):
+    """
+    Feature-level dropout: takes an input of size len x num_features and drops
+    each feature with probabibility p. A feature is dropped across the full
+    portion of the input that corresponds to a single batch element.
+    """
+
+    def forward(self, x):
+        if isinstance(x, tuple):
+            x_c, x_p = x
+            x_c = FeatureDropoutFunction.apply(x_c, self.p, self.training, self.inplace)
+            x_p = FeatureDropoutFunction.apply(x_p, self.p, self.training, self.inplace)
+            return x_c, x_p
+        else:
+            return FeatureDropoutFunction.apply(x, self.p, self.training, self.inplace)
+
+
+class PartitionedReLU(nn.ReLU):
+    def forward(self, x):
+        if isinstance(x, tuple):
+            x_c, x_p = x
+        else:
+            x_c, x_p = torch.chunk(x, 2, dim=-1)
+        return super().forward(x_c), super().forward(x_p)
+
+
+class PartitionedLinear(nn.Module):
+    def __init__(self, in_features, out_features, bias=True):
+        super().__init__()
+        self.linear_c = nn.Linear(in_features // 2, out_features // 2, bias)
+        self.linear_p = nn.Linear(in_features // 2, out_features // 2, bias)
+
+    def forward(self, x):
+        if isinstance(x, tuple):
+            x_c, x_p = x
+        else:
+            x_c, x_p = torch.chunk(x, 2, dim=-1)
+
+        out_c = self.linear_c(x_c)
+        out_p = self.linear_p(x_p)
+        return out_c, out_p
+
+
+class PartitionedMultiHeadAttention(nn.Module):
+    def __init__(
+        self, d_model, n_head, d_qkv, attention_dropout=0.1, initializer_range=0.02
+    ):
+        super().__init__()
+
+        self.w_qkv_c = nn.Parameter(torch.Tensor(n_head, d_model // 2, 3, d_qkv // 2))
+        self.w_qkv_p = nn.Parameter(torch.Tensor(n_head, d_model // 2, 3, d_qkv // 2))
+        self.w_o_c = nn.Parameter(torch.Tensor(n_head, d_qkv // 2, d_model // 2))
+        self.w_o_p = nn.Parameter(torch.Tensor(n_head, d_qkv // 2, d_model // 2))
+
+        bound = math.sqrt(3.0) * initializer_range
+        for param in [self.w_qkv_c, self.w_qkv_p, self.w_o_c, self.w_o_p]:
+            nn.init.uniform_(param, -bound, bound)
+        self.scaling_factor = 1 / d_qkv ** 0.5
+
+        self.dropout = nn.Dropout(attention_dropout)
+
+    def forward(self, x, mask=None):
+        if isinstance(x, tuple):
+            x_c, x_p = x
+        else:
+            x_c, x_p = torch.chunk(x, 2, dim=-1)
+        qkv_c = torch.einsum("btf,hfca->bhtca", x_c, self.w_qkv_c)
+        qkv_p = torch.einsum("btf,hfca->bhtca", x_p, self.w_qkv_p)
+        q_c, k_c, v_c = [c.squeeze(dim=3) for c in torch.chunk(qkv_c, 3, dim=3)]
+        q_p, k_p, v_p = [c.squeeze(dim=3) for c in torch.chunk(qkv_p, 3, dim=3)]
+        q = torch.cat([q_c, q_p], dim=-1) * self.scaling_factor
+        k = torch.cat([k_c, k_p], dim=-1)
+        v = torch.cat([v_c, v_p], dim=-1)
+        dots = torch.einsum("bhqa,bhka->bhqk", q, k)
+        if mask is not None:
+            dots.data.masked_fill_(~mask[:, None, None, :], -float("inf"))
+        probs = F.softmax(dots, dim=-1)
+        probs = self.dropout(probs)
+        o = torch.einsum("bhqk,bhka->bhqa", probs, v)
+        o_c, o_p = torch.chunk(o, 2, dim=-1)
+        out_c = torch.einsum("bhta,haf->btf", o_c, self.w_o_c)
+        out_p = torch.einsum("bhta,haf->btf", o_p, self.w_o_p)
+        return out_c, out_p
+
+
+class PartitionedTransformerEncoderLayer(nn.Module):
+    def __init__(
+        self,
+        d_model,
+        n_head,
+        d_qkv,
+        d_ff,
+        ff_dropout=0.1,
+        residual_dropout=0.1,
+        attention_dropout=0.1,
+        activation=PartitionedReLU(),
+    ):
+        super().__init__()
+        self.self_attn = PartitionedMultiHeadAttention(
+            d_model, n_head, d_qkv, attention_dropout=attention_dropout
+        )
+        self.linear1 = PartitionedLinear(d_model, d_ff)
+        self.ff_dropout = FeatureDropout(ff_dropout)
+        self.linear2 = PartitionedLinear(d_ff, d_model)
+
+        self.norm_attn = nn.LayerNorm(d_model)
+        self.norm_ff = nn.LayerNorm(d_model)
+        self.residual_dropout_attn = FeatureDropout(residual_dropout)
+        self.residual_dropout_ff = FeatureDropout(residual_dropout)
+
+        self.activation = activation
+
+    def forward(self, x, mask=None):
+        residual = self.self_attn(x, mask=mask)
+        residual = torch.cat(residual, dim=-1)
+        residual = self.residual_dropout_attn(residual)
+        x = self.norm_attn(x + residual)
+        residual = self.linear2(self.ff_dropout(self.activation(self.linear1(x))))
+        residual = torch.cat(residual, dim=-1)
+        residual = self.residual_dropout_ff(residual)
+        x = self.norm_ff(x + residual)
+        return x
+
+
+class PartitionedTransformerEncoder(nn.Module):
+    def __init__(self, encoder_layer, n_layers):
+        super().__init__()
+        self.layers = nn.ModuleList(
+            [copy.deepcopy(encoder_layer) for i in range(n_layers)]
+        )
+
+    def forward(self, x, mask=None):
+        for layer in self.layers:
+            x = layer(x, mask=mask)
+        return x
+
+
+class ConcatPositionalEncoding(nn.Module):
+    def __init__(self, d_model=256, max_len=512):
+        super().__init__()
+        self.timing_table = nn.Parameter(torch.FloatTensor(max_len, d_model // 2))
+        nn.init.normal_(self.timing_table)
+        self.norm = nn.LayerNorm(d_model)
+
+    def forward(self, x):
+        timing = self.timing_table[None, : x.shape[1], :]
+        x, timing = torch.broadcast_tensors(x, timing)
+        out = torch.cat([x, timing], dim=-1)
+        out = self.norm(out)
+        return out
diff --git a/benepar/ptb_unescape.py b/benepar/ptb_unescape.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9403d492257003c9145c494314b6e670da61fcc
--- /dev/null
+++ b/benepar/ptb_unescape.py
@@ -0,0 +1,83 @@
+PTB_UNESCAPE_MAPPING = {
+    "«": '"',
+    "»": '"',
+    "‘": "'",
+    "’": "'",
+    "“": '"',
+    "”": '"',
+    "„": '"',
+    "‹": "'",
+    "›": "'",
+    "\u2013": "--",  # en dash
+    "\u2014": "--",  # em dash
+}
+
+NO_SPACE_BEFORE = {"-RRB-", "-RCB-", "-RSB-", "''"} | set("%.,!?:;")
+NO_SPACE_AFTER = {"-LRB-", "-LCB-", "-LSB-", "``", "`"} | set("$#")
+NO_SPACE_BEFORE_TOKENS_ENGLISH = {"'", "'s", "'ll", "'re", "'d", "'m", "'ve"}
+PTB_DASH_ESCAPED = {"-RRB-", "-RCB-", "-RSB-", "-LRB-", "-LCB-", "-LSB-", "--"}
+
+
+def ptb_unescape(words):
+    cleaned_words = []
+    for word in words:
+        word = PTB_UNESCAPE_MAPPING.get(word, word)
+        # This un-escaping for / and * was not yet added for the
+        # parser version in https://arxiv.org/abs/1812.11760v1
+        # and related model releases (e.g. benepar_en2)
+        word = word.replace("\\/", "/").replace("\\*", "*")
+        # Mid-token punctuation occurs in biomedical text
+        word = word.replace("-LSB-", "[").replace("-RSB-", "]")
+        word = word.replace("-LRB-", "(").replace("-RRB-", ")")
+        word = word.replace("-LCB-", "{").replace("-RCB-", "}")
+        word = word.replace("``", '"').replace("`", "'").replace("''", '"')
+        cleaned_words.append(word)
+    return cleaned_words
+
+
+def guess_space_after_non_english(escaped_words):
+    sp_after = [True for _ in escaped_words]
+    for i, word in enumerate(escaped_words):
+        if i > 0 and (
+            (
+                word.startswith("-")
+                and not any(word.startswith(x) for x in PTB_DASH_ESCAPED)
+            )
+            or any(word.startswith(x) for x in NO_SPACE_BEFORE)
+            or word == "'"
+        ):
+            sp_after[i - 1] = False
+        if (
+            word.endswith("-") and not any(word.endswith(x) for x in PTB_DASH_ESCAPED)
+        ) or any(word.endswith(x) for x in NO_SPACE_AFTER):
+            sp_after[i] = False
+
+    return sp_after
+
+
+def guess_space_after(escaped_words, for_english=True):
+    if not for_english:
+        return guess_space_after_non_english(escaped_words)
+
+    sp_after = [True for _ in escaped_words]
+    for i, word in enumerate(escaped_words):
+        if word.lower() == "n't" and i > 0:
+            sp_after[i - 1] = False
+        elif word.lower() == "not" and i > 0 and escaped_words[i - 1].lower() == "can":
+            sp_after[i - 1] = False
+
+        if i > 0 and (
+            (
+                word.startswith("-")
+                and not any(word.startswith(x) for x in PTB_DASH_ESCAPED)
+            )
+            or any(word.startswith(x) for x in NO_SPACE_BEFORE)
+            or word.lower() in NO_SPACE_BEFORE_TOKENS_ENGLISH
+        ):
+            sp_after[i - 1] = False
+        if (
+            word.endswith("-") and not any(word.endswith(x) for x in PTB_DASH_ESCAPED)
+        ) or any(word.endswith(x) for x in NO_SPACE_AFTER):
+            sp_after[i] = False
+
+    return sp_after
diff --git a/benepar/retokenization.py b/benepar/retokenization.py
new file mode 100644
index 0000000000000000000000000000000000000000..42f77188c5faf721f0587aeeac6da302ac3d8be3
--- /dev/null
+++ b/benepar/retokenization.py
@@ -0,0 +1,258 @@
+"""
+Converts from linguistically motivated word-based tokenization to subword
+tokenization used by pre-trained models.
+"""
+
+import numpy as np
+import torch
+import transformers
+
+
+def retokenize(
+    tokenizer,
+    words,
+    space_after,
+    return_attention_mask=True,
+    return_offsets_mapping=False,
+    return_tensors=None,
+    **kwargs
+):
+    """Re-tokenize into subwords.
+
+    Args:
+        tokenizer: An instance of transformers.PreTrainedTokenizerFast
+        words: List of words
+        space_after: A list of the same length as `words`, indicating whether
+            whitespace follows each word.
+        **kwargs: all remaining arguments are passed on to tokenizer.__call__
+
+    Returns:
+        The output of tokenizer.__call__, with one additional dictionary field:
+        - **words_from_tokens** -- List of the same length as `words`, where
+          each entry is the index of the *last* subword that overlaps the
+          corresponding word.
+    """
+    s = "".join([w + (" " if sp else "") for w, sp in zip(words, space_after)])
+    word_offset_starts = np.cumsum(
+        [0] + [len(w) + (1 if sp else 0) for w, sp in zip(words, space_after)]
+    )[:-1]
+    word_offset_ends = word_offset_starts + np.asarray([len(w) for w in words])
+
+    tokenized = tokenizer(
+        s,
+        return_attention_mask=return_attention_mask,
+        return_offsets_mapping=True,
+        return_tensors=return_tensors,
+        **kwargs
+    )
+    if return_offsets_mapping:
+        token_offset_mapping = tokenized["offset_mapping"]
+    else:
+        token_offset_mapping = tokenized.pop("offset_mapping")
+    if return_tensors is not None:
+        token_offset_mapping = np.asarray(token_offset_mapping)[0].tolist()
+
+    offset_mapping_iter = iter(
+        [
+            (i, (start, end))
+            for (i, (start, end)) in enumerate(token_offset_mapping)
+            if start != end
+        ]
+    )
+    token_idx, (token_start, token_end) = next(offset_mapping_iter)
+    words_from_tokens = [-100] * len(words)
+    for word_idx, (word_start, word_end) in enumerate(
+        zip(word_offset_starts, word_offset_ends)
+    ):
+        while token_end <= word_start:
+            token_idx, (token_start, token_end) = next(offset_mapping_iter)
+        if token_end > word_end:
+            words_from_tokens[word_idx] = token_idx
+        while token_end <= word_end:
+            words_from_tokens[word_idx] = token_idx
+            try:
+                token_idx, (token_start, token_end) = next(offset_mapping_iter)
+            except StopIteration:
+                assert word_idx == len(words) - 1
+                break
+    if return_tensors == "np":
+        words_from_tokens = np.asarray(words_from_tokens, dtype=int)
+    elif return_tensors == "pt":
+        words_from_tokens = torch.tensor(words_from_tokens, dtype=torch.long)
+    elif return_tensors == "tf":
+        raise NotImplementedError("Returning tf tensors is not implemented")
+    tokenized["words_from_tokens"] = words_from_tokens
+    return tokenized
+
+
+class Retokenizer:
+    def __init__(self, pretrained_model_name_or_path, retain_start_stop=False):
+        self.tokenizer = transformers.AutoTokenizer.from_pretrained(
+            pretrained_model_name_or_path, fast=True
+        )
+        if not self.tokenizer.is_fast:
+            raise NotImplementedError(
+                "Converting from treebank tokenization to tokenization used by a "
+                "pre-trained model requires a 'fast' tokenizer, which appears to not "
+                "be available for this pre-trained model type."
+            )
+        self.retain_start_stop = retain_start_stop
+        self.is_t5 = "T5Tokenizer" in str(type(self.tokenizer))
+        self.is_gpt2 = "GPT2Tokenizer" in str(type(self.tokenizer))
+
+        if self.is_gpt2:
+            # The provided GPT-2 tokenizer does not specify a padding token by default
+            self.tokenizer.pad_token = self.tokenizer.eos_token
+
+        if self.retain_start_stop:
+            # When retain_start_stop is set, the next layer after the pre-trained model
+            # expects start and stop token embeddings. For BERT these can naturally be
+            # the feature vectors for CLS and SEP, but pre-trained models differ in the
+            # special tokens that they use. This code attempts to find special token
+            # positions for each pre-trained model.
+            dummy_ids = self.tokenizer.build_inputs_with_special_tokens([-100])
+            if self.is_t5:
+                # For T5 we use the output from the decoder, which accepts inputs that
+                # are shifted relative to the encoder.
+                dummy_ids = [self.tokenizer.pad_token_id] + dummy_ids
+            if self.is_gpt2:
+                # For GPT-2, we append an eos token if special tokens are needed
+                dummy_ids = dummy_ids + [self.tokenizer.eos_token_id]
+            try:
+                input_idx = dummy_ids.index(-100)
+            except ValueError:
+                raise NotImplementedError(
+                    "Could not automatically infer how to extract start/stop tokens "
+                    "from this pre-trained model"
+                )
+            num_prefix_tokens = input_idx
+            num_suffix_tokens = len(dummy_ids) - input_idx - 1
+            self.start_token_idx = None
+            self.stop_token_idx = None
+            if num_prefix_tokens > 0:
+                self.start_token_idx = num_prefix_tokens - 1
+            if num_suffix_tokens > 0:
+                self.stop_token_idx = -num_suffix_tokens
+            if self.start_token_idx is None and num_suffix_tokens > 0:
+                self.start_token_idx = -1
+            if self.stop_token_idx is None and num_prefix_tokens > 0:
+                self.stop_token_idx = 0
+            if self.start_token_idx is None or self.stop_token_idx is None:
+                assert num_prefix_tokens == 0 and num_suffix_tokens == 0
+                raise NotImplementedError(
+                    "Could not automatically infer how to extract start/stop tokens "
+                    "from this pre-trained model because the associated tokenizer "
+                    "appears not to add any special start/stop/cls/sep/etc. tokens "
+                    "to the sequence."
+                )
+
+    def __call__(self, words, space_after, **kwargs):
+        example = retokenize(self.tokenizer, words, space_after, **kwargs)
+        if self.is_t5:
+            # decoder_input_ids (which are shifted wrt input_ids) will be created after
+            # padding, but we adjust words_from_tokens now, in anticipation.
+            if isinstance(example["words_from_tokens"], list):
+                example["words_from_tokens"] = [
+                    x + 1 for x in example["words_from_tokens"]
+                ]
+            else:
+                example["words_from_tokens"] += 1
+        if self.retain_start_stop:
+            num_tokens = len(example["input_ids"])
+            if self.is_t5:
+                num_tokens += 1
+            if self.is_gpt2:
+                num_tokens += 1
+                if kwargs.get("return_tensors") == "pt":
+                    example["input_ids"] = torch.cat(
+                        example["input_ids"],
+                        torch.tensor([self.tokenizer.eos_token_id]),
+                    )
+                    example["attention_mask"] = torch.cat(
+                        example["attention_mask"], torch.tensor([1])
+                    )
+                else:
+                    example["input_ids"].append(self.tokenizer.eos_token_id)
+                    example["attention_mask"].append(1)
+            if num_tokens > self.tokenizer.model_max_length:
+                raise ValueError(
+                    f"Sentence of length {num_tokens} (in sub-word tokens) exceeds the "
+                    f"maximum supported length of {self.tokenizer.model_max_length}"
+                )
+            start_token_idx = (
+                self.start_token_idx
+                if self.start_token_idx >= 0
+                else num_tokens + self.start_token_idx
+            )
+            stop_token_idx = (
+                self.stop_token_idx
+                if self.stop_token_idx >= 0
+                else num_tokens + self.stop_token_idx
+            )
+            if kwargs.get("return_tensors") == "pt":
+                example["words_from_tokens"] = torch.cat(
+                    [
+                        torch.tensor([start_token_idx]),
+                        example["words_from_tokens"],
+                        torch.tensor([stop_token_idx]),
+                    ]
+                )
+            else:
+                example["words_from_tokens"] = (
+                    [start_token_idx] + example["words_from_tokens"] + [stop_token_idx]
+                )
+        return example
+
+    def pad(self, encoded_inputs, return_tensors=None, **kwargs):
+        if return_tensors != "pt":
+            raise NotImplementedError("Only return_tensors='pt' is supported.")
+        res = self.tokenizer.pad(
+            [
+                {k: v for k, v in example.items() if k != "words_from_tokens"}
+                for example in encoded_inputs
+            ],
+            return_tensors=return_tensors,
+            **kwargs
+        )
+        if self.tokenizer.padding_side == "right":
+            res["words_from_tokens"] = torch.nn.utils.rnn.pad_sequence(
+                [
+                    torch.tensor(example["words_from_tokens"])
+                    for example in encoded_inputs
+                ],
+                batch_first=True,
+                padding_value=-100,
+            )
+        else:
+            # XLNet adds padding tokens on the left of the sequence, so
+            # words_from_tokens must be adjusted to skip the added padding tokens.
+            assert self.tokenizer.padding_side == "left"
+            res["words_from_tokens"] = torch.nn.utils.rnn.pad_sequence(
+                [
+                    torch.tensor(example["words_from_tokens"])
+                    + (res["input_ids"].shape[-1] - len(example["input_ids"]))
+                    for example in encoded_inputs
+                ],
+                batch_first=True,
+                padding_value=-100,
+            )
+
+        if self.is_t5:
+            res["decoder_input_ids"] = torch.cat(
+                [
+                    torch.full_like(
+                        res["input_ids"][:, :1], self.tokenizer.pad_token_id
+                    ),
+                    res["input_ids"],
+                ],
+                1,
+            )
+            res["decoder_attention_mask"] = torch.cat(
+                [
+                    torch.ones_like(res["attention_mask"][:, :1]),
+                    res["attention_mask"],
+                ],
+                1,
+            )
+        res["valid_token_mask"] = res["words_from_tokens"] != -100
+        return res
diff --git a/benepar/spacy_plugin.py b/benepar/spacy_plugin.py
new file mode 100644
index 0000000000000000000000000000000000000000..3923cc46064c5f098a2276a4312c09bc65b8891a
--- /dev/null
+++ b/benepar/spacy_plugin.py
@@ -0,0 +1,13 @@
+__all__ = ["BeneparComponent", "NonConstituentException"]
+
+import warnings
+
+from .integrations.spacy_plugin import BeneparComponent, NonConstituentException
+
+warnings.warn(
+    "BeneparComponent and NonConstituentException have been moved to the benepar "
+    "module. Use `from benepar import BeneparComponent, NonConstituentException` "
+    "instead of benepar.spacy_plugin. The benepar.spacy_plugin namespace is deprecated "
+    "and will be removed in a future version.",
+    FutureWarning,
+)
diff --git a/benepar/subbatching.py b/benepar/subbatching.py
new file mode 100644
index 0000000000000000000000000000000000000000..53bed87ce8743034a670e358acc947709c57ee3d
--- /dev/null
+++ b/benepar/subbatching.py
@@ -0,0 +1,62 @@
+"""
+Utilities for splitting batches of examples into smaller sub-batches.
+
+This is useful during training when the batch size is too large to fit on GPU,
+meaning that gradient accumulation across multiple sub-batches must be used.
+It is also useful for batching examples during evaluation. Unlike a naive
+approach, this code groups examples with similar lengths to reduce the amount
+of wasted computation due to padding. 
+"""
+
+import numpy as np
+
+
+def split(*data, costs, max_cost):
+    """Splits a batch of input items into sub-batches.
+
+    Args:
+        *data: One or more lists of input items, all of the same length
+        costs: A list of costs for each item
+        max_cost: Maximum total cost for each sub-batch
+
+    Yields:
+        (example_ids, *subbatch_data) tuples.
+    """
+    costs = np.asarray(costs, dtype=int)
+    costs_argsort = np.argsort(costs).tolist()
+
+    subbatch_size = 1
+    while costs_argsort:
+        if subbatch_size == len(costs_argsort) or (
+            subbatch_size * costs[costs_argsort[subbatch_size]] > max_cost
+        ):
+            subbatch_item_ids = costs_argsort[:subbatch_size]
+            subbatch_data = [[items[i] for i in subbatch_item_ids] for items in data]
+            yield (subbatch_item_ids,) + tuple(subbatch_data)
+            costs_argsort = costs_argsort[subbatch_size:]
+            subbatch_size = 1
+        else:
+            subbatch_size += 1
+
+
+def map(func, *data, costs, max_cost, **common_kwargs):
+    """Maps a function over subbatches of input items.
+
+    Args:
+        func: Function to map over the data
+        *data: One or more lists of input items, all of the same length.
+        costs: A list of costs for each item
+        max_cost: Maximum total cost for each sub-batch
+        **common_kwargs: Keyword arguments to pass to all calls of func
+
+    Returns:
+        A list of outputs from calling func(*subbatch_data, **kwargs) for each
+        subbatch, and then rearranging the outputs from func into the original
+        item order.
+    """
+    res = [None] * len(data[0])
+    for item_ids, *subbatch_items in split(*data, costs=costs, max_cost=max_cost):
+        subbatch_out = func(*subbatch_items, **common_kwargs)
+        for item_id, item_out in zip(item_ids, subbatch_out):
+            res[item_id] = item_out
+    return res
diff --git a/parse.py b/parse.py
index 9261a3bd3bb07cfd7490694be920d1da81eaebee..855f3f75234744bb8cfbfd8855b6e91dc9c2d64d 100644
--- a/parse.py
+++ b/parse.py
@@ -2,12 +2,10 @@ import re
 import sys
 import benepar
 from huggingface_hub import hf_hub_download
-
-model_path = "ParserModels/ENHG/new-convbert-german-europeana0_dev=83.03.pt"
- hf_hub_download(repo_id=model_path, filename='german-delex-parser_dev=83.10.pt')
-parser = benepar.Parser(model_path)
     
 def parse(words):
+    model_path = hf_hub_download(repo_id="nielklug/enhg_parser", filename='new-convbert-german-europeana0_dev=83.03.pt')
+    parser = benepar.Parser(model_path)
     words = [word.replace('(','-LRB-').replace(')','-RRB-') for word in words]
     input_sentence = benepar.InputSentence(words=words)
     tree = parser.parse(input_sentence)
@@ -17,11 +15,3 @@ def parse(words):
     tree = re.sub(r' \(', '(', tree)
     return tree
             
-            
-with open(sys.argv[1]) as file:
-    for line in file:
-        line = re.sub(r'(\S)([.,;:?!)"])', r'\1 \2', line.strip())
-        line = re.sub(r'(["(])(\S)', r'\1 \2', line)
-        words = line.split()
-        tree = parse(words)
-        print(tree)