|
import copy
|
|
import re
|
|
import torch
|
|
from threading import Thread
|
|
from transformers import TextIteratorStreamer
|
|
from config import logger, MAX_INPUT_TOKEN_LENGTH
|
|
from prompts import PROMPT_FUNCTIONS
|
|
from response_parser import ParserState, parse_response, format_response, remove_tags
|
|
from utils import merge_conversation
|
|
|
|
def generate_response(model_handler, history, temperature, top_p, top_k, max_tokens, seed, active_gen, model_id, auto_clear):
|
|
raw_history = copy.deepcopy(history)
|
|
|
|
|
|
history = [[item[0], remove_tags(item[1]) if item[1] else None] for item in history]
|
|
|
|
try:
|
|
|
|
if not isinstance(history, list) or not history:
|
|
logger.error("History is empty or not a list")
|
|
history = [[None, "Error: Conversation history is empty or invalid"]]
|
|
yield history
|
|
return
|
|
|
|
if not isinstance(history[-1], (list, tuple)) or len(history[-1]) < 1 or not history[-1][0]:
|
|
logger.error("Last history entry is invalid or missing user message")
|
|
history = raw_history
|
|
history[-1][1] = "Error: No valid user message provided"
|
|
yield history
|
|
return
|
|
|
|
|
|
if model_handler.model is None or model_handler.tokenizer is None or model_id != model_handler.current_model_id:
|
|
status, _ = model_handler.load_model(model_id, history)
|
|
if "Error" in status:
|
|
logger.error(status)
|
|
history[-1][1] = status
|
|
yield history
|
|
return
|
|
|
|
torch.manual_seed(int(seed))
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed(int(seed))
|
|
torch.cuda.manual_seed_all(int(seed))
|
|
|
|
|
|
if model_id not in PROMPT_FUNCTIONS:
|
|
logger.error(f"No prompt function defined for model_id: {model_id}")
|
|
history[-1][1] = f"Error: No prompt function defined for model {model_id}"
|
|
yield history
|
|
return
|
|
prompt_fn = PROMPT_FUNCTIONS[model_id]
|
|
|
|
|
|
if model_id in [
|
|
"Llama-3.2-3B-Reasoning-Vi-Medical-LoRA",
|
|
"Qwen-3-0.6B-Reasoning-Vi-Medical-LoRA"
|
|
]:
|
|
if auto_clear:
|
|
text = prompt_fn(model_handler.tokenizer, history[-1][0])
|
|
else:
|
|
text = prompt_fn(model_handler.tokenizer, merge_conversation(history))
|
|
|
|
inputs = model_handler.tokenizer(
|
|
[text],
|
|
return_tensors="pt",
|
|
padding=True,
|
|
truncation=True,
|
|
max_length=MAX_INPUT_TOKEN_LENGTH
|
|
)
|
|
else:
|
|
|
|
conversation = []
|
|
for msg in history:
|
|
if msg[0]:
|
|
conversation.append({"role": "user", "content": msg[0]})
|
|
if msg[1]:
|
|
clean_text = ' '.join(line for line in msg[1].split('\n') if not line.startswith('✅ Thought for')).strip()
|
|
conversation.append({"role": "assistant", "content": clean_text})
|
|
elif msg[0] and not msg[1]:
|
|
conversation.append({"role": "assistant", "content": ""})
|
|
|
|
|
|
if not any(msg["role"] == "user" for msg in conversation):
|
|
logger.error("No valid user messages in conversation history")
|
|
history = raw_history
|
|
history[-1][1] = "Error: No valid user messages in conversation history"
|
|
yield history
|
|
return
|
|
|
|
|
|
if auto_clear:
|
|
|
|
user_msgs = [msg for msg in conversation if msg["role"] == "user"]
|
|
if user_msgs:
|
|
conversation = [{"role": "user", "content": user_msgs[-1]["content"]}, {"role": "assistant", "content": ""}]
|
|
else:
|
|
logger.error("No user messages found after filtering")
|
|
history = raw_history
|
|
history[-1][1] = "Error: No user messages found in conversation history"
|
|
yield history
|
|
return
|
|
else:
|
|
|
|
if conversation and conversation[-1]["role"] == "user":
|
|
conversation.append({"role": "assistant", "content": ""})
|
|
|
|
text = prompt_fn(model_handler.tokenizer, conversation)
|
|
tokenizer_kwargs = {
|
|
"return_tensors": "pt",
|
|
"padding": True,
|
|
"truncation": True,
|
|
"max_length": MAX_INPUT_TOKEN_LENGTH
|
|
}
|
|
|
|
inputs = model_handler.tokenizer(text, **tokenizer_kwargs)
|
|
|
|
if inputs is None or "input_ids" not in inputs:
|
|
logger.error("Tokenizer returned invalid or None output")
|
|
history = raw_history
|
|
history[-1][1] = "Error: Failed to tokenize input"
|
|
yield history
|
|
return
|
|
|
|
input_ids = inputs["input_ids"].to(model_handler.model.device)
|
|
attention_mask = inputs.get("attention_mask").to(model_handler.model.device) if "attention_mask" in inputs else None
|
|
|
|
generate_kwargs = {
|
|
"input_ids": input_ids,
|
|
"attention_mask": attention_mask,
|
|
"max_new_tokens": max_tokens,
|
|
"do_sample": True,
|
|
"temperature": temperature,
|
|
"top_p": top_p,
|
|
"top_k": top_k,
|
|
"num_beams": 1,
|
|
"repetition_penalty": 1.0,
|
|
"pad_token_id": model_handler.tokenizer.pad_token_id,
|
|
"eos_token_id": model_handler.tokenizer.eos_token_id,
|
|
"use_cache": True,
|
|
"cache_implementation": "dynamic",
|
|
}
|
|
|
|
streamer = TextIteratorStreamer(model_handler.tokenizer, timeout=360.0, skip_prompt=True, skip_special_tokens=True)
|
|
generate_kwargs["streamer"] = streamer
|
|
|
|
def run_generation():
|
|
try:
|
|
model_handler.model.generate(**generate_kwargs)
|
|
except Exception as e:
|
|
logger.error(f"Generation failed: {str(e)}")
|
|
raise
|
|
|
|
thread = Thread(target=run_generation)
|
|
thread.start()
|
|
|
|
state = ParserState()
|
|
if model_id in [
|
|
"Llama-3.2-3B-Reasoning-Vi-Medical-LoRA",
|
|
"Qwen-3-0.6B-Reasoning-Vi-Medical-LoRA"
|
|
]:
|
|
full_response = "<think>"
|
|
else:
|
|
full_response = ""
|
|
|
|
for text in streamer:
|
|
if not active_gen[0]:
|
|
logger.info("Generation stopped by user")
|
|
break
|
|
|
|
if text:
|
|
logger.debug(f"Raw streamer output: {text}")
|
|
text = re.sub(r'<\|\w+\|>', '', text)
|
|
full_response += text
|
|
state, elapsed = parse_response(full_response, state)
|
|
|
|
collapsible, answer_part = format_response(state, elapsed)
|
|
history = raw_history
|
|
history[-1][1] = "\n\n".join(collapsible + [answer_part])
|
|
yield history
|
|
else:
|
|
logger.debug("Streamer returned empty text")
|
|
|
|
thread.join()
|
|
thread = None
|
|
state, elapsed = parse_response(full_response, state)
|
|
collapsible, answer_part = format_response(state, elapsed)
|
|
history = raw_history
|
|
history[-1][1] = "\n\n".join(collapsible + [answer_part])
|
|
|
|
if not full_response:
|
|
logger.warning("No response generated by model")
|
|
history[-1][1] = "No response generated. Please try again or select a different model."
|
|
|
|
yield history
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in generate: {str(e)}")
|
|
history = raw_history
|
|
if not history or not isinstance(history, list):
|
|
history = [[None, f"Error: {str(e)}. Please try again or select a different model."]]
|
|
else:
|
|
history[-1][1] = f"Error: {str(e)}. Please try again or select a different model."
|
|
|
|
yield history
|
|
finally:
|
|
active_gen[0] = False |