import re
import sys
import os
import trace
import traceback
from typing import final
import numpy as np
from collections import defaultdict
import pandas as pd
import time

# 如果使用 spaCy 进行 NLP 处理
from regex import R
import spacy

# 如果使用某种情感分析工具,比如 Hugging Face 的模型
from transformers import pipeline

# 还需要导入 pickle 模块(如果你在代码的其他部分使用了它来处理序列化/反序列化)
import pickle
from gensim.models import KeyedVectors
import akshare as ak

from gensim.models import Word2Vec
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
from transformers import BertTokenizer, BertForSequenceClassification




sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from us_stock import *




# 强制使用 GPU
#spacy.require_gpu()

# 加载模型
try:
    nlp = spacy.load("en_core_web_md")
except OSError:
    print("Downloading model 'en_core_web_md'...")
    from spacy.cli import download
    download("en_core_web_md")
    nlp = spacy.load("en_core_web_md")

# 检查是否使用 GPU
print("Is NPL GPU used Preprocessing.py:", spacy.prefer_gpu())


# 使用合适的模型和tokenizer
# tokenizer_one = AutoTokenizer.from_pretrained("ProsusAI/finbert")
# sa_model_one = AutoModelForSequenceClassification.from_pretrained("ProsusAI/finbert")


# tokenizer_two = BertTokenizer.from_pretrained('yiyanghkust/finbert-tone')
# sa_model_two = BertForSequenceClassification.from_pretrained('yiyanghkust/finbert-tone',num_labels=3)

import multiprocessing

# 添加进程锁
_tokenizer_lock = multiprocessing.Lock()
_models = {}

def get_tokenizer_and_model(model_type="one"):
    """懒加载tokenizer和model"""
    global _models
    
    if model_type not in _models:
        with _tokenizer_lock:
            if model_type not in _models:  # 双重检查锁定
                if model_type == "one":
                    tokenizer = AutoTokenizer.from_pretrained("ProsusAI/finbert")
                    model = AutoModelForSequenceClassification.from_pretrained("ProsusAI/finbert")
                else:
                    tokenizer = BertTokenizer.from_pretrained('yiyanghkust/finbert-tone')
                    model = BertForSequenceClassification.from_pretrained('yiyanghkust/finbert-tone', num_labels=3)
                _models[model_type] = (tokenizer, model)
                
    return _models[model_type]

index_us_stock_index_INX = ak.index_us_stock_sina(symbol=".INX")
index_us_stock_index_DJI = ak.index_us_stock_sina(symbol=".DJI")
index_us_stock_index_IXIC = ak.index_us_stock_sina(symbol=".IXIC")
index_us_stock_index_NDX = ak.index_us_stock_sina(symbol=".NDX")


class LazyWord2Vec:
    def __init__(self, model_path):
        self.model_path = model_path
        self._model = None

    def load_model(self):
        if self._model is None:
            print(f"Loading Word2Vec model from path: {self.model_path}...")
            self._model = KeyedVectors.load(self.model_path)

    @property
    def model(self):
        self.load_model()
        return self._model
    
    @property
    def vector_size(self):
        self.load_model()
        return self.model.vector_size

    def __getitem__(self, key):
        return self.model[key]

    def __contains__(self, key):
        return key in self.model
    
# 加载预训练的 Google News Word2Vec 模型

# 定义模型名称
from huggingface_hub import hf_hub_download
import os

# 定义 Hugging Face 的 repository 信息
repo_id = "fse/word2vec-google-news-300"  # 替换为实际的仓库ID
filename = "word2vec-google-news-300.model"  # 文件名

# 确保本地保存目录存在
#os.makedirs(local_model_path, exist_ok=True)

# 尝试从 Hugging Face 下载模型文件
try:
    print(f"Downloading {filename} from Hugging Face Hub...")
    downloaded_path = hf_hub_download(
        repo_id=repo_id,
        filename=filename
    )

    downloaded_path_npy = hf_hub_download(
        repo_id=repo_id,
        filename="word2vec-google-news-300.model.vectors.npy"
    )
    print(f"Model downloaded to {downloaded_path}")
