from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Tuple, List, Dict
import torch

def load_model(
    model_name: str,
    dtype: torch.dtype = torch.float32,
) -> Tuple[AutoModelForCausalLM, any]:
    """
    Load and initialize the language model for CPU-only inference.
    
    Args:
        model_name (str): Name of the pre-trained model to load
        dtype (torch.dtype): Data type for model weights (default: torch.float32)
    
    Returns:
        Tuple[AutoModelForCausalLM, any]: Tuple containing the model and tokenizer
    """
    kwargs = {
        "device_map": "cpu",  # Explicitly set to CPU
        "torch_dtype": dtype,
        "low_cpu_mem_usage": True,  # Optimize memory usage for CPU
    }

    # Load the tokenizer
    tokenizer = AutoTokenizer.from_pretrained("CodeTranslatorLLM/LinguistLLM")

    # Load the model
    model = AutoModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path="CodeTranslatorLLM/LinguistLLM",
        **kwargs
    )

    model.eval()  # Set model to evaluation mode
    return model, tokenizer

def prepare_input(
    messages: List[Dict[str, str]],
    tokenizer: any,
) -> torch.Tensor:
    """
    Prepare input for the model by applying chat template and tokenization.
    
    Args:
        messages (List[Dict[str, str]]): List of message dictionaries
        tokenizer: The tokenizer instance
    
    Returns:
        torch.Tensor: Prepared input tensor
    """
    # Combine messages into a single string (simple concatenation for this example)
    input_text = " ".join([msg["content"] for msg in messages])
    # Tokenize the input
    return tokenizer(
        input_text,
        return_tensors="pt",
        padding=True,
        truncation=True,
    )["input_ids"]

def generate_response(
    model: AutoModelForCausalLM,
    inputs: torch.Tensor,
    tokenizer: any,
    max_new_tokens: int = 200,
) -> str:
    """
    Generate response using the model.
    
    Args:
        model (AutoModelForCausalLM): The language model
        inputs (torch.Tensor): Prepared input tensor
        tokenizer: The tokenizer instance
        max_new_tokens (int): Maximum number of tokens to generate
    
    Returns:
        str: Generated response
    """
    outputs = model.generate(
        inputs,
        max_new_tokens=max_new_tokens,
        do_sample=False,  # Deterministic generation for reproducibility
    )
    # Decode the generated tokens
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

def main(
    USER_INPUT_CODE: str,
    USER_INPUT_EXPLANATION: str,
    MODEL_PATH: str,
):
    """
    Main function to demonstrate the inference pipeline.
    """
    # Example messages
    messages = [
        {
            "role": "user",
            "content": f"[Fortran Code]\n{USER_INPUT_CODE}\n[Fortran Code Explain]\n{USER_INPUT_EXPLANATION}"
        }
    ]
    
    # Load model
    model, tokenizer = load_model(MODEL_PATH)
    
    # Prepare input
    inputs = prepare_input(messages, tokenizer)
    
    # Generate response
    response = generate_response(model, inputs, tokenizer)
    print("Generated Response:\n", response)

if __name__ == "__main__":
    # Define your Fortran code and explanation
    USER_INPUT_CODE = """
    program sum_of_numbers
        implicit none
        integer :: n, i, sum

        ! Initialize variables
        sum = 0

        ! Get user input
        print *, "Enter a positive integer:"
        read *, n

        ! Calculate the sum of numbers from 1 to n
        do i = 1, n
            sum = sum + i
        end do

        ! Print the result
        print *, "The sum of numbers from 1 to", n, "is", sum
    end program sum_of_numbers
    """
    USER_INPUT_EXPLANATION = """
    The provided Fortran code snippet is a program that calculates the sum of integers from 1 to n, where n is provided by the user. 
    It uses a simple procedural approach, including variable declarations, input handling, and a loop for the summation.

    The program starts by initializing variables and prompting the user for input. 
    It then calculates the sum using a do loop, iterating from 1 to n, and accumulating the result in a variable. 
    Finally, it prints the computed sum to the console.

    This program demonstrates a straightforward application of Fortran's capabilities for handling loops and basic arithmetic operations.
    """
    # Path to your model
    MODEL_PATH = "lora_model"

    # Run the main function
    main(USER_INPUT_CODE, USER_INPUT_EXPLANATION, MODEL_PATH)