|
import gradio as gr |
|
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, TextStreamer |
|
from peft import PeftModel |
|
import re |
|
import os |
|
import akshare as ak |
|
import pandas as pd |
|
import random |
|
import json |
|
import requests |
|
import math |
|
from datetime import date |
|
from datetime import date, datetime, timedelta |
|
|
|
|
|
access_token = os.environ["TOKEN"] |
|
|
|
|
|
model = "meta-llama/Llama-2-7b-chat-hf" |
|
peft_model = "FinGPT/fingpt-forecaster_sz50_llama2-7B_lora" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model, token = access_token, trust_remote_code=True) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
tokenizer.padding_side = "right" |
|
|
|
streamer = TextStreamer(tokenizer) |
|
|
|
model = AutoModelForCausalLM.from_pretrained(model, trust_remote_code=True, token = access_token, device_map="cuda", load_in_8bit=True, offload_folder="offload/") |
|
model = PeftModel.from_pretrained(model, peft_model, offload_folder="offload/") |
|
|
|
model = model.eval() |
|
|
|
|
|
|
|
|
|
B_INST, E_INST = "[INST]", "[/INST]" |
|
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n" |
|
SYSTEM_PROMPT = "你是一名经验丰富的股票市场分析师。你的任务是根据公司在过去几周内的相关新闻和季度财务状况,列出公司的积极发展和潜在担忧,然后结合你对整体金融经济市场的判断,对公司未来一周的股价变化提供预测和分析。" \ |
|
"你的回答语言应为中文。你的回答格式应该如下:\n\n[积极发展]:\n1. ...\n\n[潜在担忧]:\n1. ...\n\n[预测和分析]:\n...\n" |
|
|
|
|
|
|
|
|
|
def get_curday(): |
|
|
|
return date.today().strftime("%Y%m%d") |
|
|
|
def n_weeks_before(date_string, n, format = "%Y%m%d"): |
|
|
|
date = datetime.strptime(date_string, "%Y%m%d") - timedelta(days=7*n) |
|
|
|
return date.strftime(format=format) |
|
|
|
def check_news_quality(n, last_n, week_end_date, repeat_rate = 0.6): |
|
try: |
|
|
|
if not (not(str(n['新闻内容'])[0].isdigit()) and not(str(n['新闻内容'])=='nan') and n['发布时间'][:8] <= week_end_date.replace('-', '')): |
|
return False |
|
|
|
|
|
|
|
elif str(last_n['新闻内容'])=='nan': |
|
return True |
|
elif len(set(n['新闻内容'][:20]) & set(last_n['新闻内容'][:20])) >= 20*repeat_rate or len(set(n['新闻标题']) & set(last_n['新闻标题']))/len(last_n['新闻标题']) > repeat_rate: |
|
return False |
|
|
|
else: |
|
return True |
|
except TypeError: |
|
print(n) |
|
print(last_n) |
|
raise Exception("Check Error") |
|
|
|
def sample_news(news, k=5): |
|
|
|
return [news[i] for i in sorted(random.sample(range(len(news)), k))] |
|
|
|
def return_transform(ret): |
|
|
|
up_down = '涨' if ret >= 0 else '跌' |
|
integer = math.ceil(abs(100 * ret)) |
|
if integer == 0: |
|
return "平" |
|
|
|
return up_down + (str(integer) if integer <= 5 else '5+') |
|
|
|
def map_return_label(return_lb): |
|
|
|
lb = return_lb.replace('涨', '上涨') |
|
lb = lb.replace('跌', '下跌') |
|
lb = lb.replace('平', '股价持平') |
|
lb = lb.replace('1', '0-1%') |
|
lb = lb.replace('2', '1-2%') |
|
lb = lb.replace('3', '2-3%') |
|
lb = lb.replace('4', '3-4%') |
|
if lb.endswith('+'): |
|
lb = lb.replace('5+', '超过5%') |
|
else: |
|
lb = lb.replace('5', '4-5%') |
|
|
|
return lb |
|
|
|
|
|
|
|
def stock_news_em(symbol: str = "300059", page = 1) -> pd.DataFrame: |
|
|
|
url = "https://search-api-web.eastmoney.com/search/jsonp" |
|
params = { |
|
"cb": "jQuery3510875346244069884_1668256937995", |
|
"param": '{"uid":"",' |
|
+ f'"keyword":"{symbol}"' |
|
+ ',"type":["cmsArticleWebOld"],"client":"web","clientType":"web","clientVersion":"curr","param":{"cmsArticleWebOld":{"searchScope":"default","sort":"default",' + f'"pageIndex":{page}'+ ',"pageSize":100,"preTag":"<em>","postTag":"</em>"}}}', |
|
"_": "1668256937996", |
|
} |
|
r = requests.get(url, params=params) |
|
data_text = r.text |
|
data_json = json.loads( |
|
data_text.strip("jQuery3510875346244069884_1668256937995(")[:-1] |
|
) |
|
temp_df = pd.DataFrame(data_json["result"]["cmsArticleWebOld"]) |
|
temp_df.rename( |
|
columns={ |
|
"date": "发布时间", |
|
"mediaName": "文章来源", |
|
"code": "-", |
|
"title": "新闻标题", |
|
"content": "新闻内容", |
|
"url": "新闻链接", |
|
"image": "-", |
|
}, |
|
inplace=True, |
|
) |
|
temp_df["关键词"] = symbol |
|
temp_df = temp_df[ |
|
[ |
|
"关键词", |
|
"新闻标题", |
|
"新闻内容", |
|
"发布时间", |
|
"文章来源", |
|
"新闻链接", |
|
] |
|
] |
|
temp_df["新闻标题"] = ( |
|
temp_df["新闻标题"] |
|
.str.replace(r"\(<em>", "", regex=True) |
|
.str.replace(r"</em>\)", "", regex=True) |
|
) |
|
temp_df["新闻标题"] = ( |
|
temp_df["新闻标题"] |
|
.str.replace(r"<em>", "", regex=True) |
|
.str.replace(r"</em>", "", regex=True) |
|
) |
|
temp_df["新闻内容"] = ( |
|
temp_df["新闻内容"] |
|
.str.replace(r"\(<em>", "", regex=True) |
|
.str.replace(r"</em>\)", "", regex=True) |
|
) |
|
temp_df["新闻内容"] = ( |
|
temp_df["新闻内容"] |
|
.str.replace(r"<em>", "", regex=True) |
|
.str.replace(r"</em>", "", regex=True) |
|
) |
|
temp_df["新闻内容"] = temp_df["新闻内容"].str.replace(r"\u3000", "", regex=True) |
|
temp_df["新闻内容"] = temp_df["新闻内容"].str.replace(r"\r\n", " ", regex=True) |
|
return temp_df |
|
|
|
def get_news(symbol, max_page = 3): |
|
|
|
df_list = [] |
|
for page in range(1, max_page): |
|
|
|
try: |
|
df_list.append(stock_news_em(symbol, page)) |
|
except KeyError: |
|
print(str(symbol) + "pages obtained for symbol: " + page) |
|
break |
|
|
|
news_df = pd.concat(df_list, ignore_index=True) |
|
return news_df |
|
|
|
def get_cur_return(symbol, start_date, end_date, adjust="qfq"): |
|
|
|
|
|
return_data = ak.stock_zh_a_hist(symbol=symbol, period="daily", start_date=start_date, end_date=end_date, adjust=adjust) |
|
|
|
|
|
return_data["日期"] = pd.to_datetime(return_data["日期"]) |
|
return_data.set_index("日期", inplace=True) |
|
|
|
|
|
weekly_data = return_data["收盘"].resample("W").ffill() |
|
weekly_returns = weekly_data.pct_change()[1:] |
|
weekly_start_prices = weekly_data[:-1] |
|
weekly_end_prices = weekly_data[1:] |
|
weekly_data = pd.DataFrame({ |
|
'起始日期': weekly_start_prices.index, |
|
'起始价': weekly_start_prices.values, |
|
'结算日期': weekly_end_prices.index, |
|
'结算价': weekly_end_prices.values, |
|
'周收益': weekly_returns.values |
|
}) |
|
weekly_data["简化周收益"] = weekly_data["周收益"].map(return_transform) |
|
|
|
if weekly_data.iloc[-1, 2] > pd.to_datetime(end_date): |
|
weekly_data.iloc[-1, 2] = pd.to_datetime(end_date) |
|
|
|
return weekly_data |
|
|
|
def get_basic(symbol, data): |
|
|
|
key_financials = ['报告期', '净利润同比增长率', '营业总收入同比增长率', '流动比率', '速动比率', '资产负债率'] |
|
|
|
|
|
basic_quarter_financials = ak.stock_financial_abstract_ths(symbol = symbol, indicator="按单季度") |
|
basic_fin_dict = basic_quarter_financials.to_dict("index") |
|
basic_fin_list = [dict([(key, val) for key, val in basic_fin_dict[i].items() if (key in key_financials) and val]) for i in range(len(basic_fin_dict))] |
|
|
|
|
|
matched_basic_fin = [] |
|
for i, row in data.iterrows(): |
|
|
|
newsweek_enddate = row['结算日期'].strftime("%Y-%m-%d") |
|
|
|
matched_basic = {} |
|
for basic in basic_fin_list: |
|
|
|
if basic["报告期"] < newsweek_enddate: |
|
matched_basic = basic |
|
break |
|
matched_basic_fin.append(json.dumps(matched_basic, ensure_ascii=False)) |
|
|
|
data['基本面'] = matched_basic_fin |
|
|
|
return data |
|
|
|
|
|
|
|
def cur_financial_data(symbol, start_date, end_date, with_basics = True): |
|
|
|
|
|
data = get_cur_return(symbol=symbol, start_date=start_date, end_date=end_date) |
|
|
|
news_df = get_news(symbol=symbol) |
|
news_df["发布时间"] = pd.to_datetime(news_df["发布时间"], exact=False, format="%Y-%m-%d") |
|
news_df.sort_values(by=["发布时间"], inplace=True) |
|
|
|
|
|
news_list = [] |
|
for a, row in data.iterrows(): |
|
week_start_date = row['起始日期'].strftime('%Y-%m-%d') |
|
week_end_date = row['结算日期'].strftime('%Y-%m-%d') |
|
print(symbol, ': ', week_start_date, ' - ', week_end_date) |
|
|
|
weekly_news = news_df.loc[(news_df["发布时间"]>week_start_date) & (news_df["发布时间"]<week_end_date)] |
|
|
|
weekly_news = [ |
|
{ |
|
"发布时间": n["发布时间"].strftime('%Y%m%d'), |
|
"新闻标题": n['新闻标题'], |
|
"新闻内容": n['新闻内容'], |
|
} for a, n in weekly_news.iterrows() |
|
] |
|
news_list.append(json.dumps(weekly_news,ensure_ascii=False)) |
|
|
|
data["新闻"] = news_list |
|
|
|
if with_basics: |
|
data = get_basic(symbol=symbol, data=data) |
|
|
|
else: |
|
data['新闻'] = [json.dumps({})] * len(data) |
|
|
|
|
|
return data |
|
|
|
|
|
|
|
def get_company_prompt_new(symbol): |
|
try: |
|
company_profile = dict(ak.stock_individual_info_em(symbol).values) |
|
except: |
|
print("Company Info Request Time Out! Please wait and retry.") |
|
company_profile["上市时间"] = pd.to_datetime(str(company_profile["上市时间"])).strftime("%Y年%m月%d日") |
|
|
|
template = "[公司介绍]:\n\n{股票简称}是一家在{行业}行业的领先实体,自{上市时间}成立并公开交易。截止今天,{股票简称}的总市值为{总市值}人民币,总股本数为{总股本},流通市值为{流通市值}人民币,流通股数为{流通股}。" \ |
|
"\n\n{股票简称}主要在中国运营,以股票代码{股票代码}在交易所进行交易。" |
|
|
|
formatted_profile = template.format(**company_profile) |
|
stockname = company_profile['股票简称'] |
|
return formatted_profile, stockname |
|
|
|
def get_prompt_by_row_new(stock, row): |
|
|
|
week_start_date = row['起始日期'] if isinstance(row['起始日期'], str) else row['起始日期'].strftime('%Y-%m-%d') |
|
week_end_date = row['结算日期'] if isinstance(row['结算日期'], str) else row['结算日期'].strftime('%Y-%m-%d') |
|
term = '上涨' if row['结算价'] > row['起始价'] else '下跌' |
|
chg = map_return_label(row['简化周收益']) |
|
head = "自{}至{},{}的股票价格由{:.2f}{}至{:.2f},涨跌幅为:{}。在此期间的公司新闻如下:\n\n".format( |
|
week_start_date, week_end_date, stock, row['起始价'], term, row['结算价'], chg) |
|
|
|
news = json.loads(row["新闻"]) |
|
|
|
left, right = 0, 0 |
|
filtered_news = [] |
|
while left < len(news): |
|
n = news[left] |
|
|
|
if left == 0: |
|
|
|
if (not(str(n['新闻内容'])[0].isdigit()) and not(str(n['新闻内容'])=='nan') and n['发布时间'][:8] <= week_end_date.replace('-', '')): |
|
filtered_news.append("[新闻标题]:{}\n[新闻内容]:{}\n".format(n['新闻标题'], n['新闻内容'])) |
|
left += 1 |
|
|
|
else: |
|
news_check = check_news_quality(n, last_n = news[right], week_end_date= week_end_date, repeat_rate=0.5) |
|
if news_check: |
|
filtered_news.append("[新闻标题]:{}\n[新闻内容]:{}\n".format(n['新闻标题'], n['新闻内容'])) |
|
left += 1 |
|
right += 1 |
|
|
|
|
|
basics = json.loads(row['基本面']) |
|
if basics: |
|
basics = "如下所列为{}近期的一些金融基本面信息,记录时间为{}:\n\n[金融基本面]:\n\n".format( |
|
stock, basics['报告期']) + "\n".join(f"{k}: {v}" for k, v in basics.items() if k != 'period') |
|
else: |
|
basics = "[金融基本面]:\n\n 无金融基本面记录" |
|
|
|
return head, filtered_news, basics |
|
|
|
def get_all_prompts_online(symbol, with_basics=True, max_news_perweek = 3, weeks_before = 2): |
|
|
|
end_date = get_curday() |
|
start_date = n_weeks_before(end_date, weeks_before) |
|
|
|
company_prompt, stock = get_company_prompt_new(symbol) |
|
data = cur_financial_data(symbol=symbol, start_date=start_date, end_date=end_date, with_basics=with_basics) |
|
|
|
prev_rows = [] |
|
|
|
for row_idx, row in data.iterrows(): |
|
head, news, basics = get_prompt_by_row_new(symbol, row) |
|
prev_rows.append((head, news, basics)) |
|
|
|
prompt = "" |
|
for i in range(-len(prev_rows), 0): |
|
prompt += "\n" + prev_rows[i][0] |
|
sampled_news = sample_news( |
|
prev_rows[i][1], |
|
min(max_news_perweek, len(prev_rows[i][1])) |
|
) |
|
if sampled_news: |
|
prompt += "\n".join(sampled_news) |
|
else: |
|
prompt += "No relative news reported." |
|
|
|
next_date = n_weeks_before(end_date, -1, format="%Y-%m-%d") |
|
end_date = pd.to_datetime(end_date).strftime("%Y-%m-%d") |
|
period = "{}至{}".format(end_date, next_date) |
|
|
|
if with_basics: |
|
basics = prev_rows[-1][2] |
|
else: |
|
basics = "[金融基本面]:\n\n 无金融基本面记录" |
|
|
|
info = company_prompt + '\n' + prompt + '\n' + basics |
|
|
|
new_system_prompt = SYSTEM_PROMPT.replace(':\n...', ':\n预测涨跌幅:...\n总结分析:...') |
|
prompt = B_INST + B_SYS + new_system_prompt + E_SYS + info + f"\n\n基于在{end_date}之前的所有信息,让我们首先分析{stock}的积极发展和潜在担忧。请简洁地陈述,分别提出2-4个最重要的因素。大部分所提及的因素应该从公司的相关新闻中推断出来。" \ |
|
f"接下来请预测{symbol}下周({period})的股票涨跌幅,并提供一个总结分析来支持你的预测。" + E_INST |
|
|
|
del prev_rows |
|
del data |
|
|
|
return info, prompt |
|
|
|
|
|
def ask(symbol, weeks_before): |
|
|
|
|
|
info, pt = get_all_prompts_online(symbol=symbol, weeks_before=weeks_before) |
|
|
|
|
|
inputs = tokenizer(pt, return_tensors='pt') |
|
inputs = {key: value.to(model.device) for key, value in inputs.items()} |
|
print("Inputs loaded onto devices.") |
|
|
|
res = model.generate( |
|
**inputs, |
|
use_cache=True, |
|
max_length = 4096, |
|
streamer=streamer |
|
) |
|
output = tokenizer.decode(res[0], skip_special_tokens=True) |
|
output_cur = re.sub(r'.*\[/INST\]\s*', '', output, flags=re.DOTALL) |
|
return info, output_cur |
|
|
|
server = gr.Interface( |
|
ask, |
|
inputs=[ |
|
gr.Textbox( |
|
label="Symbol", |
|
value="600519", |
|
info="Companys from SZ50 are recommended" |
|
), |
|
gr.Slider( |
|
minimum=1, |
|
maximum=3, |
|
value=2, |
|
step=1, |
|
label="weeks_before", |
|
info="Due to the token length constraint, you are recommended to input with 2" |
|
), |
|
], |
|
outputs=[ |
|
gr.Textbox( |
|
label="Information" |
|
), |
|
gr.Textbox( |
|
label="Response" |
|
) |
|
], |
|
title="FinGPT-Forecaster-Chinese", |
|
description="""This version allows the prediction based on the most current date. We will upgrade it to allow customized date soon.""" |
|
) |
|
|
|
server.launch() |