except Exception as e:
    raise RuntimeError(f"Failed to download {filename} from Hugging Face Hub: {e}")

# 加载模型
print(f"Loading Word2Vec model from {downloaded_path}...")
word2vec_model = LazyWord2Vec(downloaded_path)


def pos_tagging(text):
    try:
        doc = nlp(text)
        tokens, pos_tags, tags = [], [], []
        for token in doc:
            if token.is_punct or token.is_stop:
                continue
            tokens.append(token.text)
            pos_tags.append(token.pos_)
            tags.append(token.tag_)
    except Exception as e:
        print(f"Error in pos_tagging for text: {text[:50]}... Error: {str(e)}")
        return "", "", ""

    return tokens, pos_tags, tags


# 命名实体识别函数
def named_entity_recognition(text):
    try:
        doc = nlp(text)
        entities = [(ent.text, ent.label_) for ent in doc.ents]
    except Exception as e:
        print(f"Error in named_entity_recognition for text: {text[:50]}... Error: {str(e)}")
        entities = []

    return entities or [("", "")]



# 处理命名实体识别结果
def process_entities(entities):
    entity_counts = defaultdict(int)
    try:
        for entity in entities:
            etype = entity[1]  # 取出实体类型
            entity_counts[etype] += 1  # 直接对实体类型进行计数
        
        # 将字典转化为有序的数组
        entity_types = sorted(entity_counts.keys())
        counts = np.array([entity_counts[etype] for etype in entity_types])
    except Exception as e:
        print(f"Error in process_entities: {str(e)}")
        counts = np.zeros(len(entities))
        entity_types = []

    return counts, entity_types



# 处理词性标注结果
def process_pos_tags(pos_tags):
    pos_counts = defaultdict(int)
    try:
        # 确保 pos_tags 不为空且是有效的标记
        if not pos_tags or not isinstance(pos_tags, (list, tuple)):
            print(f"Invalid POS tags: {pos_tags}")
            return np.zeros(1), []

        # 安全地处理每个 POS 标记
        for pos in pos_tags:
            if isinstance(pos, str) and pos:  # 确保是非空字符串
                pos_counts[pos] += 1
            elif isinstance(pos, (list, tuple)) and len(pos) > 1:  # 如果是元组/列表,取第二个元素
                pos_counts[pos[1]] += 1
        
        # 将字典转化为有序的数组
        pos_types = sorted(pos_counts.keys())
        if not pos_types:  # 如果没有有效的类型,返回零向量
            print(f"No valid POS tags found: {pos_tags}")
            return np.zeros(1), []
        
        counts = np.array([pos_counts[pos] for pos in pos_types])
    except Exception as e:
        print(f"Error in process_pos_tags: {str(e)} for POS tags: {pos_tags}")
        return np.zeros(1), []

    return counts, pos_types




# 函数:获取文档向量
def get_document_vector(words, model = word2vec_model):
    try:
        # 获取每个词的词向量,如果词不在模型中则跳过
        word_vectors = [model[word] for word in words if word in model]
        # 对词向量进行平均,得到文档向量;如果没有词在模型中则返回零向量
        document_vector = np.mean(word_vectors, axis=0) if word_vectors else np.zeros(model.vector_size)
    except Exception as e:
        print(f"Error in get_document_vector for words: {words[:5]}... Error: {str(e)}")
        document_vector = np.zeros(model.vector_size)
    
    return document_vector



