Spaces:
Runtime error
Runtime error
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from dataclasses import dataclass | |
| from functools import partial | |
| from time import perf_counter | |
| from typing import Optional | |
| import lhotse.dataset | |
| import torch | |
| from lhotse import CutSet, fastcopy | |
| from lhotse.dataset import IterableDatasetWrapper | |
| from lhotse.serialization import SequentialJsonlWriter | |
| from omegaconf import OmegaConf | |
| from transformers import GenerationConfig | |
| from nemo.collections.common.data.lhotse import NeMoMultimodalConversation | |
| from nemo.collections.common.data.lhotse.cutset import cut_to_conversation, guess_parse_cutset | |
| from nemo.collections.common.data.lhotse.dataloader import tokenize_with_prompt | |
| from nemo.collections.common.data.lhotse.text_adapters import TextTurn | |
| from nemo.collections.speechlm2 import SALM, SALMDataset | |
| from nemo.core.config import hydra_runner | |
| from nemo.utils import logging | |
| class SalmEvalConfig: | |
| pretrained_name: str | |
| inputs: str | |
| batch_size: int = 64 | |
| max_new_tokens: int = 128 | |
| output_manifest: str = "generations.jsonl" | |
| verbose: bool = True | |
| device: str = "cuda" | |
| dtype: str = "bfloat16" | |
| extra_eos_tokens: Optional[list[str]] = None | |
| system_prompt: Optional[str] = None | |
| user_prompt: Optional[str] = None | |
| def main(cfg: SalmEvalConfig): | |
| logging.info(f"Hydra config:\n{OmegaConf.to_yaml(cfg)}") | |
| model = SALM.from_pretrained(cfg.pretrained_name).eval().to(getattr(torch, cfg.dtype)).to(cfg.device) | |
| conversations = ( | |
| guess_parse_cutset(cfg.inputs) | |
| .map( | |
| partial( | |
| cut_to_conversation, | |
| audio_locator_tag=model.audio_locator_tag, | |
| token_equivalent_duration=model.token_equivalent_duration, | |
| ) | |
| ) | |
| .map( | |
| partial(attach_system_and_user_turns, system_prompt=cfg.system_prompt, user_prompt=cfg.user_prompt), | |
| apply_fn=None, | |
| ) | |
| .map(strip_response_if_any, apply_fn=None) | |
| .map( | |
| partial( | |
| tokenize_with_prompt, | |
| tokenizer=model.tokenizer, | |
| prompt_format=model.cfg.prompt_format, | |
| ), | |
| apply_fn=None, | |
| ) | |
| ) | |
| conversations = sort_by_length(conversations) | |
| dloader = torch.utils.data.DataLoader( | |
| dataset=IterableDatasetWrapper( | |
| dataset=SALMDataset(model.tokenizer), | |
| sampler=lhotse.dataset.DynamicCutSampler(conversations, max_cuts=cfg.batch_size), | |
| ), | |
| num_workers=1, | |
| batch_size=None, | |
| ) | |
| eos_tokens = [model.text_eos_id] | |
| if cfg.extra_eos_tokens is not None: | |
| for t in cfg.extra_eos_tokens: | |
| tid = model.tokenizer.token_to_id(t) | |
| assert tid is not None, f"Token '{t}' is not in the model's vocabulary." | |
| eos_tokens.append(tid) | |
| writer = SequentialJsonlWriter(cfg.output_manifest) | |
| num_answer_tokens = [] | |
| infer_durations = [] | |
| for batch_idx, batch in enumerate(dloader): | |
| ts = perf_counter() | |
| answer_ids = model.generate( | |
| prompts=batch["input_ids"].to(model.device, non_blocking=True), | |
| audios=batch["audios"].to(model.device, non_blocking=True), | |
| audio_lens=batch["audio_lens"].to(model.device, non_blocking=True), | |
| generation_config=GenerationConfig( | |
| max_new_tokens=cfg.max_new_tokens, | |
| bos_token_id=model.text_bos_id, | |
| eos_token_id=eos_tokens, | |
| pad_token_id=model.text_pad_id, | |
| ), | |
| ) | |
| answer_ids = answer_ids.cpu() | |
| batch_infer_duration = perf_counter() - ts | |
| batch_contexts = [model.tokenizer.ids_to_text(example) for example in batch["input_ids"]] | |
| answer_ids = [parse_hyp(ans, eos_tokens) for ans in answer_ids] | |
| batch_num_answer_tokens = [len(ans) for ans in answer_ids] | |
| batch_answers = [model.tokenizer.ids_to_text(ans) for ans in answer_ids] | |
| for conv, ctx, ans in zip(batch["conversations"], batch_contexts, batch_answers): | |
| conv.turns.append(TextTurn(role="assistant", value=ans)) | |
| for k, v in list(conv.custom.items()): | |
| if isinstance(v, torch.Tensor): | |
| del conv.custom[k] | |
| writer.write(conv.to_dict()) | |
| num_answer_tokens.extend(batch_num_answer_tokens) | |
| infer_durations.append(batch_infer_duration) | |
| if cfg.verbose: | |
| batch_token_per_second = sum(batch_num_answer_tokens) / batch_infer_duration | |
| logging.info(f"Batch {batch_idx}: TPS={batch_token_per_second:.2f}") | |
| rtfx = sum(num_answer_tokens) / sum(infer_durations) | |
| logging.info(f"TPS: {rtfx:.2f}") | |
| def attach_system_and_user_turns( | |
| conversation: NeMoMultimodalConversation, system_prompt: str | None = None, user_prompt: str | None = None | |
| ) -> NeMoMultimodalConversation: | |
| if system_prompt is None and user_prompt is None: | |
| return conversation | |
| turns = conversation.turns | |
| # Attach user prompt only when no user turn with a text prompt exists. | |
| if user_prompt is not None and not any(isinstance(t, TextTurn) and t.role == "user" for t in turns): | |
| turns = [TextTurn(role="user", value=user_prompt)] + turns | |
| # Attach system prompt only when no system prompt already exists. | |
| if system_prompt is not None and not any(t.role == "system" for t in turns): | |
| turns = [TextTurn(role="system", value=system_prompt)] + turns | |
| return fastcopy(conversation, turns=turns) | |
| def strip_response_if_any( | |
| conversation: NeMoMultimodalConversation, | |
| ) -> NeMoMultimodalConversation: | |
| turns = conversation.turns | |
| while turns[-1].role == "assistant": | |
| turns = turns[:-1] | |
| return fastcopy(conversation, turns=conversation.turns[:-1]) | |
| def sort_by_length(conversations: CutSet) -> CutSet: | |
| return CutSet(sorted(conversations, key=lambda c: c.total_length, reverse=True)) | |
| def parse_hyp(answer: torch.Tensor, eos_tokens: list[int]): | |
| end = torch.isin(answer, torch.tensor(eos_tokens)).nonzero(as_tuple=True)[0] | |
| if end.numel() == 0: | |
| return answer | |
| end = end[0] | |
| return answer[:end] | |
| if __name__ == '__main__': | |
| main() | |