File size: 12,862 Bytes
237da60
 
 
828e092
237da60
 
 
 
 
 
 
 
7f840df
 
e91d59a
237da60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42cbd99
 
237da60
 
 
 
 
03fbdfd
 
ea6d2c1
ed9bae0
237da60
b5dd82d
237da60
 
 
42cbd99
237da60
 
b5dd82d
237da60
 
 
 
42cbd99
237da60
 
b5dd82d
237da60
 
 
 
 
 
 
 
 
9eef4ac
 
237da60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6afd562
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
import os
import torch
from fastapi import FastAPI, File, UploadFile, HTTPException, Body
from fastapi.responses import JSONResponse,RedirectResponse
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.cache_utils import DynamicCache , StaticCache
from pydantic import BaseModel
from typing import Optional
import uvicorn
import tempfile
from time import time
from pyngrok import ngrok
 
os.environ["HF_HOME"] = "/app/hf_cache"
#os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"

# Add necessary serialization safety
torch.serialization.add_safe_globals([DynamicCache])
torch.serialization.add_safe_globals([set])
#These lines allow PyTorch to serialize and deserialize these objects without raising errors, 
# #ensuring compatibility and functionality during cache saving/loading.

# Minimal generate function for token-by-token generation
def generate(model, 
             input_ids, 
             past_key_values, 
             max_new_tokens=50):
    """
    This function performs token-by-token text generation using a pre-trained language model.
            Purpose: To generate new text based on input tokens, without loading the full context repeatedly
            Process: It takes a model, input IDs, and cached key-values, then generates new tokens one by one up to the specified maximum
            Performance: Uses the cached key-values for efficiency and returns only the newly generated tokens
    """
    device = model.model.embed_tokens.weight.device
    origin_len = input_ids.shape[-1]#Stores the length of the input sequence (number of tokens) before text generation begins./return only the newly 
    input_ids = input_ids.to(device)#same device as the model.
    output_ids = input_ids.clone()#will be updated during the generation process to include newly generated tokens.
    next_token = input_ids#the token that will process in the next iteration.
    with torch.no_grad():
        for _ in range(max_new_tokens):
            out = model(
                input_ids=next_token,
                past_key_values=past_key_values,
                use_cache=True
            )
            logits = out.logits[:, -1, :]#Extracts the logits for the last token
            token = torch.argmax(logits, dim=-1, keepdim=True)#highest predicted probability as the next token.
            output_ids = torch.cat([output_ids, token], dim=-1)#add the newly generated token
            past_key_values = out.past_key_values
            next_token = token.to(device)
            if model.config.eos_token_id is not None and token.item() == model.config.eos_token_id:
                break
    return output_ids[:, origin_len:] # Return just the newly generated part
          
def get_kv_cache(model, tokenizer, prompt):
    """
    This function creates a key-value cache for a given prompt.
        Purpose: To pre-compute and store the model's internal representations (key-value states) for a prompt
        Process: Encodes the prompt, runs it through the model, and captures the resulting cache
        Returns: The cache object and the original prompt length for future reference
    """
    # Encode prompt
    device = model.model.embed_tokens.weight.device
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    cache = DynamicCache()  # it grows as text is generated
    
    # Run the model to populate the KV cache:
    with torch.no_grad():
        _ = model(
            input_ids=input_ids,
            past_key_values=cache,
            use_cache=True
        )
    return cache, input_ids.shape[-1]

def clean_up(cache, origin_len):
    # Make a deep copy of the cache first
    new_cache = DynamicCache()
    for i in range(len(cache.key_cache)):
        new_cache.key_cache.append(cache.key_cache[i].clone())
        new_cache.value_cache.append(cache.value_cache[i].clone())
    
    # Remove any tokens appended to the original knowledge
    for i in range(len(new_cache.key_cache)):
        new_cache.key_cache[i] = new_cache.key_cache[i][:, :, :origin_len, :]
        new_cache.value_cache[i] = new_cache.value_cache[i][:, :, :origin_len, :]
    return new_cache