# 函数:获取情感得分
def process_long_text(text, tokenizer, max_length=512):
    """
    将长文本分段并保持句子完整性,同时考虑特殊标记的长度
    """
    import nltk
    try:
        nltk.data.find('tokenizers/punkt')
    except LookupError:
        nltk.download('punkt')

    try:
        nltk.data.find('tokenizers/punkt_tab')
    except LookupError:
        nltk.download('punkt_tab')
        
        
    # 计算特殊标记占用的长度(CLS, SEP等)
    special_tokens_count = tokenizer.num_special_tokens_to_add()
    # 实际可用于文本的最大长度
    effective_max_length = max_length - special_tokens_count
    
    sentences = nltk.sent_tokenize(text)
    segments = []
    current_segment = ""
    
    for sentence in sentences:
        # 检查添加当前句子后是否会超过最大长度
        test_segment = current_segment + " " + sentence if current_segment else sentence
        if len(tokenizer.tokenize(test_segment)) > effective_max_length:
            if current_segment:
                segments.append(current_segment.strip())
            current_segment = sentence
        else:
            current_segment = test_segment
            
    # 添加最后一个段落
    if current_segment:
        segments.append(current_segment.strip())
        
    return segments

def get_sentiment_score(text):
    if text and text.strip() == "EMPTY_TEXT":
        return 0.0
    
    
    try:
        import torch
        
        # 懒加载获取tokenizer和model
        tokenizer_one, sa_model_one = get_tokenizer_and_model("one")
        tokenizer_two, sa_model_two = get_tokenizer_and_model("two")
        
        # 将长文本分段
        segments_one = process_long_text(text, tokenizer_one)
        segments_two = process_long_text(text, tokenizer_two)
        
        final_scores_one = []
        final_scores_two = []
        weights_one = []
        weights_two = []

        # 处理每个段落 - 模型一
        for segment in segments_one:
            with torch.no_grad():
                inputs = tokenizer_one(segment, return_tensors="pt", truncation=True, max_length=512)
                outputs = sa_model_one(**inputs)
                predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
                
                scores = predictions[0].tolist()
                score_positive = scores[0]
                score_negative = scores[1]
                score_neutral = scores[2]
                
                segment_score = 0.0
                segment_score += score_positive
                segment_score -= score_negative
                if score_positive > score_negative:
                    segment_score += score_neutral
                else:
                    segment_score -= score_neutral
                
                final_scores_one.append(np.clip(segment_score, -1.0, 1.0))
                weights_one.append(len(tokenizer_one.tokenize(segment)))

        # 处理每个段落 - 模型二
        for segment in segments_two:
            with torch.no_grad():
                inputs = tokenizer_two(segment, return_tensors="pt", truncation=True, max_length=512)
                outputs = sa_model_two(**inputs)
                predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
                
                scores = predictions[0].tolist()
                score_neutral = scores[0]
                score_positive = scores[1]
                score_negative = scores[2]
                
                segment_score = 0.0
                segment_score += score_positive
                segment_score -= score_negative
                if score_positive > score_negative:
                    segment_score += score_neutral
                else:
                    segment_score -= score_neutral
                
                final_scores_two.append(np.clip(segment_score, -1.0, 1.0))
                weights_two.append(len(tokenizer_two.tokenize(segment)))

        # 加权平均
        if final_scores_one:
            final_score_one = np.average(final_scores_one, weights=weights_one)
        else:
            final_score_one = 0.0
            
        if final_scores_two:
            final_score_two = np.average(final_scores_two, weights=weights_two)
        else:
            final_score_two = 0.0

        # 组合两个模型的结果
        final_score = np.average([final_score_one, final_score_two], weights=[0.3, 0.7])
        return np.clip(final_score, -1.0, 1.0)

    except Exception as e:
        print(f"Error in get_sentiment_score for text: {text[:50]}... Error: {str(e)}")
        traceback.print_exc()
        return 0.0



