gobeldan commited on
Commit
c2df3e1
·
verified ·
1 Parent(s): 8b02fc8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -2
app.py CHANGED
@@ -27,7 +27,16 @@ MAX_MAX_NEW_TOKENS = 2048
27
  DEFAULT_MAX_NEW_TOKENS = 1024
28
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
29
 
30
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
31
 
32
  # model_id = "google/gemma-3-270m-it"
33
  model_id = "unsloth/gemma-3-270m-it"
@@ -36,7 +45,7 @@ model = AutoModelForCausalLM.from_pretrained(
36
  model_id,
37
  device_map="auto",
38
  torch_dtype=torch.bfloat16,
39
- attn_implementation="flash_attention_2",
40
  trust_remote_code=True,
41
  )
42
  model.config.sliding_window = 4096
 
27
  DEFAULT_MAX_NEW_TOKENS = 1024
28
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
29
 
30
+ #device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
31
+ # Pick attention backend based on device availability
32
+ if torch.cuda.is_available():
33
+ device = "cuda"
34
+ attn_impl = "flash_attention_2" # or "flash" depending on the library
35
+ torch_dtype = torch.bfloat16 # or torch.float16
36
+ else:
37
+ device = "cpu"
38
+ attn_impl = "eager"
39
+ torch_dtype = torch.bfloat16 # or float32, bfloat16 supported on CPUs with AVX512-BF16 or AMX (e.g., Intel Ice Lake / Sapphire Rapids, some newer AMD). But many ops may still fall back to float32.
40
 
41
  # model_id = "google/gemma-3-270m-it"
42
  model_id = "unsloth/gemma-3-270m-it"
 
45
  model_id,
46
  device_map="auto",
47
  torch_dtype=torch.bfloat16,
48
+ attn_implementation=attn_impl,
49
  trust_remote_code=True,
50
  )
51
  model.config.sliding_window = 4096