#os.environ["TRANSFORMERS_OFFLINE"] = "1"
#os.environ["HF_HUB_OFFLINE"] = "1"

# Path to your local model

# Initialize model and tokenizer
def load_model_and_tokenizer():
    model_name = "Locutusque/TinyMistral-248M"
    #"tiiuae/falcon-rw-1b"
    #"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
#"facebook/opt-125m"
    # Load tokenizer and model from disk (without trust_remote_code)
    tokenizer = AutoTokenizer.from_pretrained(model_name  )
    if torch.cuda.is_available():
        # Load model on GPU if CUDA is available
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            device_map="auto"  # Automatically map model layers to GPU
              
        )
    else:
        # Load model on CPU if no GPU is available
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float32,  # Use float32 for compatibility with CPU
            low_cpu_mem_usage=True  # Reduce memory usage on CPU
            
        )
    return model, tokenizer

# Create FastAPI app
app = FastAPI(title="DeepSeek QA with KV Cache API")


# Initialize model and tokenizer at startup
model, tokenizer = load_model_and_tokenizer()
# Global variables to store the cache, origin length, and model/tokenizer
cache_store = {}
class QueryRequest(BaseModel):
    query: str
    max_new_tokens: Optional[int] = 150
def clean_response(response_text):
    """
    Clean up model response by removing redundant tags, repetitions, and formatting issues.
    """
    # First, try to extract just the answer content between tags if they exist
    import re
    
    # Try to extract content between assistant tags if present
    assistant_pattern = re.compile(r'<\|assistant\|>\s*(.*?)(?:<\/\|assistant\|>|<\|user\|>|<\|system\|>)', re.DOTALL)
    matches = assistant_pattern.findall(response_text)
    
    if matches:
        # Return the first meaningful assistant response
        for match in matches:
            cleaned = match.strip()
            if cleaned and not cleaned.startswith("<|") and len(cleaned) > 5:
                return cleaned
    
    # If no proper match found, try more aggressive cleaning
    # Remove all tag markers completely
    cleaned = re.sub(r'<\|.*?\|>', '', response_text)
    cleaned = re.sub(r'<\/\|.*?\|>', '', cleaned)
    
    # Remove duplicate lines (common in generated responses)
    lines = cleaned.strip().split('\n')
    unique_lines = []
    for line in lines:
        line = line.strip()
        if line and line not in unique_lines:
            unique_lines.append(line)
    
    result = '\n'.join(unique_lines)
    
    # Final cleanup - remove any trailing system/user markers
    result = re.sub(r'<\/?\|.*?\|>\s*$', '', result)
    
    return result.strip()
@app.post("/upload-document_to_create_KV_cache")
async def upload_document(file: UploadFile = File(...)):
    """Upload a document and create KV cache for it"""
    t1 = time()
    
    # Save the uploaded file temporarily
    with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as temp_file:
        temp_file_path = temp_file.name
        content = await file.read()
        temp_file.write(content)
    
    try:
        # Read the document
        with open(temp_file_path, "r", encoding="utf-8") as f:
            doc_text = f.read()
        
        # Create system prompt with document context
        system_prompt = f"""
        <|system|>
        Answer concisely and precisely, You are an assistant who provides concise factual answers.
        <|user|>
        Context:
        {doc_text}
        Question:
        """.strip()
        
        # Create KV cache
        cache, origin_len = get_kv_cache(model, tokenizer, system_prompt)
        
        # Generate a unique ID for this document/cache
        cache_id = f"cache_{int(time())}"
        
        # Store the cache and origin_len
        cache_store[cache_id] = {
            "cache": cache,
            "origin_len": origin_len,
            "doc_preview": doc_text[:500] + "..." if len(doc_text) > 500 else doc_text
        }
        
        # Clean up the temporary file
        os.unlink(temp_file_path)
        
        t2 = time()
        
        return {
            "cache_id": cache_id,
            "message": "Document uploaded and cache created successfully",
            "doc_preview": cache_store[cache_id]["doc_preview"],
            "time_taken": f"{t2 - t1:.4f} seconds"
        }
    
    except Exception as e:
        # Clean up the temporary file in case of error
        if os.path.exists(temp_file_path):
            os.unlink(temp_file_path)
        raise HTTPException(status_code=500, detail=f"Error processing document: {str(e)}")

