Spaces:
Runtime error
Runtime error
import json | |
import ssl | |
import certifi | |
import logging | |
import requests | |
import base64 | |
import os | |
import sys | |
import socket | |
from telegram import Update | |
from telegram.ext import Application, CommandHandler, MessageHandler, filters | |
from transformers import VitsModel, AutoTokenizer, WhisperProcessor, WhisperForConditionalGeneration | |
import torch | |
import scipy.io.wavfile | |
import librosa | |
import asyncio | |
from telegram.error import NetworkError | |
import aiodns | |
import httpx | |
from httpcore import AsyncConnectionPool | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
logging.getLogger("httpx").setLevel(logging.DEBUG) | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
class Config: | |
def __init__(self): | |
self.telegram_api_key = os.environ.get('TELEGRAM_API_KEY') | |
self.voice_message_file_path = os.environ.get('VOICE_MESSAGE_FILE_PATH', 'voice_message.ogg') | |
self.translate_kh_username = os.environ.get('TRANSLATE_KH_USERNAME') | |
self.translate_kh_password = os.environ.get('TRANSLATE_KH_PASSWORD') | |
if not all([self.telegram_api_key, self.translate_kh_username, self.translate_kh_password]): | |
raise ValueError("Missing required environment variables.") | |
class CustomDNSResolver: | |
def __init__(self): | |
self.resolver = aiodns.DNSResolver(nameservers=['8.8.8.8', '8.8.4.4']) | |
async def resolve(self, host, port): | |
try: | |
result = await self.resolver.query(host, 'A') | |
return [(result[0].host, port)] | |
except aiodns.error.DNSError as e: | |
logger.error(f"DNS resolution error for {host}: {e}") | |
return None | |
class CustomDNSResolver: | |
def __init__(self): | |
self.resolver = aiodns.DNSResolver(nameservers=['8.8.8.8', '8.8.4.4']) | |
async def resolve(self, hostname): | |
try: | |
result = await self.resolver.query(hostname, 'A') | |
return [{'host': entry.host, 'port': 443} for entry in result] | |
except aiodns.error.DNSError as e: | |
logger.error(f"DNS resolution error for {hostname}: {e}") | |
return None | |
class TranscribeKHBot: | |
def __init__(self): | |
self.config = Config() | |
self.custom_resolver = CustomDNSResolver() | |
self.application = Application.builder().token(self.config.telegram_api_key).build() | |
self.ssl_context = ssl.create_default_context(cafile=certifi.where()) | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.initialize_mms_model() | |
self.initialize_whisper_model() | |
self.http_client = httpx.AsyncClient( | |
transport=httpx.AsyncHTTPTransport( | |
retries=3, | |
verify=self.ssl_context | |
), | |
timeout=30.0 | |
) | |
def initialize_mms_model(self): | |
logger.info("Initializing MMS model for Khmer TTS...") | |
self.mms_model = VitsModel.from_pretrained("facebook/mms-tts-khm").to(self.device) | |
self.mms_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-khm") | |
logger.info("MMS model initialized successfully.") | |
def initialize_whisper_model(self): | |
logger.info("Initializing Whisper model...") | |
try: | |
self.whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium").to(self.device) | |
self.whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-medium") | |
self.whisper_model.eval() | |
logger.info("Whisper model initialized successfully.") | |
except Exception as e: | |
logger.error(f"Error initializing Whisper model: {e}", exc_info=True) | |
raise | |
def setup_handlers(self): | |
self.application.add_handler(CommandHandler("start", self.start_command)) | |
self.application.add_handler(CommandHandler("help", self.help_command)) | |
self.application.add_handler(MessageHandler(filters.VOICE, self.handle_voice)) | |
async def run_with_retries(self, max_retries=5, initial_delay=1): | |
for attempt in range(max_retries): | |
try: | |
await self.application.initialize() | |
await self.application.start() | |
await self.application.run_polling(allowed_updates=Update.ALL_TYPES) | |
break | |
except NetworkError as e: | |
if attempt < max_retries - 1: | |
delay = initial_delay * (2 ** attempt) | |
logger.error(f"Network error occurred: {e}. Retrying in {delay} seconds...") | |
await asyncio.sleep(delay) | |
else: | |
logger.error("Max retries reached. Exiting.") | |
raise | |
except Exception as e: | |
logger.error(f"An unexpected error occurred: {e}") | |
raise | |
def run(self): | |
self.setup_handlers() | |
logger.info("Bot is running...") | |
asyncio.run(self.run_with_retries()) | |
async def start_command(self, update: Update, context): | |
await update.message.reply_text( | |
"Welcome to Transcribe KH Bot\n" | |
"Send a voice message, and the bot will convert it to text, translate to Khmer, and generate Khmer speech for you.\n" | |
"Commands: \n" | |
"/help - help information\n" | |
"/select - select language" | |
) | |
async def help_command(self, update: Update, context): | |
await update.message.reply_text("Send a voice message to get started") | |
def preprocess_audio(self, audio_path): | |
audio, sr = librosa.load(audio_path, sr=16000) | |
audio = librosa.effects.trim(audio, top_db=20)[0] | |
return audio | |
async def handle_voice(self, update: Update, context): | |
message = await update.message.reply_text("Processing your voice message") | |
try: | |
file = await update.message.voice.get_file() | |
await file.download_to_drive(self.config.voice_message_file_path) | |
preprocessed_audio = self.preprocess_audio(self.config.voice_message_file_path) | |
input_features = self.whisper_processor(preprocessed_audio, sampling_rate=16000, return_tensors="pt").input_features | |
attention_mask = torch.ones_like(input_features) | |
forced_decoder_ids = self.whisper_processor.get_decoder_prompt_ids(language="en", task="transcribe") | |
predicted_ids = self.whisper_model.generate( | |
input_features.to(self.device), | |
attention_mask=attention_mask.to(self.device), | |
forced_decoder_ids=forced_decoder_ids | |
) | |
transcription = self.whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
logger.info(f"Whisper transcription: {transcription}") | |
try: | |
khmer_text = await self.translate_to_khmer(transcription) | |
except Exception as e: | |
logger.error(f"Error translating text: {str(e)}") | |
await update.message.reply_text("An error occurred while translating the text. Please try again later.") | |
return | |
khmer_speech = self.khmer_text_to_speech(khmer_text) | |
output_file = "khmer_speech.wav" | |
scipy.io.wavfile.write(output_file, self.mms_model.config.sampling_rate, khmer_speech) | |
await context.bot.delete_message(chat_id=update.effective_chat.id, message_id=message.message_id) | |
await update.message.reply_text(f"Original (English): {transcription}\n\nTranslated (Khmer): {khmer_text}") | |
await update.message.reply_voice(voice=open(output_file, "rb")) | |
except Exception as e: | |
logger.error(f"Error processing voice message: {str(e)}", exc_info=True) | |
await update.message.reply_text("An error occurred while processing your voice message. Please try again later.") | |
async def translate_to_khmer(self, text): | |
url = "https://translatekh.mptc.gov.kh/api" | |
username = self.config.translate_kh_username | |
password = self.config.translate_kh_password | |
credentials = f"{username}:{password}" | |
encoded_credentials = base64.b64encode(credentials.encode('utf-8')).decode('utf-8') | |
headers = { | |
"Authorization": f"Basic {encoded_credentials}", | |
"Content-Type": "application/json" | |
} | |
data = { | |
"input_text": [text], | |
"src_lang": "eng", | |
"tgt_lang": "kh" | |
} | |
try: | |
async with self.http_client as client: | |
response = await client.post(url, headers=headers, json=data) | |
response.raise_for_status() | |
result = response.json() | |
if "translate_text" in result and len(result["translate_text"]) > 0: | |
return result["translate_text"][0] | |
else: | |
logger.error(f"Empty or invalid translation result: {result}") | |
raise ValueError("Translation result is empty or invalid") | |
except httpx.RequestError as e: | |
logger.error(f"Error calling Translate KH API: {str(e)}") | |
raise | |
def khmer_text_to_speech(self, text): | |
inputs = self.mms_tokenizer(text, return_tensors="pt").to(self.device) | |
output = self.mms_model(**inputs).waveform | |
return output.squeeze().cpu().numpy() | |
def check_internet_connection(): | |
try: | |
socket.create_connection(("8.8.8.8", 53), timeout=5) | |
return True | |
except OSError: | |
return False | |
def main(): | |
if not check_internet_connection(): | |
logger.error("No internet connection available. Please check your network settings.") | |
sys.exit(1) | |
bot = TranscribeKHBot() | |
bot.run() | |
if __name__ == "__main__": | |
main() | |