Spaces:
Building
Building
File size: 5,276 Bytes
36d59d5 413be94 36d59d5 1914537 413be94 1914537 413be94 36d59d5 413be94 36d59d5 413be94 1914537 413be94 1914537 413be94 1914537 413be94 36d59d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
from fastapi import FastAPI, HTTPException, WebSocket
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Dict
from app.bpe_tokenizer import BPETokenizer
import os
import asyncio
import logging
logging.basicConfig(level=logging.ERROR, format="%(asctime)s - %(levelname)s - %(message)s")
app = FastAPI(title="Hindi BPE Tokenizer API")
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize tokenizer and load model
tokenizer = BPETokenizer()
model_path = "bpe_model_latest.json"
if os.path.exists(model_path):
tokenizer.load_model(model_path)
else:
print(f"Warning: Model file {model_path} not found. Starting with empty model.")
class TokenizeRequest(BaseModel):
text: str
class TokenStats(BaseModel):
original_chars: int
token_count: int
compression_ratio: float
unique_tokens: int
class TokenizeResponse(BaseModel):
original_text: str
original_tokens: List[str]
bpe_tokens: List[str]
stats: Dict
token_details: List[Dict]
merge_history: List[Dict]
# Add WebSocket manager
class ConnectionManager:
def __init__(self):
self.active_connections = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
async def broadcast(self, message: dict):
for connection in self.active_connections:
await connection.send_json(message)
manager = ConnectionManager()
# Add WebSocket endpoint
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await manager.connect(websocket)
try:
while True:
await websocket.receive_text()
except:
manager.disconnect(websocket)
@app.post("/tokenize")
async def tokenize_text(request: TokenizeRequest):
# try:
result = tokenizer.tokenize_with_details(request.text)
# Ensure all required fields are present
return {
"original_text": result["original_text"],
"original_tokens": result["original_tokens"],
"bpe_tokens": result["bpe_tokens"],
"original_encoded_tokens": result["original_encoded_tokens"],
"token_numbers": result["token_numbers"],
"stats": {
"original_chars": result["stats"]["original_chars"],
"token_count": len(result["bpe_tokens"]),
"compression_ratio": result["stats"]["compression_ratio"],
"unique_tokens": result["stats"]["unique_tokens"],
},
"token_details": [
{
"token": token,
"type": tokenizer._get_token_type(token),
"length": len(token),
}
for token in result["bpe_tokens"]
],
"merge_history": tokenizer.merge_history[-10:]
if tokenizer.merge_history
else [],
}
# except Exception as e:
# raise HTTPException(status_code=500, detail=str(e))
@app.get("/vocabulary-stats")
async def get_vocab_stats():
return {
"vocab_size": len(tokenizer.vocab),
"most_frequent_tokens": tokenizer.token_usage.most_common(20),
"most_frequent_pairs": dict(
sorted(
tokenizer.pair_frequencies.items(), key=lambda x: x[1], reverse=True
)[:20]
),
}
@app.get("/training-progress")
async def get_training_progress():
if hasattr(tokenizer, "training_progress"):
return tokenizer.training_progress
return {"message": "No training progress available"}
@app.get("/training-stats")
async def get_training_stats():
"""Get detailed training statistics"""
if hasattr(tokenizer, "vocab_growth"):
return {
"vocab_growth": tokenizer.vocab_growth,
"base_vocab_stats": tokenizer.base_vocab_stats,
"current_stats": {
"total_vocab": len(tokenizer.vocab),
"learned_vocab": len(tokenizer.learned_vocab),
"base_vocab": len(tokenizer.BASE_VOCAB),
},
}
return {"message": "No training statistics available"}
@app.post("/resume-training")
async def resume_training(checkpoint_file: str):
"""Resume training from checkpoint"""
try:
tokenizer.learn_bpe(text, resume_from=checkpoint_file)
return {"message": "Training resumed successfully"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/start-training")
async def start_training(request: dict):
try:
text = load_sample_data(
"data/hindi_wiki_corpus.txt",
max_sentences=request.get("max_sentences", 10000),
)
tokenizer = await train_and_save_bpe(
text, vocab_size=request.get("vocab_size", 10000), manager=manager
)
return {"message": "Training completed successfully"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
|