Spaces:
Build error
Build error
| import torch | |
| from config import Config | |
| from networks import peft_model | |
| tokenizer = Config.tokenizer | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.add_tokens('<question-answer>') | |
| def prepare_inputs(peft_model, audio_model, clip_model, projection, text_input=None, image_input=None, audio_input=None): | |
| text_audio, text_embed, image_embed = None, None, None | |
| if audio_input: | |
| audio_transcribed = audio_model.transcribe(audio_input) | |
| processed_audio = '' | |
| for audio_segment in audio_transcribed['segments']: | |
| processed_audio += audio_segment['text'] | |
| processed_audio = processed_audio.strip() | |
| if image_input != None: | |
| image_processed = Config.processor(images=image_input, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = clip_model(**image_processed) | |
| last_hidden_state = outputs.last_hidden_state[:, 1:, :] | |
| image_embed = projection(last_hidden_state.to(Config.device)).to(torch.float16) | |
| if audio_input != None and text_input != None: | |
| text_audio = f"{text_input} {processed_audio}" | |
| elif audio_input and text_input == None: | |
| text_audio = processed_audio | |
| elif audio_input == None and text_input: | |
| text_audio = text_input | |
| if text_audio: | |
| tokenized_text_audio = tokenizer.encode(text_audio) | |
| tokenized_text_audio = Config.IMAGE_SEPARATOR_TOKENS + tokenized_text_audio + [Config.QUESTION_ANSWER_SEPARATOR_ID] | |
| with torch.no_grad(): | |
| tokenized_text_audio = torch.tensor(tokenized_text_audio) | |
| text_embed = peft_model.model.model.embed_tokens(tokenized_text_audio.to(Config.device)).unsqueeze(0) | |
| if text_audio != None and image_input != None: | |
| combined_embed = torch.cat([image_embed, text_embed], dim=1) | |
| elif text_audio and image_input == None: | |
| combined_embed = text_embed | |
| elif text_audio == None and image_input: | |
| combined_embed = image_embed | |
| return(combined_embed) | |
| def chatbot_response(text_input, image_input, audio_input): | |
| if text_input == '': | |
| text_input = None | |
| if text_input == None and image_input == None and audio_input == None: | |
| return "Please enter text, upload an image, or record audio." | |
| combined_embeds = prepare_inputs(text_input, image_input, audio_input) | |
| generated_tokens = generate_tokens(combined_embeds, max_tokens=60) | |
| return(tokenizer.decode(generated_tokens)) | |
| def generate_tokens(combined_embeds, max_tokens=100): | |
| pred_tokens = [] | |
| combined_embed = combined_embeds | |
| for _ in range(max_tokens): | |
| logits = peft_model(inputs_embeds=combined_embed).logits[:, -1, :] | |
| next_token_id = logits.argmax(dim=-1) | |
| if next_token_id.item() == 50256: | |
| break | |
| pred_tokens.append(next_token_id.item()) | |
| next_token_embed = peft_model.model.model.embed_tokens(next_token_id.unsqueeze(0)) | |
| with torch.no_grad(): | |
| combined_embed = torch.cat((combined_embed, next_token_embed), dim=1) | |
| return(pred_tokens) |