"""Gen cmat for de/en text."""
# pylint: disable=invalid-name, too-many-branches
from typing import List, Optional

import more_itertools as mit
import numpy as np

# from logzero import logger
from loguru import logger
from tqdm import tqdm

# from model_pool import load_model_s
# from hf_model_s_cpu import model_s  # load_model_s directly
from st_mlbee.load_model_s import load_model_s

# from st_mlbee.cos_matrix2 import cos_matrix2
from .cos_matrix2 import cos_matrix2

_ = """
try:
    model_s = load_model_s()
except Exception as exc:
    logger.erorr(exc)
    raise
"""

try:
    # model = model_s()
    # model = model_s(alive_bar_on=True)

    # default model-s mikeee/model_s_512
    model_s = load_model_s()
    # model_s_v2 = load_model_s("model_s_512v2")  # model-s mikeee/model-s-512v2
except Exception as _:
    logger.error(_)
    raise


def gen_cmat(
    text1: List[str],
    text2: List[str],
    bsize: int = 32,  # default batch_size of model.encode
    model=None,
) -> np.ndarray:
    """Gen corr matrix for texts.

    Args:
    ----
    text1: typically '''...''' splitlines()
    text2: typically '''...''' splitlines()
    bsize: batch size, default 50
    model: for encoding list of strings, default model-s of mikeee/model_s_512

    text1 = 'this is a test'
    text2 = 'another test'

    Returns:
    -------
    numpy array of cmat

    """
    if model is None:
        model = model_s
    bsize = int(bsize)
    if bsize <= 0:
        bsize = 32

    if isinstance(text1, str):
        text1 = [text1]
    if isinstance(text2, str):
        text1 = [text2]

    vec1 = []
    vec2 = []
    len1 = len(text1)
    len2 = len(text2)
    tot = len1 // bsize + bool(len1 % bsize)
    tot += len2 // bsize + bool(len2 % bsize)
    with tqdm(total=tot) as pbar:
        for chunk in mit.chunked(text1, bsize):
            try:
                vec = model.encode(chunk)
            except Exception as exc:
                logger.error(exc)
                raise
            vec1.extend(vec)
            pbar.update()
        for chunk in mit.chunked(text2, bsize):
            try:
                vec = model.encode(chunk)
            except Exception as exc:
                logger.error(exc)
                raise
            vec2.extend(vec)
            pbar.update()
    try:
        # note the order vec2, vec1
        _ = cos_matrix2(np.array(vec2), np.array(vec1))
    except Exception as exc:
        logger.exception(exc)
        raise

    return _