MarioCerulo commited on
Commit
54eb26a
·
verified ·
1 Parent(s): bd76728

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -8
app.py CHANGED
@@ -57,18 +57,23 @@ collection = chroma_client.get_collection(name="F1-wiki", embedding_function=Cus
57
 
58
  question = get_text()
59
 
 
 
 
60
  if question:
61
- response = collection.query(query_texts=question, include=['documents'], n_results=5)
 
62
 
63
- context = " ".join(response['documents'][0])
64
 
65
- input_text = template.replace("{context}", context).replace("{question}", question)
66
- input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
 
67
 
68
- output = model.generate(input_ids, max_new_tokens=200, early_stopping=True)
69
- answer = tokenizer.decode(output[0], skip_special_tokens=True).split("ANSWER:")[1]
70
 
71
- st.session_state.past.append(question)
72
- st.session_state.generated.append(answer)
73
 
74
  st.write(answer)
 
57
 
58
  question = get_text()
59
 
60
+ if tokenizer.pad_token_id is None:
61
+ tokenizer.pad_token_id = tokenizer.eos_token_id
62
+
63
  if question:
64
+ with st.spinner("Generating answer... "):
65
+ response = collection.query(query_texts=question, include=['documents'], n_results=5)
66
 
67
+ context = " ".join(response['documents'][0])
68
 
69
+ input_text = template.replace("{context}", context).replace("{question}", question)
70
+ input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
71
+ attention_mask = (input_ids != tokenizer.pad_token_id).to(device)
72
 
73
+ output = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=200, early_stopping=True)
74
+ answer = tokenizer.decode(output[0], skip_special_tokens=True).split("ANSWER:")[1].strip()
75
 
76
+ st.session_state.past.append(question)
77
+ st.session_state.generated.append(answer)
78
 
79
  st.write(answer)