Dramaturg / app.py
K00B404's picture
Update app.py
85ecc49 verified
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import warnings
import gradio as gr
import os
from gradio_client import Client
# Disable warnings for cleaner output
transformers.logging.set_verbosity_error()
transformers.logging.disable_progress_bar()
warnings.filterwarnings('ignore')
# Set device - will use CUDA if available, otherwise CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.set_default_device(device)
# Model configuration
model_name = 'qnguyen3/nanoLLaVA-1.5'
print(f"Loading model {model_name} on {device}...")
# Create model
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map='auto',
trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True)
print("Model loaded successfully!")
# Initialize the client for the test bot
chatter = "K00B404/transcript_image_generator"
chatbot_client = Client(chatter)
def analyze_character(image_path, analysis_type):
"""
Analyze a character image for dramaturgical insights
Args:
image_path: Path to the character image
analysis_type: Type of character analysis to perform
Returns:
str: The generated character analysis
"""
# Load and process image
try:
image = Image.open(image_path).convert('RGB')
# Resize image to 512x512
image = image.resize((256, 256), Image.Resampling.LANCZOS)
image_tensor = model.process_images([image], model.config).to(dtype=model.dtype)
except Exception as e:
return f"Error processing image: {str(e)}"
# Create prompt based on analysis type
if analysis_type == "full_analysis":
prompt = ("Analyze this character as a dramaturg would. Describe their appearance, "
"potential personality traits, character archetype, suitable roles, and how they might "
"function within a dramatic narrative. Consider costume, posture, expression, and visual symbolism.")
elif analysis_type == "archetype":
prompt = ("Identify the potential character archetype(s) represented in this image. "
"Consider both classical archetypes (hero, mentor, trickster, etc.) and modern "
"interpretations. Explain your reasoning based on visual cues.")
elif analysis_type == "historical_context":
prompt = ("Analyze this character's appearance in terms of historical context. "
"Identify the likely time period, cultural influences, and how these elements "
"would influence the character's role in a dramatic work. Consider costume details, "
"props, and stylistic elements.")
else:
prompt = "Describe this character in detail for dramatic casting purposes."
# Format input for the model using ChatML format
messages = [
{"role": "system", "content": "You are an expert dramaturg with deep knowledge of character analysis, theatrical traditions, and visual storytelling."},
{"role": "user", "content": f'<image>\n{prompt}'}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Split text around image placeholder
text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0)
# Generate response
try:
# Modified generation approach to avoid the cache issue
output_ids = model.generate(
input_ids,
images=image_tensor,
max_new_tokens=1024,
temperature=0.7,
top_p=0.9,
use_cache=False, # Disable caching to avoid the error
do_sample=True) # Enable sampling for more creative outputs
response = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True).strip()
return response
except Exception as e:
# Add fallback generation method if the first method fails
try:
print(f"First generation method failed with: {str(e)}. Trying fallback method...")
# Alternate generation approach
with torch.inference_mode():
output = model.generate(
input_ids,
images=image_tensor,
max_new_tokens=1024,
do_sample=True,
top_p=0.9,
temperature=0.7,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id
)
response = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True).strip()
return response
except Exception as e2:
return f"Error generating analysis: {str(e)}\nFallback also failed: {str(e2)}\n\nPlease try a different image or check model compatibility."
def chat_with_persona(message, history, system_message, max_tokens, temperature, top_p):
"""Function to interact with the chatbot API using the generated persona"""
try:
# Call the API with the current message and system prompt (persona)
response = chatbot_client.predict(
message=message,
system_message=system_message,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
api_name="/chat"
)
return response
except Exception as e:
return f"Error communicating with the chatbot API: {str(e)}"
# Create Gradio interface
def create_ui():
with gr.Blocks(title="Dramaturg Character Analyzer") as demo:
# Store the current analysis result for sharing between tabs
analysis_result = gr.State("")
with gr.Tabs() as tabs:
# First tab: Character analysis
with gr.TabItem("Character Analysis"):
gr.Markdown("# Dramaturg Character Analyzer")
gr.Markdown("Upload a character image to receive a dramaturgical analysis")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="filepath", label="Upload Character Image")
analysis_type = gr.Radio(
["full_analysis", "archetype", "historical_context", "basic_description"],
label="Analysis Type",
value="full_analysis"
)
analyze_btn = gr.Button("Analyze Character")
with gr.Column():
output_text = gr.Textbox(label="Character Analysis", lines=20)
copy_to_test_btn = gr.Button("Copy to Test Bot", interactive=False)
def update_analysis_result(result):
# Enable the copy button once we have a result
return result, True
analyze_btn.click(
fn=analyze_character,
inputs=[input_image, analysis_type],
outputs=[output_text, copy_to_test_btn]
)
def copy_to_test(result):
# Update the state and switch to the test tab
return result, 1
copy_to_test_btn.click(
fn=copy_to_test,
inputs=[output_text],
outputs=[analysis_result, tabs]
)
# Second tab: Test bot integration
with gr.TabItem("Test Bot"):
gr.Markdown("# Test Your Character Persona")
gr.Markdown("The character analysis will be used as the system prompt for the test bot.")
with gr.Row():
with gr.Column():
system_prompt = gr.Textbox(label="System Prompt (Character Persona)", lines=10)
with gr.Row():
max_tokens = gr.Slider(minimum=100, maximum=4000, value=1000, step=100, label="Max Tokens")
temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top P")
user_input = gr.Textbox(label="Your message", placeholder="Ask something about the character...")
send_btn = gr.Button("Send Message")
with gr.Column():
chatbot = gr.Chatbot(label="Conversation")
def update_system_prompt(result):
return result
# Update the system prompt when switching to this tab with an analysis result
demo.load(
fn=update_system_prompt,
inputs=[analysis_result],
outputs=[system_prompt]
)
# Chat history for the test bot
chat_history = []
def respond(message, history, system_message, max_tokens_val, temperature_val, top_p_val):
# Add the user message to history
history.append((message, ""))
# Get response from the test bot
response = chat_with_persona(
message=message,
history=history,
system_message=system_message,
max_tokens=max_tokens_val,
temperature=temperature_val,
top_p=top_p_val
)
# Update the last history item with the response
history[-1] = (message, response)
return "", history
send_btn.click(
fn=respond,
inputs=[user_input, chatbot, system_prompt, max_tokens, temperature, top_p],
outputs=[user_input, chatbot]
)
# Also trigger on pressing Enter in the input box
user_input.submit(
fn=respond,
inputs=[user_input, chatbot, system_prompt, max_tokens, temperature, top_p],
outputs=[user_input, chatbot]
)
return demo
# Main function
if __name__ == "__main__":
demo = create_ui()
demo.launch(share=True)
print("Dramaturg Character Analyzer is now running with Test Bot integration!")