def get_stock_info(stock_code: str, history_days=30):
    # 获取股票代码和新闻日期

    news_date = datetime.now().strftime('%Y%m%d')
    # print(f"Getting stock info for {stock_codes} on {news_date}")

    previous_stock_history = []
    following_stock_history = []

    previous_stock_inx_index_history = []
    previous_stock_dj_index_history = []
    previous_stock_ixic_index_history = []
    previous_stock_ndx_index_history = []

    following_stock_inx_index_history = []
    following_stock_dj_index_history = []
    following_stock_ixic_index_history = []
    following_stock_ndx_index_history = []



    def process_history(stock_history, target_date, history_days=history_days, following_days = 3):
        # 如果数据为空,创建一个空的 DataFrame 并填充为 0
        if stock_history.empty:
            empty_data_previous = pd.DataFrame({
                '开盘': [-1] * history_days,
                '收盘': [-1] * history_days,
                '最高': [-1] * history_days,
                '最低': [-1] * history_days,
                '成交量': [-1] * history_days,
                '成交额': [-1] * history_days
            })

            empty_data_following = pd.DataFrame({
                '开盘': [-1] * following_days,
                '收盘': [-1] * following_days,
                '最高': [-1] * following_days,
                '最低': [-1] * following_days,
                '成交量': [-1] * following_days,
                '成交额': [-1] * following_days
            })
            return empty_data_previous, empty_data_following

        # 确保 'date' 列存在
        if 'date' not in stock_history.columns:
            print(f"'date' column not found in stock history. Returning empty data.")
            return pd.DataFrame([[-1] * 6] * history_days), pd.DataFrame([[-1] * 6] * following_days)

        # 将日期转换为 datetime 格式,便于比较
        stock_history['date'] = pd.to_datetime(stock_history['date'])
        target_date = pd.to_datetime(target_date)

        # 找到目标日期的索引
        target_row = stock_history[stock_history['date'] == target_date]
        
        if target_row.empty:
            # 如果目标日期找不到,找到离目标日期最近的日期
            closest_date_index = (stock_history['date'] - target_date).abs().idxmin()
            target_date = stock_history.loc[closest_date_index, 'date']
            target_row = stock_history[stock_history['date'] == target_date]

        # 确保找到的目标日期有数据
        if target_row.empty:
            return pd.DataFrame([[-1] * 6] * history_days), pd.DataFrame([[-1] * 6] * following_days)

        target_index = target_row.index[0]
        target_pos = stock_history.index.get_loc(target_index)

        # 取出目标日期及其前history_days条记录
        previous_rows = stock_history.iloc[max(0, target_pos - history_days):target_pos + 1]

        # 取出目标日期及其后3条记录
        following_rows = stock_history.iloc[target_pos + 1:target_pos + 4]

        # 删除日期列
        previous_rows = previous_rows.drop(columns=['date'])
        following_rows = following_rows.drop(columns=['date'])

        # 如果 previous_rows 或 following_rows 的行数不足 history_days,则填充至 history_days 行
        if len(previous_rows) < history_days:
            previous_rows = previous_rows.reindex(range(history_days), fill_value=-1)

        if len(following_rows) < 3:
            following_rows = following_rows.reindex(range(3), fill_value=-1)

        # 只返回前history_days行,并只返回前6列(开盘、收盘、最高、最低、成交量、成交额)
        previous_rows = previous_rows.iloc[:history_days, :6]
        following_rows = following_rows.iloc[:following_days, :6]

        return previous_rows, following_rows

    stock_index_ndx_history = get_stock_index_history("", news_date, 1)
    stock_index_dj_history = get_stock_index_history("", news_date, 2)
    stock_index_inx_history = get_stock_index_history("", news_date, 3)
    stock_index_ixic_history = get_stock_index_history("", news_date, 4)

    previous_ndx_rows, following_ndx_rows = process_history(stock_index_ndx_history, news_date, history_days)
    previous_dj_rows, following_dj_rows = process_history(stock_index_dj_history, news_date, history_days)
    previous_inx_rows, following_inx_rows = process_history(stock_index_inx_history, news_date, history_days)
    previous_ixic_rows, following_ixic_rows = process_history(stock_index_ixic_history, news_date, history_days)


    previous_stock_inx_index_history.append(previous_inx_rows.values.tolist())
    previous_stock_dj_index_history.append(previous_dj_rows.values.tolist())
    previous_stock_ixic_index_history.append(previous_ixic_rows.values.tolist())
    previous_stock_ndx_index_history.append(previous_ndx_rows.values.tolist())

    following_stock_inx_index_history.append(following_inx_rows.values.tolist())
    following_stock_dj_index_history.append(following_dj_rows.values.tolist())
    following_stock_ixic_index_history.append(following_ixic_rows.values.tolist())
    following_stock_ndx_index_history.append(following_ndx_rows.values.tolist())


    if not stock_code or stock_code == '' or stock_code == 'NONE_SYMBOL_FOUND':
        # 个股补零逻辑
        previous_stock_history.append([[-1] * 6] * history_days)
        following_stock_history.append([[-1] * 6] * 3)

    else:
        stock_code = stock_code.strip()
        stock_history = get_stock_history(stock_code, news_date)
        
        # 处理个股数据
        previous_rows, following_rows = process_history(stock_history, news_date)
        previous_stock_history.append(previous_rows.values.tolist())
        following_stock_history.append(following_rows.values.tolist())

        
    return  previous_stock_history, following_stock_history, \
            previous_stock_inx_index_history, previous_stock_dj_index_history, previous_stock_ixic_index_history, previous_stock_ndx_index_history, \
            following_stock_inx_index_history, following_stock_dj_index_history, following_stock_ixic_index_history, following_stock_ndx_index_history,



