adinarayana commited on
Commit
6f020bb
·
verified ·
1 Parent(s): 9149630

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -61,16 +61,16 @@ def answer_question(text, question, max_length=512):
61
 
62
  start_logits, end_logits = qa_model(inputs)
63
 
 
 
 
64
  # Check and ensure data type of start_logits:
65
  if start_logits.dtype not in (tf.float32, tf.int32):
66
  start_logits = tf.cast(start_logits, tf.float32) # Example casting to float32
67
 
68
  # Verify axis type:
69
- if not isinstance(start_logits, tf.Tensor):
70
- raise ValueError("start_logits is not a TensorFlow tensor")
71
-
72
- if start_logits.dtype not in (tf.float32, tf.int32):
73
- start_logits = tf.cast(start_logits, tf.float32)
74
 
75
  # Ensure compatibility for argmax (e.g., non-empty tensor):
76
  if start_logits.shape[0] == 0:
@@ -84,6 +84,7 @@ def answer_question(text, question, max_length=512):
84
  return answer if answer else "No answer found."
85
 
86
 
 
87
  ## Streamlit app
88
 
89
  st.set_page_config(page_title="Enhanced PDF Summarizer")
 
61
 
62
  start_logits, end_logits = qa_model(inputs)
63
 
64
+ # Ensure start_logits is a tensor
65
+ start_logits = tf.convert_to_tensor(start_logits)
66
+
67
  # Check and ensure data type of start_logits:
68
  if start_logits.dtype not in (tf.float32, tf.int32):
69
  start_logits = tf.cast(start_logits, tf.float32) # Example casting to float32
70
 
71
  # Verify axis type:
72
+ if not isinstance(axis, tf.Tensor) or axis.dtype not in (tf.int32, tf.int64):
73
+ axis = tf.constant(1, dtype=tf.int32) # Replace with correct axis if needed
 
 
 
74
 
75
  # Ensure compatibility for argmax (e.g., non-empty tensor):
76
  if start_logits.shape[0] == 0:
 
84
  return answer if answer else "No answer found."
85
 
86
 
87
+
88
  ## Streamlit app
89
 
90
  st.set_page_config(page_title="Enhanced PDF Summarizer")