from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware


from transformers import AutoTokenizer, AutoModelForSequenceClassification
import numpy as np

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("Arafath10/reference_page_finder")

# Load the model
model = AutoModelForSequenceClassification.from_pretrained("Arafath10/reference_page_finder")


app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.post("/find_refrence_page")
async def find_refrence_page(request: Request):
    try:
        # Extract the JSON body
        body = await request.json()
        test_text = body.get("text")

        if not test_text:
            raise HTTPException(status_code=400, detail="Missing 'text' field in request body")

        import re
        
        # Remove all types of extra whitespace (spaces, tabs, newlines)
        test_text = re.sub(r'\s+', ' ', test_text).strip()

        def chunk_string(input_string, chunk_size):
            return [input_string[i:i + chunk_size] for i in range(0, len(input_string), chunk_size)]
        
        chunks = chunk_string(test_text, chunk_size=512)
        chunks = reversed(chunks)
        # Output the chunks
        flag = "no reference found"
        for idx, chunk in enumerate(chunks):
          print(f"Chunk {idx + 1} {chunk}")
          inputs = tokenizer(chunk, return_tensors="pt", truncation=True, padding="max_length")
          outputs = model(**inputs)
          predictions = np.argmax(outputs.logits.detach().numpy(), axis=-1)
          #print("Prediction:", "yes reference found" if predictions[0] == 1 else "no reference found")
          if predictions[0] == 1:
            flag = "yes reference found"
            break
        return flag
    except:
        return "error"

#print(main("https://www.keells.com/", "Please analyse reports"))