Update utils.py
Browse files
utils.py
CHANGED
|
@@ -388,9 +388,17 @@ def query(api_llm, payload):
|
|
| 388 |
def llm_chain2(prompt, context):
|
| 389 |
full_prompt = RAG_CHAIN_PROMPT.format(context=context, question=prompt)
|
| 390 |
inputs = tokenizer_rag(full_prompt, return_tensors="pt", max_length=1024, truncation=True)
|
| 391 |
-
|
| 392 |
#Generiere die Antwort
|
| 393 |
-
outputs = modell_rag.generate(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
answer = tokenizer_rag.decode(outputs[0], skip_special_tokens=True)
|
| 395 |
|
| 396 |
return answer
|
|
|
|
| 388 |
def llm_chain2(prompt, context):
|
| 389 |
full_prompt = RAG_CHAIN_PROMPT.format(context=context, question=prompt)
|
| 390 |
inputs = tokenizer_rag(full_prompt, return_tensors="pt", max_length=1024, truncation=True)
|
| 391 |
+
attention_mask = (inputs != tokenizer_rag.pad_token_id).long()
|
| 392 |
#Generiere die Antwort
|
| 393 |
+
outputs = modell_rag.generate(
|
| 394 |
+
inputs,
|
| 395 |
+
attention_mask=attention_mask,
|
| 396 |
+
max_new_tokens=1024,
|
| 397 |
+
do_sample=True,
|
| 398 |
+
temperature=0.9,
|
| 399 |
+
pad_token_id=tokenizer.eos_token_id
|
| 400 |
+
)
|
| 401 |
+
#outputs = modell_rag.generate(inputs['input_ids'], max_new_tokens=1024, num_beams=2, early_stopping=True)
|
| 402 |
answer = tokenizer_rag.decode(outputs[0], skip_special_tokens=True)
|
| 403 |
|
| 404 |
return answer
|