from typing import Dict
from PIL import ImageFont

TPL_DEP_WORDS = """
<text class="displacy-token" fill="currentColor" text-anchor="start" y="{y}">
    <tspan class="displacy-word" fill="currentColor" x="{x}">{text}</tspan>
    <tspan class="displacy-tag" dy="2em" fill="currentColor" x="{x}">{tag}</tspan>
</text>
"""

TPL_DEP_SVG = """
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" xml:lang="{lang}" id="{id}" class="displacy" width="{width}" height="{height}" direction="{dir}" style="max-width: none; height: {height}px; color: {color}; background: {bg}; font-family: {font}; direction: {dir}">{content}</svg>
"""

TPL_DEP_ARCS = """
<g class="displacy-arrow">
    <path class="displacy-arc" id="arrow-{id}-{i}" stroke-width="{stroke}px" d="{arc}" fill="none" stroke="red"/>
    <text dy="1.25em" style="font-size: 0.8em; letter-spacing: 1px">
        <textPath xlink:href="#arrow-{id}-{i}" class="displacy-label" startOffset="50%" side="{label_side}" fill="red" text-anchor="middle">{label}</textPath>
    </text>
    <path class="displacy-arrowhead" d="{head}" fill="red"/>
</g>
"""


def get_pil_text_size(text, font_size, font_name):
    font = ImageFont.truetype(font_name, font_size)
    size = font.getsize(text)
    return size


def render_arrow(
        label: str, start: int, end: int, direction: str, i: int
) -> str:
    """Render individual arrow.

    label (str): Dependency label.
    start (int): Index of start word.
    end (int): Index of end word.
    direction (str): Arrow direction, 'left' or 'right'.
    i (int): Unique ID, typically arrow index.
    RETURNS (str): Rendered SVG markup.
    """

    arc = get_arc(start + 10, 50, 5, end + 10)
    arrowhead = get_arrowhead(direction, start + 10, 50, end + 10)
    label_side = "right" if direction == "rtl" else "left"
    return TPL_DEP_ARCS.format(
        id=0,
        i=0,
        stroke=2,
        head=arrowhead,
        label=label,
        label_side=label_side,
        arc=arc,
    )


def get_arc(x_start: int, y: int, y_curve: int, x_end: int) -> str:
    """Render individual arc.

    x_start (int): X-coordinate of arrow start point.
    y (int): Y-coordinate of arrow start and end point.
    y_curve (int): Y-corrdinate of Cubic Bézier y_curve point.
    x_end (int): X-coordinate of arrow end point.
    RETURNS (str): Definition of the arc path ('d' attribute).
    """
    template = "M{x},{y} C{x},{c} {e},{c} {e},{y}"
    return template.format(x=x_start, y=y, c=y_curve, e=x_end)


def get_arrowhead(direction: str, x: int, y: int, end: int) -> str:
    """Render individual arrow head.

    direction (str): Arrow direction, 'left' or 'right'.
    x (int): X-coordinate of arrow start point.
    y (int): Y-coordinate of arrow start and end point.
    end (int): X-coordinate of arrow end point.
    RETURNS (str): Definition of the arrow head path ('d' attribute).
    """
    arrow_width = 6
    if direction == "left":
        p1, p2, p3 = (x, x - arrow_width + 2, x + arrow_width - 2)
    else:
        p1, p2, p3 = (end, end + arrow_width - 2, end - arrow_width + 2)
    return f"M{p1},{y + 2} L{p2},{y - arrow_width} {p3},{y - arrow_width}"


def render_sentence_custom(unmatched_list: Dict, nlp):
    arcs_svg = []
    doc = nlp(unmatched_list["sentence"])

    x_value_counter = 10
    index_counter = 0
    svg_words = []
    words_under_arc = []
    direction_current = "rtl"

    if unmatched_list["cur_word_index"] < unmatched_list["target_word_index"]:
        min_index = unmatched_list["cur_word_index"]
        max_index = unmatched_list["target_word_index"]
        direction_current = "left"
    else:
        max_index = unmatched_list["cur_word_index"]
        min_index = unmatched_list["target_word_index"]
    for i, token in enumerate(doc):
        word = str(token)
        word = word + " "
        pixel_x_length = get_pil_text_size(word, 16, 'arial.ttf')[0]
        svg_words.append(TPL_DEP_WORDS.format(text=word, tag="", x=x_value_counter, y=70))
        if min_index <= index_counter <= max_index:
            words_under_arc.append(x_value_counter)
            if index_counter < max_index - 1:
                x_value_counter += 50
        index_counter += 1
        x_value_counter += pixel_x_length + 4

    arcs_svg.append(render_arrow(unmatched_list['dep'], words_under_arc[0], words_under_arc[-1], direction_current, i))

    content = "".join(svg_words) + "".join(arcs_svg)

    full_svg = TPL_DEP_SVG.format(
        id=0,
        width=1200,  # 600
        height=75,  # 125
        color="#00000",
        bg="#ffffff",
        font="Arial",
        content=content,
        dir="ltr",
        lang="en",
    )
    return full_svg