danhtran2mind's picture
Upload 12 files
25effb7 verified
import copy
import re
import torch
from threading import Thread
from transformers import TextIteratorStreamer
from config import logger, MAX_INPUT_TOKEN_LENGTH
from prompts import PROMPT_FUNCTIONS
from response_parser import ParserState, parse_response, format_response, remove_tags
from utils import merge_conversation
def generate_response(model_handler, history, temperature, top_p, top_k, max_tokens, seed, active_gen, model_id, auto_clear):
raw_history = copy.deepcopy(history)
# Clean history by removing tags from assistant responses
history = [[item[0], remove_tags(item[1]) if item[1] else None] for item in history]
try:
# Validate history
if not isinstance(history, list) or not history:
logger.error("History is empty or not a list")
history = [[None, "Error: Conversation history is empty or invalid"]]
yield history
return
# Validate last history entry
if not isinstance(history[-1], (list, tuple)) or len(history[-1]) < 1 or not history[-1][0]:
logger.error("Last history entry is invalid or missing user message")
history = raw_history
history[-1][1] = "Error: No valid user message provided"
yield history
return
# Load model if necessary
if model_handler.model is None or model_handler.tokenizer is None or model_id != model_handler.current_model_id:
status, _ = model_handler.load_model(model_id, history)
if "Error" in status:
logger.error(status)
history[-1][1] = status
yield history
return
torch.manual_seed(int(seed))
if torch.cuda.is_available():
torch.cuda.manual_seed(int(seed))
torch.cuda.manual_seed_all(int(seed))
# Validate prompt function
if model_id not in PROMPT_FUNCTIONS:
logger.error(f"No prompt function defined for model_id: {model_id}")
history[-1][1] = f"Error: No prompt function defined for model {model_id}"
yield history
return
prompt_fn = PROMPT_FUNCTIONS[model_id]
# Handle specific model prompt formatting
if model_id in [
"Llama-3.2-3B-Reasoning-Vi-Medical-LoRA",
"Qwen-3-0.6B-Reasoning-Vi-Medical-LoRA"
]:
if auto_clear:
text = prompt_fn(model_handler.tokenizer, history[-1][0])
else:
text = prompt_fn(model_handler.tokenizer, merge_conversation(history))
inputs = model_handler.tokenizer(
[text],
return_tensors="pt",
padding=True,
truncation=True,
max_length=MAX_INPUT_TOKEN_LENGTH
)
else:
# Build conversation for other models
conversation = []
for msg in history:
if msg[0]:
conversation.append({"role": "user", "content": msg[0]})
if msg[1]:
clean_text = ' '.join(line for line in msg[1].split('\n') if not line.startswith('✅ Thought for')).strip()
conversation.append({"role": "assistant", "content": clean_text})
elif msg[0] and not msg[1]:
conversation.append({"role": "assistant", "content": ""})
# Ensure at least one user message
if not any(msg["role"] == "user" for msg in conversation):
logger.error("No valid user messages in conversation history")
history = raw_history
history[-1][1] = "Error: No valid user messages in conversation history"
yield history
return
# Apply auto_clear logic
if auto_clear:
# Keep only the last user message and add an empty assistant response
user_msgs = [msg for msg in conversation if msg["role"] == "user"]
if user_msgs:
conversation = [{"role": "user", "content": user_msgs[-1]["content"]}, {"role": "assistant", "content": ""}]
else:
logger.error("No user messages found after filtering")
history = raw_history
history[-1][1] = "Error: No user messages found in conversation history"
yield history
return
else:
# Ensure the conversation ends with an assistant placeholder if the last message is from user
if conversation and conversation[-1]["role"] == "user":
conversation.append({"role": "assistant", "content": ""})
text = prompt_fn(model_handler.tokenizer, conversation)
tokenizer_kwargs = {
"return_tensors": "pt",
"padding": True,
"truncation": True,
"max_length": MAX_INPUT_TOKEN_LENGTH
}
inputs = model_handler.tokenizer(text, **tokenizer_kwargs)
if inputs is None or "input_ids" not in inputs:
logger.error("Tokenizer returned invalid or None output")
history = raw_history
history[-1][1] = "Error: Failed to tokenize input"
yield history
return
input_ids = inputs["input_ids"].to(model_handler.model.device)
attention_mask = inputs.get("attention_mask").to(model_handler.model.device) if "attention_mask" in inputs else None
generate_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"max_new_tokens": max_tokens,
"do_sample": True,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"num_beams": 1,
"repetition_penalty": 1.0,
"pad_token_id": model_handler.tokenizer.pad_token_id,
"eos_token_id": model_handler.tokenizer.eos_token_id,
"use_cache": True,
"cache_implementation": "dynamic",
}
streamer = TextIteratorStreamer(model_handler.tokenizer, timeout=360.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs["streamer"] = streamer
def run_generation():
try:
model_handler.model.generate(**generate_kwargs)
except Exception as e:
logger.error(f"Generation failed: {str(e)}")
raise
thread = Thread(target=run_generation)
thread.start()
state = ParserState()
if model_id in [
"Llama-3.2-3B-Reasoning-Vi-Medical-LoRA",
"Qwen-3-0.6B-Reasoning-Vi-Medical-LoRA"
]:
full_response = "<think>"
else:
full_response = ""
for text in streamer:
if not active_gen[0]:
logger.info("Generation stopped by user")
break
if text:
logger.debug(f"Raw streamer output: {text}")
text = re.sub(r'<\|\w+\|>', '', text)
full_response += text
state, elapsed = parse_response(full_response, state)
collapsible, answer_part = format_response(state, elapsed)
history = raw_history
history[-1][1] = "\n\n".join(collapsible + [answer_part])
yield history
else:
logger.debug("Streamer returned empty text")
thread.join()
thread = None
state, elapsed = parse_response(full_response, state)
collapsible, answer_part = format_response(state, elapsed)
history = raw_history
history[-1][1] = "\n\n".join(collapsible + [answer_part])
if not full_response:
logger.warning("No response generated by model")
history[-1][1] = "No response generated. Please try again or select a different model."
yield history
except Exception as e:
logger.error(f"Error in generate: {str(e)}")
history = raw_history
if not history or not isinstance(history, list):
history = [[None, f"Error: {str(e)}. Please try again or select a different model."]]
else:
history[-1][1] = f"Error: {str(e)}. Please try again or select a different model."
yield history
finally:
active_gen[0] = False