import ast
import base64
import duckdb
import json
import re
import textwrap
from ulid import ULID

HISTORY_FILE = "history.json"
MAX_ROWS = 10000

class SQLError(Exception):
    pass

class NotFoundError(Exception):
    pass

class Q(str):
    UNSAFE = ["CREATE", "DELETE", "DROP", "INSERT", "UPDATE"]
    rows=None
    def __new__(cls, template: str, **kwargs):
        """Create a new Q-string."""
        _template = textwrap.dedent(template).strip()
        try:
            instance = str.__new__(cls, _template.format(**kwargs))
        except KeyError:
            instance = str.__new__(cls, _template)
        instance.id = str(ULID())
        instance.alias = kwargs.pop("alias") if kwargs.get("alias") else None
        instance.template = _template
        instance.kwargs = kwargs
        instance.definitions = "\n".join([f"{k} = {repr(v)}" for k, v in kwargs.items()])

        for attr in ("rows", "cols", "source_id", "start", "end"):
            setattr(instance, attr, None)
        return instance

    def __repr__(self):
        """Neat repr for inspecting Q objects."""
        strings = []
        for k, v in self.__dict__.items():
            value_repr = "\n" + textwrap.indent(v, "  ") if "\n" in str(v) else v
            strings.append(f"{k}: {value_repr}")
        return "\n".join(strings)
    
    def run(self, sql_engine=None, save=False, _raise=False):
        self.start = ULID()
        try:
            if sql_engine is None:
                res = self.run_duckdb()
            else:
                res = self.run_sql(sql_engine) 
            self.rows, self.cols = res.shape
            return res
        except Exception as e:
            if _raise:
                raise e
            return str(e)
        finally:
            self.end = ULID()
            if save:
                self.save()
        
    def run_duckdb(self):
        if MAX_ROWS:
            return duckdb.sql(f"WITH x AS ({self}) SELECT * FROM x LIMIT {MAX_ROWS}")
        else:
            return duckdb.sql(self)

    def df(self, sql_engine=None, save=False, _raise=False):
        res = self.run(sql_engine=sql_engine, save=save, _raise=_raise)
        if not getattr(self, "rows", None):
            return
        else:
            result_df = res.df()
            result_df.q = self
            return result_df

    def save(self, file=HISTORY_FILE):
        with open(file, "a") as f:
            f.write(self.json)
            f.write("\n")
    
    @property
    def json(self):
        serialized = {"id": self.id, "q": self}
        serialized.update(self.__dict__)
        return json.dumps(serialized, default=lambda x: x.datetime.strftime("%F %T.%f")[:-3])
    
    @property
    def is_safe(self):
        return not any(cmd in self.template.upper() for cmd in self.UNSAFE)

    
    @classmethod
    def from_dict(cls, query_dict: dict):
        q = query_dict.pop("q")
        return cls(q, **query_dict)
    
    @classmethod
    def from_template_and_definitions(cls, template: str, definitions: str, alias: str|None = None):
        query_dict = {"q": template, "alias": alias}
        query_dict.update(parse_definitions(definitions))
        instance = Q.from_dict(query_dict)
        instance.definitions = definitions
        return instance

    @classmethod
    def from_history(cls, query_id=None, alias=None):
        search_query = Q(f"""
            SELECT id, template, kwargs
            FROM '{HISTORY_FILE}'
            WHERE id='{query_id}' OR alias='{alias}'
            LIMIT 1
        """)
        query = search_query.run()
        if search_query.rows == 1:
            source_id, template, kwargs = query.fetchall()[0]
            kwargs = {k: v for k, v in kwargs.items() if v is not None}
            instance = cls(template, **kwargs)
            instance.source_id = source_id
            return instance
        elif search_query.rows == 0:
            raise NotFoundError(f"id '{query_id}' / alias '{alias}' not found")
        else:
            raise SQLError(query)

    # @property
    # def definitions(self):
    #     return "\n".join([""]+[f"{k} = {v}" for k, v in self.kwargs.items()])
    
    @property
    def base64(self):
        return base64.b64encode(self.encode()).decode()

    @classmethod
    def from_base64(cls, b64):
        """Initializing from base64-encoded URL paths."""
        return cls(base64.b64decode(b64).decode())


def parse_definitions(definitions) -> dict:
    """Parse a string literal of "key=value" pairs, one per line, into kwargs."""
    kwargs = {}
    lines = definitions.split("\n")
    for _line in lines:
        line = re.sub("\s+", "", _line)
        if line == "" or line.startswith("#"):
            continue
        if "=" in line:
            key, value = line.split("=", maxsplit=1)
            kwargs[key] = ast.literal_eval(value)
    return kwargs


EX1 = Q.from_template_and_definitions(
    template="SELECT {x} AS {colname}",
    definitions="\n".join([
        "# Define variables: one '=' per line",
        "x=42",
        "colname='answer'",
    ]),
    alias="example1",
)

EX2 = Q(
    """
    SELECT
        Symbol,
        Number,
        Mass,
        Abundance
    FROM '{url}'
    """,
    url="https://raw.githubusercontent.com/ekwan/cctk/master/cctk/data/isotopes.csv",
    alias="example2",
)

EX3 = Q(
    """
    SELECT *
    FROM 'history.json'
    ORDER BY id DESC
    """,
    alias="example3",
)

EX4 = Q("SELECT nothing", alias="bad_example")