Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Commit 
							
							·
						
						c2d0dc7
	
1
								Parent(s):
							
							59e3ffd
								
fixing app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -2,27 +2,33 @@ import os | |
| 2 | 
             
            import torch
         | 
| 3 | 
             
            from fastapi import FastAPI, HTTPException
         | 
| 4 | 
             
            from pydantic import BaseModel
         | 
| 5 | 
            -
            from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
         | 
| 6 |  | 
| 7 | 
            -
            # Set  | 
| 8 | 
             
            os.environ["HF_HOME"] = "/tmp/huggingface"
         | 
| 9 | 
             
            os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
         | 
| 10 |  | 
| 11 | 
             
            # Model setup
         | 
| 12 | 
             
            MODEL_NAME = "deepseek-ai/deepseek-llm-7b-base"
         | 
| 13 | 
             
            DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 14 | 
            -
            DTYPE = torch.float16 if DEVICE == "cuda" else torch.bfloat16
         | 
| 15 |  | 
| 16 | 
            -
            # Load model  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 17 | 
             
            tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
         | 
| 18 | 
             
            model = AutoModelForCausalLM.from_pretrained(
         | 
| 19 | 
            -
                MODEL_NAME, | 
|  | |
|  | |
|  | |
| 20 | 
             
            )
         | 
| 21 |  | 
| 22 | 
            -
            #  | 
| 23 | 
            -
             | 
| 24 | 
            -
            generation_config.pad_token_id = generation_config.eos_token_id
         | 
| 25 | 
            -
            generation_config.use_cache = True  # Speed up decoding
         | 
| 26 |  | 
| 27 | 
             
            # FastAPI app
         | 
| 28 | 
             
            app = FastAPI()
         | 
| @@ -30,28 +36,26 @@ app = FastAPI() | |
| 30 | 
             
            # Request payload
         | 
| 31 | 
             
            class TextGenerationRequest(BaseModel):
         | 
| 32 | 
             
                prompt: str
         | 
| 33 | 
            -
                max_tokens: int = 512  # Default to 512 | 
| 34 |  | 
| 35 | 
             
            @app.post("/generate")
         | 
| 36 | 
             
            async def generate_text(request: TextGenerationRequest):
         | 
| 37 | 
             
                try:
         | 
| 38 | 
            -
                    # Tokenize input and move tensors to the correct device
         | 
| 39 | 
             
                    inputs = tokenizer(request.prompt, return_tensors="pt", padding=True, truncation=True).to(DEVICE)
         | 
| 40 |  | 
| 41 | 
            -
                    # Use no_grad() for faster inference
         | 
| 42 | 
             
                    with torch.no_grad():
         | 
| 43 | 
             
                        outputs = model.generate(
         | 
| 44 | 
             
                            **inputs,
         | 
| 45 | 
             
                            max_new_tokens=request.max_tokens,
         | 
| 46 | 
            -
                            do_sample=True, | 
| 47 | 
            -
                            temperature=0.7, | 
| 48 | 
            -
                            top_k=50, | 
| 49 | 
            -
                            top_p=0.9, | 
| 50 | 
            -
                            repetition_penalty=1. | 
|  | |
| 51 | 
             
                        )
         | 
| 52 |  | 
| 53 | 
            -
                     | 
| 54 | 
            -
                    result = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
         | 
| 55 | 
             
                    return {"generated_text": result}
         | 
| 56 |  | 
| 57 | 
             
                except Exception as e:
         | 
|  | |
| 2 | 
             
            import torch
         | 
| 3 | 
             
            from fastapi import FastAPI, HTTPException
         | 
| 4 | 
             
            from pydantic import BaseModel
         | 
| 5 | 
            +
            from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, BitsAndBytesConfig
         | 
| 6 |  | 
| 7 | 
            +
            # Set cache directory
         | 
| 8 | 
             
            os.environ["HF_HOME"] = "/tmp/huggingface"
         | 
| 9 | 
             
            os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
         | 
| 10 |  | 
| 11 | 
             
            # Model setup
         | 
| 12 | 
             
            MODEL_NAME = "deepseek-ai/deepseek-llm-7b-base"
         | 
| 13 | 
             
            DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
         | 
|  | |
| 14 |  | 
| 15 | 
            +
            # Load 4-bit quantized model (for speed & efficiency)
         | 
| 16 | 
            +
            bnb_config = BitsAndBytesConfig(
         | 
| 17 | 
            +
                load_in_4bit=True,  # Enable 4-bit inference
         | 
| 18 | 
            +
                bnb_4bit_compute_dtype=torch.float16,
         | 
| 19 | 
            +
                bnb_4bit_use_double_quant=True,
         | 
| 20 | 
            +
            )
         | 
| 21 | 
            +
             | 
| 22 | 
             
            tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
         | 
| 23 | 
             
            model = AutoModelForCausalLM.from_pretrained(
         | 
| 24 | 
            +
                MODEL_NAME,
         | 
| 25 | 
            +
                quantization_config=bnb_config,
         | 
| 26 | 
            +
                device_map="auto",
         | 
| 27 | 
            +
                attn_implementation="flash_attention_2"  # Enables Flash Attention
         | 
| 28 | 
             
            )
         | 
| 29 |  | 
| 30 | 
            +
            # Compile for even faster inference (PyTorch 2.0+)
         | 
| 31 | 
            +
            model = torch.compile(model)
         | 
|  | |
|  | |
| 32 |  | 
| 33 | 
             
            # FastAPI app
         | 
| 34 | 
             
            app = FastAPI()
         | 
|  | |
| 36 | 
             
            # Request payload
         | 
| 37 | 
             
            class TextGenerationRequest(BaseModel):
         | 
| 38 | 
             
                prompt: str
         | 
| 39 | 
            +
                max_tokens: int = 512  # Default to 512
         | 
| 40 |  | 
| 41 | 
             
            @app.post("/generate")
         | 
| 42 | 
             
            async def generate_text(request: TextGenerationRequest):
         | 
| 43 | 
             
                try:
         | 
|  | |
| 44 | 
             
                    inputs = tokenizer(request.prompt, return_tensors="pt", padding=True, truncation=True).to(DEVICE)
         | 
| 45 |  | 
|  | |
| 46 | 
             
                    with torch.no_grad():
         | 
| 47 | 
             
                        outputs = model.generate(
         | 
| 48 | 
             
                            **inputs,
         | 
| 49 | 
             
                            max_new_tokens=request.max_tokens,
         | 
| 50 | 
            +
                            do_sample=True,
         | 
| 51 | 
            +
                            temperature=0.7,
         | 
| 52 | 
            +
                            top_k=50,
         | 
| 53 | 
            +
                            top_p=0.9,
         | 
| 54 | 
            +
                            repetition_penalty=1.05,
         | 
| 55 | 
            +
                            use_cache=True,
         | 
| 56 | 
             
                        )
         | 
| 57 |  | 
| 58 | 
            +
                    result = tokenizer.decode(outputs[0], skip_special_tokens=True)
         | 
|  | |
| 59 | 
             
                    return {"generated_text": result}
         | 
| 60 |  | 
| 61 | 
             
                except Exception as e:
         |