def lemmatized_entry(entry):
    entry_start_time = time.time()
    # Step 1 - 条目聚合
    lemmatized_text = preprocessing_entry(entry)
    
        
    return lemmatized_text






# 1. 数据清理
# 1.1 合并数据
# 1.2 去除噪声
# 1.3 大小写转换
# 1.4 去除停用词
# 1.5 词汇矫正与拼写检查
# 1.6 词干提取与词形还原



# 强制使用 GPU
# spacy.require_gpu()

# 加载模型
nlp = spacy.load("en_core_web_md")

# 检查是否使用 GPU
# print("Is NPL GPU used Lemmatized:", spacy.prefer_gpu())




def preprocessing_entry(news_entry):
    """数据清理启动函数

    Args:
        text (str): preprocessing后的文本

    Returns:
        [str]]: 词干提取后的String列表
    """

    # 1.1 合并数据
    text = merge_text(news_entry)

    # 1.2 去除噪声
    text = disposal_noise(text)

    # 1.3 大小写转换
    text = text.lower()

    # 1.4 去除停用词
    text = remove_stopwords(text)

    # 1.5 拼写检查
    #text = correct_spelling(text)
    #print(f"1.5 拼写检查后的文本:{text}")

    # 1.6 词干提取与词形还原
    lemmatized_text_list = lemmatize_text(text)
    #print(f"1.6 词干提取与词形还原后的文本:{lemmatized_text_list}")

    return lemmatized_text_list



# 1.1 合并数据
def merge_text(news_entry):
    return news_entry


# 1.2 去除噪声
def disposal_noise(text):
     # 移除HTML标签
    text = re.sub(r'<.*?>', '', text)
    # 移除URLs
    text = re.sub(r'http\S+|www\S+|https\S+', '', text, flags=re.MULTILINE)
    # 移除方括号内的内容
    # text = re.sub(r'\[.*?\]', '', text)
    # 移除标点符号
    # text = re.sub(r'[^\w\s]', '', text)
    # 移除多余的空格
    text = re.sub(r'\s+', ' ', text).strip()
    # 或者选择性地过滤,例如移除表情符号
    # text = re.sub(r'[^\w\s.,!?]', '', text)
    # 移除换行符和制表符
    text = re.sub(r'[\n\t\r]', ' ', text)


    return text