@app.post("/generate_answer_from_cache/{cache_id}")
async def generate_answer(cache_id: str, request: QueryRequest):
    """Generate an answer to a question based on the uploaded document"""
    t1 = time()
    
    # Check if the document/cache exists
    if cache_id not in cache_store:
        raise HTTPException(status_code=404, detail="Document not found. Please upload it first.")
    
    try:
        # Get a clean copy of the cache
        current_cache = clean_up(
            cache_store[cache_id]["cache"], 
            cache_store[cache_id]["origin_len"]
        )
        
        # Prepare input with just the query
        full_prompt = f"""
        <|user|>
        Question: {request.query}
        <|assistant|>
        """.strip()
        
        input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids
        
        # Generate response
        output_ids = generate(model, input_ids, current_cache, max_new_tokens=request.max_new_tokens)
        response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
        rep =  clean_response(response)
        t2 = time()
        
        return {
            "query": request.query,
            "answer": rep,
            "time_taken": f"{t2 - t1:.4f} seconds"
        }
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error generating answer: {str(e)}")

@app.post("/save_cache/{cache_id}")
async def save_cache(cache_id: str):
    """Save the cache for a document"""
    if cache_id not in cache_store:
        raise HTTPException(status_code=404, detail="Document not found. Please upload it first.")
    
    try:
        # Clean up the cache and save it
        cleaned_cache = clean_up(
            cache_store[cache_id]["cache"], 
            cache_store[cache_id]["origin_len"]
        )
        
        cache_path = f"{cache_id}_cache.pth"
        torch.save(cleaned_cache, cache_path)
        
        return {
            "message": f"Cache saved successfully as {cache_path}",
            "cache_path": cache_path
        }
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error saving cache: {str(e)}")

@app.post("/load_cache")
async def load_cache(file: UploadFile = File(...)):
    """Load a previously saved cache"""
    with tempfile.NamedTemporaryFile(delete=False, suffix=".pth") as temp_file:
        temp_file_path = temp_file.name
        content = await file.read()
        temp_file.write(content)
    
    try:
        # Load the cache
        loaded_cache = torch.load(temp_file_path)
        
        # Generate a unique ID for this cache
        cache_id = f"loaded_cache_{int(time())}"
        
        # Store the cache (we don't have the original document text)
        cache_store[cache_id] = {
            "cache": loaded_cache,
            "origin_len": loaded_cache.key_cache[0].shape[-2],
            "doc_preview": "Loaded from cache file"
        }
        
        # Clean up the temporary file
        os.unlink(temp_file_path)
        
        return {
            "cache_id": cache_id,
            "message": "Cache loaded successfully"
        }
    
    except Exception as e:
        # Clean up the temporary file in case of error
        if os.path.exists(temp_file_path):
            os.unlink(temp_file_path)
        raise HTTPException(status_code=500, detail=f"Error loading cache: {str(e)}")

@app.get("/list_of_caches")
async def list_documents():
    """List all uploaded documents/caches"""
    documents = {}
    for cache_id in cache_store:
        documents[cache_id] = {
            "doc_preview": cache_store[cache_id]["doc_preview"],
            "origin_len": cache_store[cache_id]["origin_len"]
        }
    
    return {"documents": documents}

@app.get("/", include_in_schema=False)
async def root():
    return RedirectResponse(url="/docs")

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)