import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer import streamlit as st def generate_blog(title, model_name='gpt2', max_length=500): # Check if a GPU is available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") st.write(f"Using device: {device}") # Load the tokenizer and model tokenizer = GPT2Tokenizer.from_pretrained(model_name) model = GPT2LMHeadModel.from_pretrained(model_name).to(device) prompt = f"Write a blog post based on this Title: {title}" # Prepare the input input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) # Generate text output = model.generate(input_ids, max_length=max_length, num_return_sequences=1, no_repeat_ngram_size=2, early_stopping=True) # Decode the generated text blog_post = tokenizer.decode(output[0], skip_special_tokens=True) return blog_post st.title("AI Blog Writer") st.write("Enter a blog title, and the AI will generate a blog post for you!") title = st.text_input("Enter the blog title:") if st.button("Generate Blog"): if title: with st.spinner("Generating blog post..."): blog_post = generate_blog(title) st.subheader("Generated Blog Post") st.write(blog_post) else: st.warning("Please enter a blog title.")