Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from pydub import AudioSegment | |
| import json | |
| import uuid | |
| import edge_tts | |
| import asyncio | |
| import aiofiles | |
| import os | |
| import time | |
| import mimetypes | |
| from typing import List, Dict | |
| # NEW – Hugging Face Transformers | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| # NEW – external model id | |
| MODEL_ID = "tabularisai/german-gemma-3-1b-it" | |
| # Constants | |
| MAX_FILE_SIZE_MB = 20 | |
| MAX_FILE_SIZE_BYTES = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes | |
| class PodcastGenerator: | |
| def __init__(self): | |
| self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto", | |
| ).eval() | |
| async def generate_script( | |
| self, | |
| prompt: str, | |
| language: str, | |
| api_key: str, | |
| file_obj=None, | |
| progress=None, | |
| ) -> Dict: | |
| example = """ | |
| { | |
| "topic": "AGI", | |
| "podcast": [ | |
| { | |
| "speaker": 2, | |
| "line": "So, AGI, huh? Seems like everyone's talking about it these days." | |
| }, | |
| { | |
| "speaker": 1, | |
| "line": "Yeah, it's definitely having a moment, isn't it?" | |
| } | |
| ] | |
| } | |
| """ | |
| if language == "Auto Detect": | |
| language_instruction = ( | |
| "- The podcast MUST be in the same language as the user input." | |
| ) | |
| else: | |
| language_instruction = f"- The podcast MUST be in {language} language" | |
| system_prompt = f""" | |
| You are a professional podcast generator. Your task is to generate a professional podcast script based on the user input. | |
| {language_instruction} | |
| - The podcast should have 2 speakers. | |
| - The podcast should be long. | |
| - Do not use names for the speakers. | |
| - The podcast should be interesting, lively, and engaging, and hook the listener from the start. | |
| - The input text might be disorganized or unformatted, originating from sources like PDFs or text files. Ignore any formatting inconsistencies or irrelevant details; your task is to distill the essential points, identify key definitions, and highlight intriguing facts that would be suitable for discussion in a podcast. | |
| - The script must be in JSON format. | |
| Follow this example structure: | |
| {example} | |
| """ | |
| if prompt and file_obj: | |
| user_prompt = ( | |
| f"Please generate a podcast script based on the uploaded file following user input:\n{prompt}" | |
| ) | |
| elif prompt: | |
| user_prompt = ( | |
| f"Please generate a podcast script based on the following user input:\n{prompt}" | |
| ) | |
| else: | |
| user_prompt = "Please generate a podcast script based on the uploaded file." | |
| # If a file is provided we still read it for completeness (not required for HF generation) | |
| if file_obj: | |
| _ = await self._read_file_bytes(file_obj) | |
| if progress: | |
| progress(0.3, "Generating podcast script...") | |
| inputs = self.tokenizer( | |
| f"{system_prompt}\n\n{user_prompt}", return_tensors="pt" | |
| ).to(self.model.device) | |
| try: | |
| output = self.model.generate(**inputs, max_new_tokens=2048, temperature=1.0) | |
| response_text = self.tokenizer.decode(output[0], skip_special_tokens=True) | |
| except Exception as e: | |
| raise Exception(f"Failed to generate podcast script: {e}") | |
| print(f"Generated podcast script:\n{response_text}") | |
| if progress: | |
| progress(0.4, "Script generated successfully!") | |
| return json.loads(response_text) | |
| async def _read_file_bytes(self, file_obj) -> bytes: | |
| if hasattr(file_obj, "size"): | |
| file_size = file_obj.size | |
| else: | |
| file_size = os.path.getsize(file_obj.name) | |
| if file_size > MAX_FILE_SIZE_BYTES: | |
| raise Exception( | |
| f"File size exceeds the {MAX_FILE_SIZE_MB}MB limit. Please upload a smaller file." | |
| ) | |
| if hasattr(file_obj, "read"): | |
| return file_obj.read() | |
| else: | |
| async with aiofiles.open(file_obj.name, "rb") as f: | |
| return await f.read() | |
| def _get_mime_type(filename: str) -> str: | |
| ext = os.path.splitext(filename)[1].lower() | |
| if ext == ".pdf": | |
| return "application/pdf" | |
| elif ext == ".txt": | |
| return "text/plain" | |
| else: | |
| mime_type, _ = mimetypes.guess_type(filename) | |
| return mime_type or "application/octet-stream" | |
| # Re-add UI definition for Gradio | |
| async def generate_interface(prompt, language, api_key, file): | |
| gen = PodcastGenerator() | |
| result = await gen.generate_script(prompt, language, api_key, file) | |
| return json.dumps(result, indent=2) | |
| interface = gr.Interface( | |
| fn=generate_interface, | |
| inputs=[ | |
| gr.Textbox(label="Prompt"), | |
| gr.Radio(["English", "German", "Auto Detect"], label="Language", value="Auto Detect"), | |
| gr.Textbox(label="API Key", type="password"), | |
| gr.File(label="Upload File (optional)") | |
| ], | |
| outputs=gr.Textbox(label="Generated Podcast JSON"), | |
| title="Podcast Generator using Gemma", | |
| description="Generate a lively podcast script from your input text or uploaded file using the tabularisai/german-gemma-3-1b-it model." | |
| ) | |
| if __name__ == "__main__": | |
| interface.launch() | |