import spaces import gradio as gr from transformers import AutoTokenizer, pipeline import torch import logging import asyncio from functools import partial # ロギング設定 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # モデル定義 classification_model_name = "unitary/toxic-bert" generation_model_name = "distilgpt2" # 軽量なテキスト生成モデル logger.info("Starting model loading...") # 分類モデルのロード logger.info(f"Loading classification model: {classification_model_name}") classification_tokenizer = AutoTokenizer.from_pretrained(classification_model_name) classification_pipeline = pipeline( "text-classification", model=classification_model_name, tokenizer=classification_tokenizer, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto" ) logger.info(f"Classification model loaded successfully: {classification_model_name}") # 生成モデルのロード logger.info(f"Loading generation model: {generation_model_name}") generation_tokenizer = AutoTokenizer.from_pretrained(generation_model_name) generation_pipeline = pipeline( "text-generation", model=generation_model_name, tokenizer=generation_tokenizer, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto" ) logger.info(f"Generation model loaded successfully: {generation_model_name}") # 非同期で分類を実行する関数 async def classify_text_async(prompt): logger.info(f"Running classification for: {prompt[:50]}...") # CPUバウンドな処理を非同期実行するためにループの外で実行 loop = asyncio.get_event_loop() classification_result = await loop.run_in_executor( None, lambda: classification_pipeline(prompt) ) logger.info(f"Classification complete: {classification_result}") return classification_result # 非同期で生成を実行する関数 async def generate_text_async(prompt): logger.info(f"Running text generation for: {prompt[:50]}...") loop = asyncio.get_event_loop() generation_result = await loop.run_in_executor( None, lambda: generation_pipeline( prompt, max_new_tokens=50, do_sample=True, temperature=0.7, num_return_sequences=1 ) ) generated_text = generation_result[0]["generated_text"] logger.info(f"Text generation complete, generated: {len(generated_text)} chars") return generated_text # GPUを利用する非同期推論関数 @spaces.GPU(duration=120) async def process_text_async(prompt): logger.info(f"Processing input asynchronously: {prompt[:50]}...") # 両方のタスクを並行して実行 classification_task = classify_text_async(prompt) generation_task = generate_text_async(prompt) # 両方のタスクが完了するのを待つ classification_result, generated_text = await asyncio.gather( classification_task, generation_task ) # 結果を組み合わせて返す combined_result = f"分類結果: {classification_result}\n\n生成されたテキスト: {generated_text}" return combined_result # Gradioは非同期関数にも対応しているので、そのまま渡す demo = gr.Interface( fn=process_text_async, # 非同期関数を使用 inputs=gr.Textbox(lines=3, label="入力テキスト"), outputs=gr.Textbox(label="処理結果", lines=8), title="テキスト分類 & 生成デモ (非同期版)", description="入力テキストに対して分類と生成を非同期で並行実行します。" ) # アプリの起動 logger.info("Starting application...") demo.launch()