Manju017 commited on
Commit
b0ea0b7
·
verified ·
1 Parent(s): bce1941

Fix model loading and device map inference

Browse files
Files changed (1) hide show
  1. app.py +8 -9
app.py CHANGED
@@ -8,15 +8,14 @@ model_name = "ai4bharat/Airavata"
8
  # Load the tokenizer
9
  tokenizer = AutoTokenizer.from_pretrained(model_name)
10
 
11
- # Automatically determine the device map
12
- device_map = infer_auto_device_map(model_name)
13
-
14
- # Load the model with the device map
15
- model = AutoModelForCausalLM.from_pretrained(
16
- model_name,
17
- device_map=device_map,
18
- load_in_8bit=True # Use 8-bit precision for reduced memory usage
19
- )
20
 
21
  # Define the inference function
22
  def generate_text(prompt):
 
8
  # Load the tokenizer
9
  tokenizer = AutoTokenizer.from_pretrained(model_name)
10
 
11
+ # Load the model first
12
+ model = AutoModelForCausalLM.from_pretrained(model_name, load_in_8bit=True)
13
+
14
+ # Now infer the device map
15
+ device_map = infer_auto_device_map(model)
16
+
17
+ # Move model to the appropriate device based on device_map
18
+ model.to(device_map)
 
19
 
20
  # Define the inference function
21
  def generate_text(prompt):