from fastapi import FastAPI, HTTPException, Request from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from transformers import AutoTokenizer, AutoModelForSequenceClassification import numpy as np # Load the tokenizer tokenizer = AutoTokenizer.from_pretrained("Arafath10/reference_page_finder") # Load the model model = AutoModelForSequenceClassification.from_pretrained("Arafath10/reference_page_finder") app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.post("/find_refrence_page") async def find_refrence_page(request: Request): try: # Extract the JSON body body = await request.json() test_text = body.get("text") if not test_text: raise HTTPException(status_code=400, detail="Missing 'text' field in request body") import re # Remove all types of extra whitespace (spaces, tabs, newlines) test_text = re.sub(r'\s+', ' ', test_text).strip() def chunk_string(input_string, chunk_size): return [input_string[i:i + chunk_size] for i in range(0, len(input_string), chunk_size)] chunks = chunk_string(test_text, chunk_size=512) chunks = reversed(chunks) # Output the chunks flag = "no reference found" for idx, chunk in enumerate(chunks): print(f"Chunk {idx + 1} {chunk}") inputs = tokenizer(chunk, return_tensors="pt", truncation=True, padding="max_length") outputs = model(**inputs) predictions = np.argmax(outputs.logits.detach().numpy(), axis=-1) #print("Prediction:", "yes reference found" if predictions[0] == 1 else "no reference found") if predictions[0] == 1: flag = "yes reference found" break return flag except: return "error" #print(main("https://www.keells.com/", "Please analyse reports"))