import gradio as gr
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from Ashare_data import *
from Inference_datapipe import *
import re
import akshare as ak
import pandas as pd
import random
import json
import requests
import math
from datetime import date
from datetime import date, datetime, timedelta
# load model
model = "meta-llama/Llama-2-7b-chat-hf"
peft_model = "FinGPT/fingpt-forecaster_sz50_llama2-7B_lora"
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
model = AutoModelForCausalLM.from_pretrained(model, trust_remote_code=True, device_map = 'auto', offload_folder="offload/")
model = PeftModel.from_pretrained(model, peft_model, offload_folder="offload/")
model = model.eval()
# Inference Data
# get company news online
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<>\n", "\n<>\n\n"
SYSTEM_PROMPT = "你是一名经验丰富的股票市场分析师。你的任务是根据公司在过去几周内的相关新闻和季度财务状况,列出公司的积极发展和潜在担忧,然后结合你对整体金融经济市场的判断,对公司未来一周的股价变化提供预测和分析。" \
"你的回答语言应为中文。你的回答格式应该如下:\n\n[积极发展]:\n1. ...\n\n[潜在担忧]:\n1. ...\n\n[预测和分析]:\n...\n"
# ------------------------------------------------------------------------------
# Utils
# ------------------------------------------------------------------------------
def get_curday():
return date.today().strftime("%Y%m%d")
def n_weeks_before(date_string, n, format = "%Y%m%d"):
date = datetime.strptime(date_string, "%Y%m%d") - timedelta(days=7*n)
return date.strftime(format=format)
def check_news_quality(n, last_n, week_end_date, repeat_rate = 0.6):
try:
# check content avalability
if not (not(str(n['新闻内容'])[0].isdigit()) and not(str(n['新闻内容'])=='nan') and n['发布时间'][:8] <= week_end_date.replace('-', '')):
return False
# check highly duplicated news
# (assume the duplicated contents happened adjacent)
elif str(last_n['新闻内容'])=='nan':
return True
elif len(set(n['新闻内容'][:20]) & set(last_n['新闻内容'][:20])) >= 20*repeat_rate or len(set(n['新闻标题']) & set(last_n['新闻标题']))/len(last_n['新闻标题']) > repeat_rate:
return False
else:
return True
except TypeError:
print(n)
print(last_n)
raise Exception("Check Error")
def sample_news(news, k=5):
return [news[i] for i in sorted(random.sample(range(len(news)), k))]
def return_transform(ret):
up_down = '涨' if ret >= 0 else '跌'
integer = math.ceil(abs(100 * ret))
if integer == 0:
return "平"
return up_down + (str(integer) if integer <= 5 else '5+')
def map_return_label(return_lb):
lb = return_lb.replace('涨', '上涨')
lb = lb.replace('跌', '下跌')
lb = lb.replace('平', '股价持平')
lb = lb.replace('1', '0-1%')
lb = lb.replace('2', '1-2%')
lb = lb.replace('3', '2-3%')
lb = lb.replace('4', '3-4%')
if lb.endswith('+'):
lb = lb.replace('5+', '超过5%')
else:
lb = lb.replace('5', '4-5%')
return lb
# ------------------------------------------------------------------------------
# Get data from website
# ------------------------------------------------------------------------------
def stock_news_em(symbol: str = "300059", page = 1) -> pd.DataFrame:
url = "https://search-api-web.eastmoney.com/search/jsonp"
params = {
"cb": "jQuery3510875346244069884_1668256937995",
"param": '{"uid":"",'
+ f'"keyword":"{symbol}"'
+ ',"type":["cmsArticleWebOld"],"client":"web","clientType":"web","clientVersion":"curr","param":{"cmsArticleWebOld":{"searchScope":"default","sort":"default",' + f'"pageIndex":{page}'+ ',"pageSize":100,"preTag":"","postTag":""}}}',
"_": "1668256937996",
}
r = requests.get(url, params=params)
data_text = r.text
data_json = json.loads(
data_text.strip("jQuery3510875346244069884_1668256937995(")[:-1]
)
temp_df = pd.DataFrame(data_json["result"]["cmsArticleWebOld"])
temp_df.rename(
columns={
"date": "发布时间",
"mediaName": "文章来源",
"code": "-",
"title": "新闻标题",
"content": "新闻内容",
"url": "新闻链接",
"image": "-",
},
inplace=True,
)
temp_df["关键词"] = symbol
temp_df = temp_df[
[
"关键词",
"新闻标题",
"新闻内容",
"发布时间",
"文章来源",
"新闻链接",
]
]
temp_df["新闻标题"] = (
temp_df["新闻标题"]
.str.replace(r"\(", "", regex=True)
.str.replace(r"\)", "", regex=True)
)
temp_df["新闻标题"] = (
temp_df["新闻标题"]
.str.replace(r"", "", regex=True)
.str.replace(r"", "", regex=True)
)
temp_df["新闻内容"] = (
temp_df["新闻内容"]
.str.replace(r"\(", "", regex=True)
.str.replace(r"\)", "", regex=True)
)
temp_df["新闻内容"] = (
temp_df["新闻内容"]
.str.replace(r"", "", regex=True)
.str.replace(r"", "", regex=True)
)
temp_df["新闻内容"] = temp_df["新闻内容"].str.replace(r"\u3000", "", regex=True)
temp_df["新闻内容"] = temp_df["新闻内容"].str.replace(r"\r\n", " ", regex=True)
return temp_df
def get_news(symbol, max_page = 3):
df_list = []
for page in range(1, max_page):
try:
df_list.append(stock_news_em(symbol, page))
except KeyError:
print(str(symbol) + "pages obtained for symbol: " + page)
break
news_df = pd.concat(df_list, ignore_index=True)
return news_df
def get_cur_return(symbol, start_date, end_date, adjust="qfq"):
# load data
return_data = ak.stock_zh_a_hist(symbol=symbol, period="daily", start_date=start_date, end_date=end_date, adjust=adjust)
# process timestamp
return_data["日期"] = pd.to_datetime(return_data["日期"])
return_data.set_index("日期", inplace=True)
# resample and filled with forward data
weekly_data = return_data["收盘"].resample("W").ffill()
weekly_returns = weekly_data.pct_change()[1:]
weekly_start_prices = weekly_data[:-1]
weekly_end_prices = weekly_data[1:]
weekly_data = pd.DataFrame({
'起始日期': weekly_start_prices.index,
'起始价': weekly_start_prices.values,
'结算日期': weekly_end_prices.index,
'结算价': weekly_end_prices.values,
'周收益': weekly_returns.values
})
weekly_data["简化周收益"] = weekly_data["周收益"].map(return_transform)
# check enddate
if weekly_data.iloc[-1, 2] > pd.to_datetime(end_date):
weekly_data.iloc[-1, 2] = pd.to_datetime(end_date)
return weekly_data
def get_basic(symbol, data):
key_financials = ['报告期', '净利润同比增长率', '营业总收入同比增长率', '流动比率', '速动比率', '资产负债率']
# load quarterly basic data
basic_quarter_financials = ak.stock_financial_abstract_ths(symbol = symbol, indicator="按单季度")
basic_fin_dict = basic_quarter_financials.to_dict("index")
basic_fin_list = [dict([(key, val) for key, val in basic_fin_dict[i].items() if (key in key_financials) and val]) for i in range(len(basic_fin_dict))]
# match basic financial data to news dataframe
matched_basic_fin = []
for i, row in data.iterrows():
newsweek_enddate = row['结算日期'].strftime("%Y-%m-%d")
matched_basic = {}
for basic in basic_fin_list:
# match the most current financial report
if basic["报告期"] < newsweek_enddate:
matched_basic = basic
break
matched_basic_fin.append(json.dumps(matched_basic, ensure_ascii=False))
data['基本面'] = matched_basic_fin
return data
# ------------------------------------------------------------------------------
# Structure Data
# ------------------------------------------------------------------------------
def cur_financial_data(symbol, start_date, end_date, with_basics = True):
# get data
data = get_cur_return(symbol=symbol, start_date=start_date, end_date=end_date)
news_df = get_news(symbol=symbol)
news_df["发布时间"] = pd.to_datetime(news_df["发布时间"], exact=False, format="%Y-%m-%d")
news_df.sort_values(by=["发布时间"], inplace=True)
# match weekly news for return data
news_list = []
for a, row in data.iterrows():
week_start_date = row['起始日期'].strftime('%Y-%m-%d')
week_end_date = row['结算日期'].strftime('%Y-%m-%d')
print(symbol, ': ', week_start_date, ' - ', week_end_date)
weekly_news = news_df.loc[(news_df["发布时间"]>week_start_date) & (news_df["发布时间"] row['起始价'] else '下跌'
chg = map_return_label(row['简化周收益'])
head = "自{}至{},{}的股票价格由{:.2f}{}至{:.2f},涨跌幅为:{}。在此期间的公司新闻如下:\n\n".format(
week_start_date, week_end_date, stock, row['起始价'], term, row['结算价'], chg)
news = json.loads(row["新闻"])
left, right = 0, 0
filtered_news = []
while left < len(news):
n = news[left]
if left == 0:
# check first news quality
if (not(str(n['新闻内容'])[0].isdigit()) and not(str(n['新闻内容'])=='nan') and n['发布时间'][:8] <= week_end_date.replace('-', '')):
filtered_news.append("[新闻标题]:{}\n[新闻内容]:{}\n".format(n['新闻标题'], n['新闻内容']))
left += 1
else:
news_check = check_news_quality(n, last_n = news[right], week_end_date= week_end_date, repeat_rate=0.5)
if news_check:
filtered_news.append("[新闻标题]:{}\n[新闻内容]:{}\n".format(n['新闻标题'], n['新闻内容']))
left += 1
right += 1
basics = json.loads(row['基本面'])
if basics:
basics = "如下所列为{}近期的一些金融基本面信息,记录时间为{}:\n\n[金融基本面]:\n\n".format(
stock, basics['报告期']) + "\n".join(f"{k}: {v}" for k, v in basics.items() if k != 'period')
else:
basics = "[金融基本面]:\n\n 无金融基本面记录"
return head, filtered_news, basics
def get_all_prompts_online(symbol, with_basics=True, max_news_perweek = 3, weeks_before = 2):
end_date = get_curday()
start_date = n_weeks_before(end_date, weeks_before)
company_prompt, stock = get_company_prompt_new(symbol)
data = cur_financial_data(symbol=symbol, start_date=start_date, end_date=end_date, with_basics=with_basics)
prev_rows = []
for row_idx, row in data.iterrows():
head, news, basics = get_prompt_by_row_new(symbol, row)
prev_rows.append((head, news, basics))
prompt = ""
for i in range(-len(prev_rows), 0):
prompt += "\n" + prev_rows[i][0]
sampled_news = sample_news(
prev_rows[i][1],
min(max_news_perweek, len(prev_rows[i][1]))
)
if sampled_news:
prompt += "\n".join(sampled_news)
else:
prompt += "No relative news reported."
next_date = n_weeks_before(end_date, -1, format="%Y-%m-%d")
end_date = pd.to_datetime(end_date).strftime("%Y-%m-%d")
period = "{}至{}".format(end_date, next_date)
if with_basics:
basics = prev_rows[-1][2]
else:
basics = "[金融基本面]:\n\n 无金融基本面记录"
info = company_prompt + '\n' + prompt + '\n' + basics
new_system_prompt = SYSTEM_PROMPT.replace(':\n...', ':\n预测涨跌幅:...\n总结分析:...')
prompt = B_INST + B_SYS + new_system_prompt + E_SYS + info + f"\n\n基于在{end_date}之前的所有信息,让我们首先分析{stock}的积极发展和潜在担忧。请简洁地陈述,分别提出2-4个最重要的因素。大部分所提及的因素应该从公司的相关新闻中推断出来。" \
f"接下来请预测{symbol}下周({period})的股票涨跌幅,并提供一个总结分析来支持你的预测。" + E_INST
return info, prompt
def ask(symbol, weeks_before):
# load inference data
info, pt = get_all_prompts_online(symbol=symbol, weeks_before=weeks_before)
# print(info)
inputs = tokenizer(pt, return_tensors='pt')
inputs = {key: value.to(model.device) for key, value in inputs.items()}
print("Inputs loaded onto devices.")
res = model.generate(
**inputs,
use_cache=True
)
output = tokenizer.decode(res[0], skip_special_tokens=True)
output_cur = re.sub(r'.*\[/INST\]\s*', '', output, flags=re.DOTALL)
return info, output_cur
server = gr.Interface(
ask,
inputs=[
gr.Textbox(
label="Symbol",
value="600519",
info="Companys from SZ50 are recommended"
),
gr.Slider(
minimum=1,
maximum=3,
value=2,
step=1,
label="weeks_before",
info="Due to the token length constraint, you are recommended to input with 2"
),
],
outputs=[
gr.Textbox(
label="Information"
),
gr.Textbox(
label="Response"
)
],
title="FinGPT-Forecaster-Chinese",
description="""This version allows the prediction based on the most current date. We will upgrade it to allow customized date soon."""
)
server.launch()