import gradio as gr from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM from peft import PeftModel from Ashare_data import * from Inference_datapipe import * import re import akshare as ak import pandas as pd import random import json import requests import math from datetime import date from datetime import date, datetime, timedelta # load model model = "meta-llama/Llama-2-7b-chat-hf" peft_model = "FinGPT/fingpt-forecaster_sz50_llama2-7B_lora" tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" model = AutoModelForCausalLM.from_pretrained(model, trust_remote_code=True, device_map = 'auto', offload_folder="offload/") model = PeftModel.from_pretrained(model, peft_model, offload_folder="offload/") model = model.eval() # Inference Data # get company news online B_INST, E_INST = "[INST]", "[/INST]" B_SYS, E_SYS = "<>\n", "\n<>\n\n" SYSTEM_PROMPT = "你是一名经验丰富的股票市场分析师。你的任务是根据公司在过去几周内的相关新闻和季度财务状况,列出公司的积极发展和潜在担忧,然后结合你对整体金融经济市场的判断,对公司未来一周的股价变化提供预测和分析。" \ "你的回答语言应为中文。你的回答格式应该如下:\n\n[积极发展]:\n1. ...\n\n[潜在担忧]:\n1. ...\n\n[预测和分析]:\n...\n" # ------------------------------------------------------------------------------ # Utils # ------------------------------------------------------------------------------ def get_curday(): return date.today().strftime("%Y%m%d") def n_weeks_before(date_string, n, format = "%Y%m%d"): date = datetime.strptime(date_string, "%Y%m%d") - timedelta(days=7*n) return date.strftime(format=format) def check_news_quality(n, last_n, week_end_date, repeat_rate = 0.6): try: # check content avalability if not (not(str(n['新闻内容'])[0].isdigit()) and not(str(n['新闻内容'])=='nan') and n['发布时间'][:8] <= week_end_date.replace('-', '')): return False # check highly duplicated news # (assume the duplicated contents happened adjacent) elif str(last_n['新闻内容'])=='nan': return True elif len(set(n['新闻内容'][:20]) & set(last_n['新闻内容'][:20])) >= 20*repeat_rate or len(set(n['新闻标题']) & set(last_n['新闻标题']))/len(last_n['新闻标题']) > repeat_rate: return False else: return True except TypeError: print(n) print(last_n) raise Exception("Check Error") def sample_news(news, k=5): return [news[i] for i in sorted(random.sample(range(len(news)), k))] def return_transform(ret): up_down = '涨' if ret >= 0 else '跌' integer = math.ceil(abs(100 * ret)) if integer == 0: return "平" return up_down + (str(integer) if integer <= 5 else '5+') def map_return_label(return_lb): lb = return_lb.replace('涨', '上涨') lb = lb.replace('跌', '下跌') lb = lb.replace('平', '股价持平') lb = lb.replace('1', '0-1%') lb = lb.replace('2', '1-2%') lb = lb.replace('3', '2-3%') lb = lb.replace('4', '3-4%') if lb.endswith('+'): lb = lb.replace('5+', '超过5%') else: lb = lb.replace('5', '4-5%') return lb # ------------------------------------------------------------------------------ # Get data from website # ------------------------------------------------------------------------------ def stock_news_em(symbol: str = "300059", page = 1) -> pd.DataFrame: url = "https://search-api-web.eastmoney.com/search/jsonp" params = { "cb": "jQuery3510875346244069884_1668256937995", "param": '{"uid":"",' + f'"keyword":"{symbol}"' + ',"type":["cmsArticleWebOld"],"client":"web","clientType":"web","clientVersion":"curr","param":{"cmsArticleWebOld":{"searchScope":"default","sort":"default",' + f'"pageIndex":{page}'+ ',"pageSize":100,"preTag":"","postTag":""}}}', "_": "1668256937996", } r = requests.get(url, params=params) data_text = r.text data_json = json.loads( data_text.strip("jQuery3510875346244069884_1668256937995(")[:-1] ) temp_df = pd.DataFrame(data_json["result"]["cmsArticleWebOld"]) temp_df.rename( columns={ "date": "发布时间", "mediaName": "文章来源", "code": "-", "title": "新闻标题", "content": "新闻内容", "url": "新闻链接", "image": "-", }, inplace=True, ) temp_df["关键词"] = symbol temp_df = temp_df[ [ "关键词", "新闻标题", "新闻内容", "发布时间", "文章来源", "新闻链接", ] ] temp_df["新闻标题"] = ( temp_df["新闻标题"] .str.replace(r"\(", "", regex=True) .str.replace(r"\)", "", regex=True) ) temp_df["新闻标题"] = ( temp_df["新闻标题"] .str.replace(r"", "", regex=True) .str.replace(r"", "", regex=True) ) temp_df["新闻内容"] = ( temp_df["新闻内容"] .str.replace(r"\(", "", regex=True) .str.replace(r"\)", "", regex=True) ) temp_df["新闻内容"] = ( temp_df["新闻内容"] .str.replace(r"", "", regex=True) .str.replace(r"", "", regex=True) ) temp_df["新闻内容"] = temp_df["新闻内容"].str.replace(r"\u3000", "", regex=True) temp_df["新闻内容"] = temp_df["新闻内容"].str.replace(r"\r\n", " ", regex=True) return temp_df def get_news(symbol, max_page = 3): df_list = [] for page in range(1, max_page): try: df_list.append(stock_news_em(symbol, page)) except KeyError: print(str(symbol) + "pages obtained for symbol: " + page) break news_df = pd.concat(df_list, ignore_index=True) return news_df def get_cur_return(symbol, start_date, end_date, adjust="qfq"): # load data return_data = ak.stock_zh_a_hist(symbol=symbol, period="daily", start_date=start_date, end_date=end_date, adjust=adjust) # process timestamp return_data["日期"] = pd.to_datetime(return_data["日期"]) return_data.set_index("日期", inplace=True) # resample and filled with forward data weekly_data = return_data["收盘"].resample("W").ffill() weekly_returns = weekly_data.pct_change()[1:] weekly_start_prices = weekly_data[:-1] weekly_end_prices = weekly_data[1:] weekly_data = pd.DataFrame({ '起始日期': weekly_start_prices.index, '起始价': weekly_start_prices.values, '结算日期': weekly_end_prices.index, '结算价': weekly_end_prices.values, '周收益': weekly_returns.values }) weekly_data["简化周收益"] = weekly_data["周收益"].map(return_transform) # check enddate if weekly_data.iloc[-1, 2] > pd.to_datetime(end_date): weekly_data.iloc[-1, 2] = pd.to_datetime(end_date) return weekly_data def get_basic(symbol, data): key_financials = ['报告期', '净利润同比增长率', '营业总收入同比增长率', '流动比率', '速动比率', '资产负债率'] # load quarterly basic data basic_quarter_financials = ak.stock_financial_abstract_ths(symbol = symbol, indicator="按单季度") basic_fin_dict = basic_quarter_financials.to_dict("index") basic_fin_list = [dict([(key, val) for key, val in basic_fin_dict[i].items() if (key in key_financials) and val]) for i in range(len(basic_fin_dict))] # match basic financial data to news dataframe matched_basic_fin = [] for i, row in data.iterrows(): newsweek_enddate = row['结算日期'].strftime("%Y-%m-%d") matched_basic = {} for basic in basic_fin_list: # match the most current financial report if basic["报告期"] < newsweek_enddate: matched_basic = basic break matched_basic_fin.append(json.dumps(matched_basic, ensure_ascii=False)) data['基本面'] = matched_basic_fin return data # ------------------------------------------------------------------------------ # Structure Data # ------------------------------------------------------------------------------ def cur_financial_data(symbol, start_date, end_date, with_basics = True): # get data data = get_cur_return(symbol=symbol, start_date=start_date, end_date=end_date) news_df = get_news(symbol=symbol) news_df["发布时间"] = pd.to_datetime(news_df["发布时间"], exact=False, format="%Y-%m-%d") news_df.sort_values(by=["发布时间"], inplace=True) # match weekly news for return data news_list = [] for a, row in data.iterrows(): week_start_date = row['起始日期'].strftime('%Y-%m-%d') week_end_date = row['结算日期'].strftime('%Y-%m-%d') print(symbol, ': ', week_start_date, ' - ', week_end_date) weekly_news = news_df.loc[(news_df["发布时间"]>week_start_date) & (news_df["发布时间"] row['起始价'] else '下跌' chg = map_return_label(row['简化周收益']) head = "自{}至{},{}的股票价格由{:.2f}{}至{:.2f},涨跌幅为:{}。在此期间的公司新闻如下:\n\n".format( week_start_date, week_end_date, stock, row['起始价'], term, row['结算价'], chg) news = json.loads(row["新闻"]) left, right = 0, 0 filtered_news = [] while left < len(news): n = news[left] if left == 0: # check first news quality if (not(str(n['新闻内容'])[0].isdigit()) and not(str(n['新闻内容'])=='nan') and n['发布时间'][:8] <= week_end_date.replace('-', '')): filtered_news.append("[新闻标题]:{}\n[新闻内容]:{}\n".format(n['新闻标题'], n['新闻内容'])) left += 1 else: news_check = check_news_quality(n, last_n = news[right], week_end_date= week_end_date, repeat_rate=0.5) if news_check: filtered_news.append("[新闻标题]:{}\n[新闻内容]:{}\n".format(n['新闻标题'], n['新闻内容'])) left += 1 right += 1 basics = json.loads(row['基本面']) if basics: basics = "如下所列为{}近期的一些金融基本面信息,记录时间为{}:\n\n[金融基本面]:\n\n".format( stock, basics['报告期']) + "\n".join(f"{k}: {v}" for k, v in basics.items() if k != 'period') else: basics = "[金融基本面]:\n\n 无金融基本面记录" return head, filtered_news, basics def get_all_prompts_online(symbol, with_basics=True, max_news_perweek = 3, weeks_before = 2): end_date = get_curday() start_date = n_weeks_before(end_date, weeks_before) company_prompt, stock = get_company_prompt_new(symbol) data = cur_financial_data(symbol=symbol, start_date=start_date, end_date=end_date, with_basics=with_basics) prev_rows = [] for row_idx, row in data.iterrows(): head, news, basics = get_prompt_by_row_new(symbol, row) prev_rows.append((head, news, basics)) prompt = "" for i in range(-len(prev_rows), 0): prompt += "\n" + prev_rows[i][0] sampled_news = sample_news( prev_rows[i][1], min(max_news_perweek, len(prev_rows[i][1])) ) if sampled_news: prompt += "\n".join(sampled_news) else: prompt += "No relative news reported." next_date = n_weeks_before(end_date, -1, format="%Y-%m-%d") end_date = pd.to_datetime(end_date).strftime("%Y-%m-%d") period = "{}至{}".format(end_date, next_date) if with_basics: basics = prev_rows[-1][2] else: basics = "[金融基本面]:\n\n 无金融基本面记录" info = company_prompt + '\n' + prompt + '\n' + basics new_system_prompt = SYSTEM_PROMPT.replace(':\n...', ':\n预测涨跌幅:...\n总结分析:...') prompt = B_INST + B_SYS + new_system_prompt + E_SYS + info + f"\n\n基于在{end_date}之前的所有信息,让我们首先分析{stock}的积极发展和潜在担忧。请简洁地陈述,分别提出2-4个最重要的因素。大部分所提及的因素应该从公司的相关新闻中推断出来。" \ f"接下来请预测{symbol}下周({period})的股票涨跌幅,并提供一个总结分析来支持你的预测。" + E_INST return info, prompt def ask(symbol, weeks_before): # load inference data info, pt = get_all_prompts_online(symbol=symbol, weeks_before=weeks_before) # print(info) inputs = tokenizer(pt, return_tensors='pt') inputs = {key: value.to(model.device) for key, value in inputs.items()} print("Inputs loaded onto devices.") res = model.generate( **inputs, use_cache=True ) output = tokenizer.decode(res[0], skip_special_tokens=True) output_cur = re.sub(r'.*\[/INST\]\s*', '', output, flags=re.DOTALL) return info, output_cur server = gr.Interface( ask, inputs=[ gr.Textbox( label="Symbol", value="600519", info="Companys from SZ50 are recommended" ), gr.Slider( minimum=1, maximum=3, value=2, step=1, label="weeks_before", info="Due to the token length constraint, you are recommended to input with 2" ), ], outputs=[ gr.Textbox( label="Information" ), gr.Textbox( label="Response" ) ], title="FinGPT-Forecaster-Chinese", description="""This version allows the prediction based on the most current date. We will upgrade it to allow customized date soon.""" ) server.launch()