K00B404 commited on
Commit
58ac2a8
·
verified ·
1 Parent(s): 2445f7f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -3
app.py CHANGED
@@ -85,18 +85,38 @@ def analyze_character(image_path, analysis_type):
85
 
86
  # Generate response
87
  try:
 
88
  output_ids = model.generate(
89
  input_ids,
90
  images=image_tensor,
91
  max_new_tokens=1024,
92
  temperature=0.7,
93
  top_p=0.9,
94
- use_cache=True)[0]
 
95
 
96
- response = tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
97
  return response
98
  except Exception as e:
99
- return f"Error generating analysis: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  # Create Gradio interface
102
  def create_ui():
 
85
 
86
  # Generate response
87
  try:
88
+ # Modified generation approach to avoid the cache issue
89
  output_ids = model.generate(
90
  input_ids,
91
  images=image_tensor,
92
  max_new_tokens=1024,
93
  temperature=0.7,
94
  top_p=0.9,
95
+ use_cache=False, # Disable caching to avoid the error
96
+ do_sample=True) # Enable sampling for more creative outputs
97
 
98
+ response = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True).strip()
99
  return response
100
  except Exception as e:
101
+ # Add fallback generation method if the first method fails
102
+ try:
103
+ print(f"First generation method failed with: {str(e)}. Trying fallback method...")
104
+ # Alternate generation approach
105
+ with torch.inference_mode():
106
+ output = model.generate(
107
+ input_ids,
108
+ images=image_tensor,
109
+ max_new_tokens=1024,
110
+ do_sample=True,
111
+ top_p=0.9,
112
+ temperature=0.7,
113
+ eos_token_id=tokenizer.eos_token_id,
114
+ pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id
115
+ )
116
+ response = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True).strip()
117
+ return response
118
+ except Exception as e2:
119
+ return f"Error generating analysis: {str(e)}\nFallback also failed: {str(e2)}\n\nPlease try a different image or check model compatibility."
120
 
121
  # Create Gradio interface
122
  def create_ui():