adinarayana commited on
Commit
09d71b7
·
verified ·
1 Parent(s): 6f020bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -12
app.py CHANGED
@@ -59,23 +59,20 @@ def answer_question(text, question, max_length=512):
59
  question, text, return_tensors="tf", padding="max_length", truncation=True
60
  )
61
 
62
- start_logits, end_logits = qa_model(inputs)
63
 
64
- # Ensure start_logits is a tensor
65
- start_logits = tf.convert_to_tensor(start_logits)
66
 
67
- # Check and ensure data type of start_logits:
68
- if start_logits.dtype not in (tf.float32, tf.int32):
69
- start_logits = tf.cast(start_logits, tf.float32) # Example casting to float32
70
 
71
  # Verify axis type:
 
72
  if not isinstance(axis, tf.Tensor) or axis.dtype not in (tf.int32, tf.int64):
73
  axis = tf.constant(1, dtype=tf.int32) # Replace with correct axis if needed
74
 
75
- # Ensure compatibility for argmax (e.g., non-empty tensor):
76
- if start_logits.shape[0] == 0:
77
- raise ValueError("start_logits tensor is empty")
78
-
79
  answer_start = tf.math.argmax(start_logits, axis=axis)
80
  answer_end = tf.math.argmax(end_logits, axis=1) + 1 # Get predicted end index (exclusive)
81
 
@@ -84,7 +81,6 @@ def answer_question(text, question, max_length=512):
84
  return answer if answer else "No answer found."
85
 
86
 
87
-
88
  ## Streamlit app
89
 
90
  st.set_page_config(page_title="Enhanced PDF Summarizer")
@@ -119,4 +115,3 @@ if uploaded_file is not None:
119
  st.write(answer)
120
  else:
121
  st.error("No text found in the PDF.")
122
-
 
59
  question, text, return_tensors="tf", padding="max_length", truncation=True
60
  )
61
 
62
+ outputs = qa_model(inputs)
63
 
64
+ start_logits = outputs.start_logits
65
+ end_logits = outputs.end_logits
66
 
67
+ # Ensure start_logits and end_logits are tensors
68
+ start_logits = tf.convert_to_tensor(start_logits)
69
+ end_logits = tf.convert_to_tensor(end_logits)
70
 
71
  # Verify axis type:
72
+ axis = 1 # Assuming axis 1 for argmax
73
  if not isinstance(axis, tf.Tensor) or axis.dtype not in (tf.int32, tf.int64):
74
  axis = tf.constant(1, dtype=tf.int32) # Replace with correct axis if needed
75
 
 
 
 
 
76
  answer_start = tf.math.argmax(start_logits, axis=axis)
77
  answer_end = tf.math.argmax(end_logits, axis=1) + 1 # Get predicted end index (exclusive)
78
 
 
81
  return answer if answer else "No answer found."
82
 
83
 
 
84
  ## Streamlit app
85
 
86
  st.set_page_config(page_title="Enhanced PDF Summarizer")
 
115
  st.write(answer)
116
  else:
117
  st.error("No text found in the PDF.")