import json import ssl import certifi import logging import requests import base64 import os import sys import socket from telegram import Update from telegram.ext import Application, CommandHandler, MessageHandler, filters from transformers import VitsModel, AutoTokenizer, WhisperProcessor, WhisperForConditionalGeneration import torch import scipy.io.wavfile import librosa import asyncio from telegram.error import NetworkError import aiodns import httpx from httpcore import AsyncConnectionPool # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) logging.getLogger("httpx").setLevel(logging.DEBUG) os.environ["TOKENIZERS_PARALLELISM"] = "false" class Config: def __init__(self): self.telegram_api_key = os.environ.get('TELEGRAM_API_KEY') self.voice_message_file_path = os.environ.get('VOICE_MESSAGE_FILE_PATH', 'voice_message.ogg') self.translate_kh_username = os.environ.get('TRANSLATE_KH_USERNAME') self.translate_kh_password = os.environ.get('TRANSLATE_KH_PASSWORD') if not all([self.telegram_api_key, self.translate_kh_username, self.translate_kh_password]): raise ValueError("Missing required environment variables.") class CustomDNSResolver: def __init__(self): self.resolver = aiodns.DNSResolver(nameservers=['8.8.8.8', '8.8.4.4']) async def resolve(self, host, port): try: result = await self.resolver.query(host, 'A') return [(result[0].host, port)] except aiodns.error.DNSError as e: logger.error(f"DNS resolution error for {host}: {e}") return None class CustomDNSResolver: def __init__(self): self.resolver = aiodns.DNSResolver(nameservers=['8.8.8.8', '8.8.4.4']) async def resolve(self, hostname): try: result = await self.resolver.query(hostname, 'A') return [{'host': entry.host, 'port': 443} for entry in result] except aiodns.error.DNSError as e: logger.error(f"DNS resolution error for {hostname}: {e}") return None class TranscribeKHBot: def __init__(self): self.config = Config() self.custom_resolver = CustomDNSResolver() self.application = Application.builder().token(self.config.telegram_api_key).build() self.ssl_context = ssl.create_default_context(cafile=certifi.where()) self.device = "cuda" if torch.cuda.is_available() else "cpu" self.initialize_mms_model() self.initialize_whisper_model() self.http_client = httpx.AsyncClient( transport=httpx.AsyncHTTPTransport( retries=3, verify=self.ssl_context ), timeout=30.0 ) def initialize_mms_model(self): logger.info("Initializing MMS model for Khmer TTS...") self.mms_model = VitsModel.from_pretrained("facebook/mms-tts-khm").to(self.device) self.mms_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-khm") logger.info("MMS model initialized successfully.") def initialize_whisper_model(self): logger.info("Initializing Whisper model...") try: self.whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium").to(self.device) self.whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-medium") self.whisper_model.eval() logger.info("Whisper model initialized successfully.") except Exception as e: logger.error(f"Error initializing Whisper model: {e}", exc_info=True) raise def setup_handlers(self): self.application.add_handler(CommandHandler("start", self.start_command)) self.application.add_handler(CommandHandler("help", self.help_command)) self.application.add_handler(MessageHandler(filters.VOICE, self.handle_voice)) async def run_with_retries(self, max_retries=5, initial_delay=1): for attempt in range(max_retries): try: await self.application.initialize() await self.application.start() await self.application.run_polling(allowed_updates=Update.ALL_TYPES) break except NetworkError as e: if attempt < max_retries - 1: delay = initial_delay * (2 ** attempt) logger.error(f"Network error occurred: {e}. Retrying in {delay} seconds...") await asyncio.sleep(delay) else: logger.error("Max retries reached. Exiting.") raise except Exception as e: logger.error(f"An unexpected error occurred: {e}") raise def run(self): self.setup_handlers() logger.info("Bot is running...") asyncio.run(self.run_with_retries()) async def start_command(self, update: Update, context): await update.message.reply_text( "Welcome to Transcribe KH Bot\n" "Send a voice message, and the bot will convert it to text, translate to Khmer, and generate Khmer speech for you.\n" "Commands: \n" "/help - help information\n" "/select - select language" ) async def help_command(self, update: Update, context): await update.message.reply_text("Send a voice message to get started") def preprocess_audio(self, audio_path): audio, sr = librosa.load(audio_path, sr=16000) audio = librosa.effects.trim(audio, top_db=20)[0] return audio async def handle_voice(self, update: Update, context): message = await update.message.reply_text("Processing your voice message") try: file = await update.message.voice.get_file() await file.download_to_drive(self.config.voice_message_file_path) preprocessed_audio = self.preprocess_audio(self.config.voice_message_file_path) input_features = self.whisper_processor(preprocessed_audio, sampling_rate=16000, return_tensors="pt").input_features attention_mask = torch.ones_like(input_features) forced_decoder_ids = self.whisper_processor.get_decoder_prompt_ids(language="en", task="transcribe") predicted_ids = self.whisper_model.generate( input_features.to(self.device), attention_mask=attention_mask.to(self.device), forced_decoder_ids=forced_decoder_ids ) transcription = self.whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] logger.info(f"Whisper transcription: {transcription}") try: khmer_text = await self.translate_to_khmer(transcription) except Exception as e: logger.error(f"Error translating text: {str(e)}") await update.message.reply_text("An error occurred while translating the text. Please try again later.") return khmer_speech = self.khmer_text_to_speech(khmer_text) output_file = "khmer_speech.wav" scipy.io.wavfile.write(output_file, self.mms_model.config.sampling_rate, khmer_speech) await context.bot.delete_message(chat_id=update.effective_chat.id, message_id=message.message_id) await update.message.reply_text(f"Original (English): {transcription}\n\nTranslated (Khmer): {khmer_text}") await update.message.reply_voice(voice=open(output_file, "rb")) except Exception as e: logger.error(f"Error processing voice message: {str(e)}", exc_info=True) await update.message.reply_text("An error occurred while processing your voice message. Please try again later.") async def translate_to_khmer(self, text): url = "https://translatekh.mptc.gov.kh/api" username = self.config.translate_kh_username password = self.config.translate_kh_password credentials = f"{username}:{password}" encoded_credentials = base64.b64encode(credentials.encode('utf-8')).decode('utf-8') headers = { "Authorization": f"Basic {encoded_credentials}", "Content-Type": "application/json" } data = { "input_text": [text], "src_lang": "eng", "tgt_lang": "kh" } try: async with self.http_client as client: response = await client.post(url, headers=headers, json=data) response.raise_for_status() result = response.json() if "translate_text" in result and len(result["translate_text"]) > 0: return result["translate_text"][0] else: logger.error(f"Empty or invalid translation result: {result}") raise ValueError("Translation result is empty or invalid") except httpx.RequestError as e: logger.error(f"Error calling Translate KH API: {str(e)}") raise @torch.no_grad() def khmer_text_to_speech(self, text): inputs = self.mms_tokenizer(text, return_tensors="pt").to(self.device) output = self.mms_model(**inputs).waveform return output.squeeze().cpu().numpy() def check_internet_connection(): try: socket.create_connection(("8.8.8.8", 53), timeout=5) return True except OSError: return False def main(): if not check_internet_connection(): logger.error("No internet connection available. Please check your network settings.") sys.exit(1) bot = TranscribeKHBot() bot.run() if __name__ == "__main__": main()