lyfeyvutha's picture
Initial commit
d9b8c51
import json
import ssl
import certifi
import logging
import requests
import base64
import os
import sys
import socket
from telegram import Update
from telegram.ext import Application, CommandHandler, MessageHandler, filters
from transformers import VitsModel, AutoTokenizer, WhisperProcessor, WhisperForConditionalGeneration
import torch
import scipy.io.wavfile
import librosa
import asyncio
from telegram.error import NetworkError
import aiodns
import httpx
from httpcore import AsyncConnectionPool
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
logging.getLogger("httpx").setLevel(logging.DEBUG)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class Config:
def __init__(self):
self.telegram_api_key = os.environ.get('TELEGRAM_API_KEY')
self.voice_message_file_path = os.environ.get('VOICE_MESSAGE_FILE_PATH', 'voice_message.ogg')
self.translate_kh_username = os.environ.get('TRANSLATE_KH_USERNAME')
self.translate_kh_password = os.environ.get('TRANSLATE_KH_PASSWORD')
if not all([self.telegram_api_key, self.translate_kh_username, self.translate_kh_password]):
raise ValueError("Missing required environment variables.")
class CustomDNSResolver:
def __init__(self):
self.resolver = aiodns.DNSResolver(nameservers=['8.8.8.8', '8.8.4.4'])
async def resolve(self, host, port):
try:
result = await self.resolver.query(host, 'A')
return [(result[0].host, port)]
except aiodns.error.DNSError as e:
logger.error(f"DNS resolution error for {host}: {e}")
return None
class CustomDNSResolver:
def __init__(self):
self.resolver = aiodns.DNSResolver(nameservers=['8.8.8.8', '8.8.4.4'])
async def resolve(self, hostname):
try:
result = await self.resolver.query(hostname, 'A')
return [{'host': entry.host, 'port': 443} for entry in result]
except aiodns.error.DNSError as e:
logger.error(f"DNS resolution error for {hostname}: {e}")
return None
class TranscribeKHBot:
def __init__(self):
self.config = Config()
self.custom_resolver = CustomDNSResolver()
self.application = Application.builder().token(self.config.telegram_api_key).build()
self.ssl_context = ssl.create_default_context(cafile=certifi.where())
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.initialize_mms_model()
self.initialize_whisper_model()
self.http_client = httpx.AsyncClient(
transport=httpx.AsyncHTTPTransport(
retries=3,
verify=self.ssl_context
),
timeout=30.0
)
def initialize_mms_model(self):
logger.info("Initializing MMS model for Khmer TTS...")
self.mms_model = VitsModel.from_pretrained("facebook/mms-tts-khm").to(self.device)
self.mms_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-khm")
logger.info("MMS model initialized successfully.")
def initialize_whisper_model(self):
logger.info("Initializing Whisper model...")
try:
self.whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium").to(self.device)
self.whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
self.whisper_model.eval()
logger.info("Whisper model initialized successfully.")
except Exception as e:
logger.error(f"Error initializing Whisper model: {e}", exc_info=True)
raise
def setup_handlers(self):
self.application.add_handler(CommandHandler("start", self.start_command))
self.application.add_handler(CommandHandler("help", self.help_command))
self.application.add_handler(MessageHandler(filters.VOICE, self.handle_voice))
async def run_with_retries(self, max_retries=5, initial_delay=1):
for attempt in range(max_retries):
try:
await self.application.initialize()
await self.application.start()
await self.application.run_polling(allowed_updates=Update.ALL_TYPES)
break
except NetworkError as e:
if attempt < max_retries - 1:
delay = initial_delay * (2 ** attempt)
logger.error(f"Network error occurred: {e}. Retrying in {delay} seconds...")
await asyncio.sleep(delay)
else:
logger.error("Max retries reached. Exiting.")
raise
except Exception as e:
logger.error(f"An unexpected error occurred: {e}")
raise
def run(self):
self.setup_handlers()
logger.info("Bot is running...")
asyncio.run(self.run_with_retries())
async def start_command(self, update: Update, context):
await update.message.reply_text(
"Welcome to Transcribe KH Bot\n"
"Send a voice message, and the bot will convert it to text, translate to Khmer, and generate Khmer speech for you.\n"
"Commands: \n"
"/help - help information\n"
"/select - select language"
)
async def help_command(self, update: Update, context):
await update.message.reply_text("Send a voice message to get started")
def preprocess_audio(self, audio_path):
audio, sr = librosa.load(audio_path, sr=16000)
audio = librosa.effects.trim(audio, top_db=20)[0]
return audio
async def handle_voice(self, update: Update, context):
message = await update.message.reply_text("Processing your voice message")
try:
file = await update.message.voice.get_file()
await file.download_to_drive(self.config.voice_message_file_path)
preprocessed_audio = self.preprocess_audio(self.config.voice_message_file_path)
input_features = self.whisper_processor(preprocessed_audio, sampling_rate=16000, return_tensors="pt").input_features
attention_mask = torch.ones_like(input_features)
forced_decoder_ids = self.whisper_processor.get_decoder_prompt_ids(language="en", task="transcribe")
predicted_ids = self.whisper_model.generate(
input_features.to(self.device),
attention_mask=attention_mask.to(self.device),
forced_decoder_ids=forced_decoder_ids
)
transcription = self.whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
logger.info(f"Whisper transcription: {transcription}")
try:
khmer_text = await self.translate_to_khmer(transcription)
except Exception as e:
logger.error(f"Error translating text: {str(e)}")
await update.message.reply_text("An error occurred while translating the text. Please try again later.")
return
khmer_speech = self.khmer_text_to_speech(khmer_text)
output_file = "khmer_speech.wav"
scipy.io.wavfile.write(output_file, self.mms_model.config.sampling_rate, khmer_speech)
await context.bot.delete_message(chat_id=update.effective_chat.id, message_id=message.message_id)
await update.message.reply_text(f"Original (English): {transcription}\n\nTranslated (Khmer): {khmer_text}")
await update.message.reply_voice(voice=open(output_file, "rb"))
except Exception as e:
logger.error(f"Error processing voice message: {str(e)}", exc_info=True)
await update.message.reply_text("An error occurred while processing your voice message. Please try again later.")
async def translate_to_khmer(self, text):
url = "https://translatekh.mptc.gov.kh/api"
username = self.config.translate_kh_username
password = self.config.translate_kh_password
credentials = f"{username}:{password}"
encoded_credentials = base64.b64encode(credentials.encode('utf-8')).decode('utf-8')
headers = {
"Authorization": f"Basic {encoded_credentials}",
"Content-Type": "application/json"
}
data = {
"input_text": [text],
"src_lang": "eng",
"tgt_lang": "kh"
}
try:
async with self.http_client as client:
response = await client.post(url, headers=headers, json=data)
response.raise_for_status()
result = response.json()
if "translate_text" in result and len(result["translate_text"]) > 0:
return result["translate_text"][0]
else:
logger.error(f"Empty or invalid translation result: {result}")
raise ValueError("Translation result is empty or invalid")
except httpx.RequestError as e:
logger.error(f"Error calling Translate KH API: {str(e)}")
raise
@torch.no_grad()
def khmer_text_to_speech(self, text):
inputs = self.mms_tokenizer(text, return_tensors="pt").to(self.device)
output = self.mms_model(**inputs).waveform
return output.squeeze().cpu().numpy()
def check_internet_connection():
try:
socket.create_connection(("8.8.8.8", 53), timeout=5)
return True
except OSError:
return False
def main():
if not check_internet_connection():
logger.error("No internet connection available. Please check your network settings.")
sys.exit(1)
bot = TranscribeKHBot()
bot.run()
if __name__ == "__main__":
main()