tskolm's picture
Update app.py
3bfacb1
import numpy as np
import os
import streamlit as st
import sys
import urllib
import json
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
def generate(tokenizer, model, text, features):
generated = tokenizer("<|startoftext|><|titlestart|>{}<|titleend|><|authornamebegin|>".format(text), return_tensors="pt").input_ids
count = 0
while count < features['num']:
sample_outputs = model.generate(
generated, do_sample=True, top_k=50,
max_length=features['max_length'], top_p=features['top_p'], temperature=features['t'] / 100.0, num_return_sequences=1,
)
decoded = tokenizer.decode(sample_outputs[0], skip_special_tokens=False)
print(decoded, file=sys.stderr)
if '<|authornamebegin|>' not in decoded:
continue
raw = decoded.split('<|authornamebegin|>')[-1]
if '<|authornameend|>' not in raw:
continue
end_name = raw.split('<|authornameend|>')
author = end_name[-2]
text = end_name[-1]
count += 1
st.markdown('**' + author.strip() + '**: ' + text.replace('<|endoftext|>', '').replace('<|pad|>', '').strip())
def load_model():
additional_special_tokens = ['<|titlestart|>', '<|titleend|>', '<|authornamebegin|>', '<|authornameend|>']
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium', bos_token='<|startoftext|>',
eos_token='<|endoftext|>', pad_token='<|pad|>',
additional_special_tokens=additional_special_tokens)
config = GPT2Config.from_json_file('./config.json')
model = GPT2LMHeadModel(config)
state_dict = torch.load('./pytorch_model.bin', map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
return tokenizer, model
def main():
tokenizer, model = load_model()
st.title("YouTube comments generating project")
st.header('YouTube comments generator')
st.sidebar.title("Features")
seed = 27834096
default_control_features = ["Количество комментариев", "Температура", "Top-p"]
control_features = default_control_features
# Insert user-controlled values from sliders into the feature vector.
features = {
"num": st.sidebar.slider("Количество комментариев", 0, 20, 1, 1),
"t": st.sidebar.slider("Температура", 0, 300, 180, 1),
"top_p": st.sidebar.slider("Top-p", 0, 100, 95, 5),
"max_length": st.sidebar.slider("Максимальная длина комментария", 0, 300, 100, 5),
}
st.sidebar.title("Note")
st.sidebar.write(
"""
Изменяя значения, можно получить различные выводы модели
"""
)
st.sidebar.write(
"""
Значение температуры делится на 100
"""
)
st.sidebar.caption(f"Streamlit version `{st.__version__}`")
with st.form(key='my_form'):
url = st.text_input('Введите url видео на YouTube')
st.form_submit_button('Готово!')
if url:
params = {"format": "json", "url": url}
base_url = "https://www.youtube.com/oembed"
query_string = urllib.parse.urlencode(params)
base_url = base_url + "?" + query_string
with urllib.request.urlopen(base_url) as response:
response_text = response.read()
data = json.loads(response_text.decode())
st.write('Video Title: ' + data['title'])
st.video(url)
generate(tokenizer, model, data['title'], features)
if __name__ == "__main__":
main()