# 1.4 去除停用词


def remove_stopwords(text):
    # 使用 spaCy 处理文本
    doc = nlp(text)
    # 去除停用词,并且仅保留标识为“词”(Token.is_alpha)类型的标记
    filtered_sentence = [token.text for token in doc if not token.is_stop and (token.is_alpha or token.like_num)]
    return ' '.join(filtered_sentence)





# 1.5 拼写检查
# 该函数用于检查输入文本的拼写错误,并修正
# def correct_spelling(text):
#     corrected_text = []
#     doc = nlp(text)
#     for token in doc:
#         if token.is_alpha:  # 仅检查字母构成的单词
#             corrected_word = spell.correction(token.text)
#             if corrected_word is None:
#                 # 如果拼写检查没有建议,保留原始单词
#                 corrected_word = token.text
#             corrected_text.append(corrected_word)
#         else:
#             corrected_text.append(token.text)
#     return " ".join(corrected_text)


# 1.6 词干提取与词形还原
# 该函数用于对输入文本进行词形还原,返回一个包含词形还原后单词
def lemmatize_text(text):
    # 提取词干化后的词
    lemmatized_words = []
    doc = nlp(text)  # 需要在这里处理输入文本
    for token in doc:
        # 忽略标点符号和空格
        if not token.is_punct and not token.is_space and (token.is_alpha or token.like_num):
            lemmatized_words.append(token.lemma_)
    return lemmatized_words






# 2. 数据增强和特征提取
# 2.1 词性标注(Part-of-Speech Tagging)
# 为每个词标注其词性(如名词、动词、形容词等),这有助于后续的句法分析和信息提取。
# 工具:spaCy 或 NLTK
# 2.2 命名实体识别(NER)
# 识别文本中的命名实体,如人名、地名、组织机构等,提取出这些实体信息。
# 工具:spaCy 或 Stanford NER
# 2.3 句法分析与依存分析
# 分析句子结构,理解单词之间的关系(如主谓宾结构)。
# 工具:spaCy 或 NLTK



# 2 特征提取


# 强制使用 GPU
#spacy.require_gpu()

# 加载模型
nlp = spacy.load("en_core_web_md")

# 检查是否使用 GPU
# print("Is NPL GPU used Enchance_text.py:", spacy.prefer_gpu())



# 2.3 句法分析与依存分析
def dependency_parsing(text):
    doc = nlp(text)
    dependencies = []

    for token in doc:
        # 过滤标点符号和停用词,或其他不需要的词性
        if token.is_punct or token.is_stop:
            continue
        
        # 可以进一步根据特定的依存关系类型过滤结果
        # 常见的依存关系类型: 'nsubj' (名词主语), 'dobj' (直接宾语), 等等
        # if token.dep_ not in {'nsubj', 'dobj', ...}:
        #     continue
        
        dependencies.append((token.text, token.dep_, token.head.text))
    
    return dependencies



def processing_entry(entry):
    # print(f"processing_entry: {entry}")

    text = entry
    if text and text.strip() == "EMPTY_TEXT":
        text = "It just a normal day."


    lemmatized_entry = preprocessing_entry(text)
    # print(f"lemmatized_entry: {lemmatized_entry}")

    cleaned_text = disposal_noise(text)
    # print(f"disposal_noise: {cleaned_text}")

    pos_tag = pos_tagging(cleaned_text)
    # print(f"pos_tagging: {db_pos_tag}")

    ner = named_entity_recognition(cleaned_text)
    # print(f"named_entity_recognition: {db_ner}")

    # dependency_parsed = dependency_parsing(cleaned_text)
    # print(f"dependency_parsing: {db_dependency_parsing}")
    dependency_parsed = None

    sentiment_score = get_sentiment_score(entry)
    # print(f"sentiment_score: {sentiment_score}")




    return (lemmatized_entry, pos_tag, ner, dependency_parsed, sentiment_score)