exemple3 / app.py
kouki321's picture
Update app.py
b5dd82d verified
raw
history blame
12.9 kB
import os
import torch
from fastapi import FastAPI, File, UploadFile, HTTPException, Body
from fastapi.responses import JSONResponse,RedirectResponse
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.cache_utils import DynamicCache , StaticCache
from pydantic import BaseModel
from typing import Optional
import uvicorn
import tempfile
from time import time
from pyngrok import ngrok
os.environ["HF_HOME"] = "/app/hf_cache"
#os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
# Add necessary serialization safety
torch.serialization.add_safe_globals([DynamicCache])
torch.serialization.add_safe_globals([set])
#These lines allow PyTorch to serialize and deserialize these objects without raising errors,
# #ensuring compatibility and functionality during cache saving/loading.
# Minimal generate function for token-by-token generation
def generate(model,
input_ids,
past_key_values,
max_new_tokens=50):
"""
This function performs token-by-token text generation using a pre-trained language model.
Purpose: To generate new text based on input tokens, without loading the full context repeatedly
Process: It takes a model, input IDs, and cached key-values, then generates new tokens one by one up to the specified maximum
Performance: Uses the cached key-values for efficiency and returns only the newly generated tokens
"""
device = model.model.embed_tokens.weight.device
origin_len = input_ids.shape[-1]#Stores the length of the input sequence (number of tokens) before text generation begins./return only the newly
input_ids = input_ids.to(device)#same device as the model.
output_ids = input_ids.clone()#will be updated during the generation process to include newly generated tokens.
next_token = input_ids#the token that will process in the next iteration.
with torch.no_grad():
for _ in range(max_new_tokens):
out = model(
input_ids=next_token,
past_key_values=past_key_values,
use_cache=True
)
logits = out.logits[:, -1, :]#Extracts the logits for the last token
token = torch.argmax(logits, dim=-1, keepdim=True)#highest predicted probability as the next token.
output_ids = torch.cat([output_ids, token], dim=-1)#add the newly generated token
past_key_values = out.past_key_values
next_token = token.to(device)
if model.config.eos_token_id is not None and token.item() == model.config.eos_token_id:
break
return output_ids[:, origin_len:] # Return just the newly generated part
def get_kv_cache(model, tokenizer, prompt):
"""
This function creates a key-value cache for a given prompt.
Purpose: To pre-compute and store the model's internal representations (key-value states) for a prompt
Process: Encodes the prompt, runs it through the model, and captures the resulting cache
Returns: The cache object and the original prompt length for future reference
"""
# Encode prompt
device = model.model.embed_tokens.weight.device
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
cache = DynamicCache() # it grows as text is generated
# Run the model to populate the KV cache:
with torch.no_grad():
_ = model(
input_ids=input_ids,
past_key_values=cache,
use_cache=True
)
return cache, input_ids.shape[-1]
def clean_up(cache, origin_len):
# Make a deep copy of the cache first
new_cache = DynamicCache()
for i in range(len(cache.key_cache)):
new_cache.key_cache.append(cache.key_cache[i].clone())
new_cache.value_cache.append(cache.value_cache[i].clone())
# Remove any tokens appended to the original knowledge
for i in range(len(new_cache.key_cache)):
new_cache.key_cache[i] = new_cache.key_cache[i][:, :, :origin_len, :]
new_cache.value_cache[i] = new_cache.value_cache[i][:, :, :origin_len, :]
return new_cache
#os.environ["TRANSFORMERS_OFFLINE"] = "1"
#os.environ["HF_HUB_OFFLINE"] = "1"
# Path to your local model
# Initialize model and tokenizer
def load_model_and_tokenizer():
model_name = "Locutusque/TinyMistral-248M"
#"tiiuae/falcon-rw-1b"
#"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
#"facebook/opt-125m"
# Load tokenizer and model from disk (without trust_remote_code)
tokenizer = AutoTokenizer.from_pretrained(model_name )
if torch.cuda.is_available():
# Load model on GPU if CUDA is available
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto" # Automatically map model layers to GPU
)
else:
# Load model on CPU if no GPU is available
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32, # Use float32 for compatibility with CPU
low_cpu_mem_usage=True # Reduce memory usage on CPU
)
return model, tokenizer
# Create FastAPI app
app = FastAPI(title="DeepSeek QA with KV Cache API")
# Initialize model and tokenizer at startup
model, tokenizer = load_model_and_tokenizer()
# Global variables to store the cache, origin length, and model/tokenizer
cache_store = {}
class QueryRequest(BaseModel):
query: str
max_new_tokens: Optional[int] = 150
def clean_response(response_text):
"""
Clean up model response by removing redundant tags, repetitions, and formatting issues.
"""
# First, try to extract just the answer content between tags if they exist
import re
# Try to extract content between assistant tags if present
assistant_pattern = re.compile(r'<\|assistant\|>\s*(.*?)(?:<\/\|assistant\|>|<\|user\|>|<\|system\|>)', re.DOTALL)
matches = assistant_pattern.findall(response_text)
if matches:
# Return the first meaningful assistant response
for match in matches:
cleaned = match.strip()
if cleaned and not cleaned.startswith("<|") and len(cleaned) > 5:
return cleaned
# If no proper match found, try more aggressive cleaning
# Remove all tag markers completely
cleaned = re.sub(r'<\|.*?\|>', '', response_text)
cleaned = re.sub(r'<\/\|.*?\|>', '', cleaned)
# Remove duplicate lines (common in generated responses)
lines = cleaned.strip().split('\n')
unique_lines = []
for line in lines:
line = line.strip()
if line and line not in unique_lines:
unique_lines.append(line)
result = '\n'.join(unique_lines)
# Final cleanup - remove any trailing system/user markers
result = re.sub(r'<\/?\|.*?\|>\s*$', '', result)
return result.strip()
@app.post("/upload-document_to_create_KV_cache")
async def upload_document(file: UploadFile = File(...)):
"""Upload a document and create KV cache for it"""
t1 = time()
# Save the uploaded file temporarily
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as temp_file:
temp_file_path = temp_file.name
content = await file.read()
temp_file.write(content)
try:
# Read the document
with open(temp_file_path, "r", encoding="utf-8") as f:
doc_text = f.read()
# Create system prompt with document context
system_prompt = f"""
<|system|>
Answer concisely and precisely, You are an assistant who provides concise factual answers.
<|user|>
Context:
{doc_text}
Question:
""".strip()
# Create KV cache
cache, origin_len = get_kv_cache(model, tokenizer, system_prompt)
# Generate a unique ID for this document/cache
cache_id = f"cache_{int(time())}"
# Store the cache and origin_len
cache_store[cache_id] = {
"cache": cache,
"origin_len": origin_len,
"doc_preview": doc_text[:500] + "..." if len(doc_text) > 500 else doc_text
}
# Clean up the temporary file
os.unlink(temp_file_path)
t2 = time()
return {
"cache_id": cache_id,
"message": "Document uploaded and cache created successfully",
"doc_preview": cache_store[cache_id]["doc_preview"],
"time_taken": f"{t2 - t1:.4f} seconds"
}
except Exception as e:
# Clean up the temporary file in case of error
if os.path.exists(temp_file_path):
os.unlink(temp_file_path)
raise HTTPException(status_code=500, detail=f"Error processing document: {str(e)}")
@app.post("/generate_answer_from_cache/{cache_id}")
async def generate_answer(cache_id: str, request: QueryRequest):
"""Generate an answer to a question based on the uploaded document"""
t1 = time()
# Check if the document/cache exists
if cache_id not in cache_store:
raise HTTPException(status_code=404, detail="Document not found. Please upload it first.")
try:
# Get a clean copy of the cache
current_cache = clean_up(
cache_store[cache_id]["cache"],
cache_store[cache_id]["origin_len"]
)
# Prepare input with just the query
full_prompt = f"""
<|user|>
Question: {request.query}
<|assistant|>
""".strip()
input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids
# Generate response
output_ids = generate(model, input_ids, current_cache, max_new_tokens=request.max_new_tokens)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
rep = clean_response(response)
t2 = time()
return {
"query": request.query,
"answer": rep,
"time_taken": f"{t2 - t1:.4f} seconds"
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error generating answer: {str(e)}")
@app.post("/save_cache/{cache_id}")
async def save_cache(cache_id: str):
"""Save the cache for a document"""
if cache_id not in cache_store:
raise HTTPException(status_code=404, detail="Document not found. Please upload it first.")
try:
# Clean up the cache and save it
cleaned_cache = clean_up(
cache_store[cache_id]["cache"],
cache_store[cache_id]["origin_len"]
)
cache_path = f"{cache_id}_cache.pth"
torch.save(cleaned_cache, cache_path)
return {
"message": f"Cache saved successfully as {cache_path}",
"cache_path": cache_path
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error saving cache: {str(e)}")
@app.post("/load_cache")
async def load_cache(file: UploadFile = File(...)):
"""Load a previously saved cache"""
with tempfile.NamedTemporaryFile(delete=False, suffix=".pth") as temp_file:
temp_file_path = temp_file.name
content = await file.read()
temp_file.write(content)
try:
# Load the cache
loaded_cache = torch.load(temp_file_path)
# Generate a unique ID for this cache
cache_id = f"loaded_cache_{int(time())}"
# Store the cache (we don't have the original document text)
cache_store[cache_id] = {
"cache": loaded_cache,
"origin_len": loaded_cache.key_cache[0].shape[-2],
"doc_preview": "Loaded from cache file"
}
# Clean up the temporary file
os.unlink(temp_file_path)
return {
"cache_id": cache_id,
"message": "Cache loaded successfully"
}
except Exception as e:
# Clean up the temporary file in case of error
if os.path.exists(temp_file_path):
os.unlink(temp_file_path)
raise HTTPException(status_code=500, detail=f"Error loading cache: {str(e)}")
@app.get("/list_of_caches")
async def list_documents():
"""List all uploaded documents/caches"""
documents = {}
for cache_id in cache_store:
documents[cache_id] = {
"doc_preview": cache_store[cache_id]["doc_preview"],
"origin_len": cache_store[cache_id]["origin_len"]
}
return {"documents": documents}
@app.get("/", include_in_schema=False)
async def root():
return RedirectResponse(url="/docs")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)