adinarayana commited on
Commit
a2d5f2f
·
verified ·
1 Parent(s): 99a4a27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -8
app.py CHANGED
@@ -68,15 +68,12 @@ def answer_question(text, question, max_length=512):
68
  start_logits = tf.convert_to_tensor(start_logits)
69
  end_logits = tf.convert_to_tensor(end_logits)
70
 
71
- # Verify axis type:
72
- axis = 1 # Assuming axis 1 for argmax
73
- if not isinstance(axis, tf.Tensor) or axis.dtype not in (tf.int32, tf.int64):
74
- axis = tf.constant(1, dtype=tf.int32) # Replace with correct axis if needed
75
 
76
- answer_start = tf.math.argmax(start_logits, axis=axis)
77
- answer_end = tf.math.argmax(end_logits, axis=1) + 1 # Get predicted end index (exclusive)
78
-
79
- answer = tf.gather(text, answer_start, axis=1).numpy()[0][answer_start[0]:answer_end[0]]
80
 
81
  return answer if answer else "No answer found."
82
 
 
68
  start_logits = tf.convert_to_tensor(start_logits)
69
  end_logits = tf.convert_to_tensor(end_logits)
70
 
71
+ # Find the indices of the start and end positions
72
+ answer_start = tf.argmax(start_logits, axis=1).numpy()[0]
73
+ answer_end = (tf.argmax(end_logits, axis=1) + 1).numpy()[0] # Increment by 1 for exclusive end index
 
74
 
75
+ # Extract the answer text from the original text
76
+ answer = text[answer_start:answer_end].strip()
 
 
77
 
78
  return answer if answer else "No answer found."
79