File size: 12,862 Bytes
237da60 828e092 237da60 7f840df e91d59a 237da60 42cbd99 237da60 03fbdfd ea6d2c1 ed9bae0 237da60 b5dd82d 237da60 42cbd99 237da60 b5dd82d 237da60 42cbd99 237da60 b5dd82d 237da60 9eef4ac 237da60 6afd562 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 |
import os
import torch
from fastapi import FastAPI, File, UploadFile, HTTPException, Body
from fastapi.responses import JSONResponse,RedirectResponse
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.cache_utils import DynamicCache , StaticCache
from pydantic import BaseModel
from typing import Optional
import uvicorn
import tempfile
from time import time
from pyngrok import ngrok
os.environ["HF_HOME"] = "/app/hf_cache"
#os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
# Add necessary serialization safety
torch.serialization.add_safe_globals([DynamicCache])
torch.serialization.add_safe_globals([set])
#These lines allow PyTorch to serialize and deserialize these objects without raising errors,
# #ensuring compatibility and functionality during cache saving/loading.
# Minimal generate function for token-by-token generation
def generate(model,
input_ids,
past_key_values,
max_new_tokens=50):
"""
This function performs token-by-token text generation using a pre-trained language model.
Purpose: To generate new text based on input tokens, without loading the full context repeatedly
Process: It takes a model, input IDs, and cached key-values, then generates new tokens one by one up to the specified maximum
Performance: Uses the cached key-values for efficiency and returns only the newly generated tokens
"""
device = model.model.embed_tokens.weight.device
origin_len = input_ids.shape[-1]#Stores the length of the input sequence (number of tokens) before text generation begins./return only the newly
input_ids = input_ids.to(device)#same device as the model.
output_ids = input_ids.clone()#will be updated during the generation process to include newly generated tokens.
next_token = input_ids#the token that will process in the next iteration.
with torch.no_grad():
for _ in range(max_new_tokens):
out = model(
input_ids=next_token,
past_key_values=past_key_values,
use_cache=True
)
logits = out.logits[:, -1, :]#Extracts the logits for the last token
token = torch.argmax(logits, dim=-1, keepdim=True)#highest predicted probability as the next token.
output_ids = torch.cat([output_ids, token], dim=-1)#add the newly generated token
past_key_values = out.past_key_values
next_token = token.to(device)
if model.config.eos_token_id is not None and token.item() == model.config.eos_token_id:
break
return output_ids[:, origin_len:] # Return just the newly generated part
def get_kv_cache(model, tokenizer, prompt):
"""
This function creates a key-value cache for a given prompt.
Purpose: To pre-compute and store the model's internal representations (key-value states) for a prompt
Process: Encodes the prompt, runs it through the model, and captures the resulting cache
Returns: The cache object and the original prompt length for future reference
"""
# Encode prompt
device = model.model.embed_tokens.weight.device
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
cache = DynamicCache() # it grows as text is generated
# Run the model to populate the KV cache:
with torch.no_grad():
_ = model(
input_ids=input_ids,
past_key_values=cache,
use_cache=True
)
return cache, input_ids.shape[-1]
def clean_up(cache, origin_len):
# Make a deep copy of the cache first
new_cache = DynamicCache()
for i in range(len(cache.key_cache)):
new_cache.key_cache.append(cache.key_cache[i].clone())
new_cache.value_cache.append(cache.value_cache[i].clone())
# Remove any tokens appended to the original knowledge
for i in range(len(new_cache.key_cache)):
new_cache.key_cache[i] = new_cache.key_cache[i][:, :, :origin_len, :]
new_cache.value_cache[i] = new_cache.value_cache[i][:, :, :origin_len, :]
return new_cache
#os.environ["TRANSFORMERS_OFFLINE"] = "1"
#os.environ["HF_HUB_OFFLINE"] = "1"
# Path to your local model
# Initialize model and tokenizer
def load_model_and_tokenizer():
model_name = "Locutusque/TinyMistral-248M"
#"tiiuae/falcon-rw-1b"
#"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
#"facebook/opt-125m"
# Load tokenizer and model from disk (without trust_remote_code)
tokenizer = AutoTokenizer.from_pretrained(model_name )
if torch.cuda.is_available():
# Load model on GPU if CUDA is available
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto" # Automatically map model layers to GPU
)
else:
# Load model on CPU if no GPU is available
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32, # Use float32 for compatibility with CPU
low_cpu_mem_usage=True # Reduce memory usage on CPU
)
return model, tokenizer
# Create FastAPI app
app = FastAPI(title="DeepSeek QA with KV Cache API")
# Initialize model and tokenizer at startup
model, tokenizer = load_model_and_tokenizer()
# Global variables to store the cache, origin length, and model/tokenizer
cache_store = {}
class QueryRequest(BaseModel):
query: str
max_new_tokens: Optional[int] = 150
def clean_response(response_text):
"""
Clean up model response by removing redundant tags, repetitions, and formatting issues.
"""
# First, try to extract just the answer content between tags if they exist
import re
# Try to extract content between assistant tags if present
assistant_pattern = re.compile(r'<\|assistant\|>\s*(.*?)(?:<\/\|assistant\|>|<\|user\|>|<\|system\|>)', re.DOTALL)
matches = assistant_pattern.findall(response_text)
if matches:
# Return the first meaningful assistant response
for match in matches:
cleaned = match.strip()
if cleaned and not cleaned.startswith("<|") and len(cleaned) > 5:
return cleaned
# If no proper match found, try more aggressive cleaning
# Remove all tag markers completely
cleaned = re.sub(r'<\|.*?\|>', '', response_text)
cleaned = re.sub(r'<\/\|.*?\|>', '', cleaned)
# Remove duplicate lines (common in generated responses)
lines = cleaned.strip().split('\n')
unique_lines = []
for line in lines:
line = line.strip()
if line and line not in unique_lines:
unique_lines.append(line)
result = '\n'.join(unique_lines)
# Final cleanup - remove any trailing system/user markers
result = re.sub(r'<\/?\|.*?\|>\s*$', '', result)
return result.strip()
@app.post("/upload-document_to_create_KV_cache")
async def upload_document(file: UploadFile = File(...)):
"""Upload a document and create KV cache for it"""
t1 = time()
# Save the uploaded file temporarily
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as temp_file:
temp_file_path = temp_file.name
content = await file.read()
temp_file.write(content)
try:
# Read the document
with open(temp_file_path, "r", encoding="utf-8") as f:
doc_text = f.read()
# Create system prompt with document context
system_prompt = f"""
<|system|>
Answer concisely and precisely, You are an assistant who provides concise factual answers.
<|user|>
Context:
{doc_text}
Question:
""".strip()
# Create KV cache
cache, origin_len = get_kv_cache(model, tokenizer, system_prompt)
# Generate a unique ID for this document/cache
cache_id = f"cache_{int(time())}"
# Store the cache and origin_len
cache_store[cache_id] = {
"cache": cache,
"origin_len": origin_len,
"doc_preview": doc_text[:500] + "..." if len(doc_text) > 500 else doc_text
}
# Clean up the temporary file
os.unlink(temp_file_path)
t2 = time()
return {
"cache_id": cache_id,
"message": "Document uploaded and cache created successfully",
"doc_preview": cache_store[cache_id]["doc_preview"],
"time_taken": f"{t2 - t1:.4f} seconds"
}
except Exception as e:
# Clean up the temporary file in case of error
if os.path.exists(temp_file_path):
os.unlink(temp_file_path)
raise HTTPException(status_code=500, detail=f"Error processing document: {str(e)}")
@app.post("/generate_answer_from_cache/{cache_id}")
async def generate_answer(cache_id: str, request: QueryRequest):
"""Generate an answer to a question based on the uploaded document"""
t1 = time()
# Check if the document/cache exists
if cache_id not in cache_store:
raise HTTPException(status_code=404, detail="Document not found. Please upload it first.")
try:
# Get a clean copy of the cache
current_cache = clean_up(
cache_store[cache_id]["cache"],
cache_store[cache_id]["origin_len"]
)
# Prepare input with just the query
full_prompt = f"""
<|user|>
Question: {request.query}
<|assistant|>
""".strip()
input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids
# Generate response
output_ids = generate(model, input_ids, current_cache, max_new_tokens=request.max_new_tokens)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
rep = clean_response(response)
t2 = time()
return {
"query": request.query,
"answer": rep,
"time_taken": f"{t2 - t1:.4f} seconds"
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error generating answer: {str(e)}")
@app.post("/save_cache/{cache_id}")
async def save_cache(cache_id: str):
"""Save the cache for a document"""
if cache_id not in cache_store:
raise HTTPException(status_code=404, detail="Document not found. Please upload it first.")
try:
# Clean up the cache and save it
cleaned_cache = clean_up(
cache_store[cache_id]["cache"],
cache_store[cache_id]["origin_len"]
)
cache_path = f"{cache_id}_cache.pth"
torch.save(cleaned_cache, cache_path)
return {
"message": f"Cache saved successfully as {cache_path}",
"cache_path": cache_path
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error saving cache: {str(e)}")
@app.post("/load_cache")
async def load_cache(file: UploadFile = File(...)):
"""Load a previously saved cache"""
with tempfile.NamedTemporaryFile(delete=False, suffix=".pth") as temp_file:
temp_file_path = temp_file.name
content = await file.read()
temp_file.write(content)
try:
# Load the cache
loaded_cache = torch.load(temp_file_path)
# Generate a unique ID for this cache
cache_id = f"loaded_cache_{int(time())}"
# Store the cache (we don't have the original document text)
cache_store[cache_id] = {
"cache": loaded_cache,
"origin_len": loaded_cache.key_cache[0].shape[-2],
"doc_preview": "Loaded from cache file"
}
# Clean up the temporary file
os.unlink(temp_file_path)
return {
"cache_id": cache_id,
"message": "Cache loaded successfully"
}
except Exception as e:
# Clean up the temporary file in case of error
if os.path.exists(temp_file_path):
os.unlink(temp_file_path)
raise HTTPException(status_code=500, detail=f"Error loading cache: {str(e)}")
@app.get("/list_of_caches")
async def list_documents():
"""List all uploaded documents/caches"""
documents = {}
for cache_id in cache_store:
documents[cache_id] = {
"doc_preview": cache_store[cache_id]["doc_preview"],
"origin_len": cache_store[cache_id]["origin_len"]
}
return {"documents": documents}
@app.get("/", include_in_schema=False)
async def root():
return RedirectResponse(url="/docs")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860) |