from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from fastapi.middleware.cors import CORSMiddleware
import torch
import os
import yaml
import transformers

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Adjust this as needed
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("EzekielMW/LuoKslGloss")
model = AutoModelForSeq2SeqLM.from_pretrained("EzekielMW/LuoKslGloss")

# Where should output files be stored locally
# Where should output files be stored locally
drive_folder = "./quadserverlogs"

if not os.path.exists(drive_folder):
    os.makedirs(drive_folder)

# Large batch sizes generally give good results for translation
effective_train_batch_size = 480
train_batch_size = 6
eval_batch_size = train_batch_size

gradient_accumulation_steps = int(effective_train_batch_size / train_batch_size)

# Everything in one yaml string, so that it can all be logged.
yaml_config = '''
training_args:
  output_dir: "{drive_folder}"
  eval_strategy: steps
  eval_steps: 200
  save_steps: 200
  gradient_accumulation_steps: {gradient_accumulation_steps}
  learning_rate: 3.0e-4  # Include decimal point to parse as float
  # optim: adafactor
  per_device_train_batch_size: {train_batch_size}
  per_device_eval_batch_size: {eval_batch_size}
  weight_decay: 0.01
  save_total_limit: 3
  max_steps: 500
  predict_with_generate: True
  fp16: True
  logging_dir: "{drive_folder}"
  load_best_model_at_end: True
  metric_for_best_model: loss
  seed: 123
  push_to_hub: False

max_input_length: 128
eval_pretrained_model: False
early_stopping_patience: 4
data_dir: .

# Use a 600M parameter model here, which is easier to train on a free Colab
# instance. Bigger models work better, however: results will be improved
# if able to train on nllb-200-1.3B instead.
model_checkpoint: facebook/nllb-200-1.3B

datasets:
  train:
    huggingface_load:
      # We will load two datasets here: English/KSL Gloss, and also SALT
      # Swahili/English, so that we can try out multi-way translation.
      - path: EzekielMW/Eksl_dataset
        split: train[:-1000]
      - path: EzekielMW/Luo_Swa
        split: train[:-2000]
      - path: sunbird/salt
        name: text-all
        split: train
    source:
      # This is a text translation only, no audio.
      type: text
      # The source text can be any of English, KSL, Swahili or Dholuo.
      language: [eng,ksl,swa,luo]
      preprocessing:
        # The models are case sensitive, so if the training text is all
        # capitals, then it will only learn to translate capital letters and
        # won't understand lower case. Make everything lower case for now.
        - lower_case
        # We can also augment the spelling of the input text, which makes the
        # model more robust to spelling errors.
        - augment_characters
    target:
      type: text
      # The target text with any of English, KSL, Swahili or Dholuo.
      language: [eng,ksl,swa,luo]
      # The models are case sensitive: make everything lower case for now.
      preprocessing:
        - lower_case

    shuffle: True
    allow_same_src_and_tgt_language: False

  validation:
    huggingface_load:
      # Use the last 1000 of the KSL examples for validation.
      - path: EzekielMW/Eksl_dataset
        split: train[-1000:]
      # Use the last 2000 of the Luo examples for validation.
      - path: EzekielMW/Luo_Swa
        split: train[-2000:]
      # Add some Swahili validation text.
      - path: sunbird/salt
        name: text-all
        split: dev
    source:
      type: text
      language: [swa,ksl,eng,luo]
      preprocessing:
        - lower_case
    target:
      type: text
      language: [swa,ksl,eng,luo]
      preprocessing:
        - lower_case
    allow_same_src_and_tgt_language: False
'''

yaml_config = yaml_config.format(
    drive_folder=drive_folder,
    train_batch_size=train_batch_size,
    eval_batch_size=eval_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
)

config = yaml.safe_load(yaml_config)

training_settings = transformers.Seq2SeqTrainingArguments(
    **config["training_args"])

# The pre-trained model that we use has support for some African languages, but
# we need to adapt the tokenizer to languages that it wasn't trained with,
# such as KSL. Here we reuse the token from a different language.
LANGUAGE_CODES = ["eng", "swa", "ksl","luo"]

code_mapping = {
    # Exact/close mapping
    'eng': 'eng_Latn',
    'swa': 'swh_Latn',
    # Random mapping
    'ksl': 'ace_Latn',
    'luo': 'luo_Latn',
}
tokenizer = transformers.NllbTokenizer.from_pretrained(
    config['model_checkpoint'],
    src_lang='eng_Latn',
    tgt_lang='eng_Latn')

offset = tokenizer.sp_model_size + tokenizer.fairseq_offset

for code in LANGUAGE_CODES:
    i = tokenizer.convert_tokens_to_ids(code_mapping[code])
    tokenizer._added_tokens_encoder[code] = i

transformers.generation.utils.ForcedBOSTokenLogitsProcessor = transformers.ForcedBOSTokenLogitsProcessor

# Define a translation function
def translate(text, source_language, target_language):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    inputs = tokenizer(text.lower(), return_tensors="pt").to(device)
    inputs['input_ids'][0][0] = tokenizer.convert_tokens_to_ids(source_language)
    translated_tokens = model.to(device).generate(
        **inputs,
        forced_bos_token_id=tokenizer.convert_tokens_to_ids(target_language),
        max_length=100,
        num_beams=5,
    )
    result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]

    if target_language == 'ksl':
        result = result.upper()

    return result

@app.post("/translate")
async def translate_text(request: Request):
    data = await request.json()
    text = data.get("text")
    source_language = data.get("source_language")
    target_language = data.get("target_language")
    
    translation = translate(text, source_language, target_language)
    return {"translation": translation}



@app.get("/")
async def root():
    return {"message": "Welcome to the translation API!"}