danhtran2mind commited on
Commit
9bd8a79
·
verified ·
1 Parent(s): 8a17ca8

Update gradio_app/model_handler.py

Browse files
Files changed (1) hide show
  1. gradio_app/model_handler.py +83 -64
gradio_app/model_handler.py CHANGED
@@ -1,65 +1,84 @@
1
- import torch
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
- from peft import PeftModel
4
- import gc
5
- from config import logger, LORA_CONFIGS
6
-
7
- class ModelHandler:
8
- def __init__(self):
9
- self.model = None
10
- self.tokenizer = None
11
- self.current_model_id = None
12
-
13
- def load_model(self, model_id, chatbot_state):
14
- """Load the model, tokenizer, and apply LoRA adapter for the given model ID."""
15
- try:
16
- logger.info(f"Loading model: {model_id}")
17
- print(f"Changing to model: {model_id}")
18
- self.clear_model()
19
-
20
- if model_id not in LORA_CONFIGS:
21
- raise ValueError(f"Invalid model ID: {model_id}")
22
-
23
- device = "cuda" if torch.cuda.is_available() else "cpu"
24
- base_model_name = LORA_CONFIGS[model_id]["base_model"]
25
- lora_adapter_name = LORA_CONFIGS[model_id]["lora_adapter"]
26
-
27
- self.tokenizer = AutoTokenizer.from_pretrained(
28
- base_model_name,
29
- trust_remote_code=True
30
- )
31
- self.tokenizer.use_default_system_prompt = False
32
-
33
- if self.tokenizer.pad_token is None or self.tokenizer.pad_token == self.tokenizer.eos_token:
34
- self.tokenizer.pad_token = self.tokenizer.unk_token or "<pad>"
35
- logger.info(f"Set pad_token to {self.tokenizer.pad_token}")
36
-
37
- self.model = AutoModelForCausalLM.from_pretrained(
38
- base_model_name,
39
- torch_dtype=torch.float16,
40
- device_map=device,
41
- trust_remote_code=True
42
- )
43
-
44
- self.model = PeftModel.from_pretrained(self.model, lora_adapter_name)
45
- self.model.eval()
46
- self.model.config.pad_token_id = self.tokenizer.pad_token_id
47
-
48
- self.current_model_id = model_id
49
- chatbot_state = []
50
- return f"Successfully loaded model: {model_id} with LoRA adapter {lora_adapter_name}", chatbot_state
51
- except Exception as e:
52
- logger.error(f"Failed to load model or tokenizer: {str(e)}")
53
- return f"Error: Failed to load model {model_id}: {str(e)}", chatbot_state
54
-
55
- def clear_model(self):
56
- """Clear the current model and tokenizer from memory."""
57
- if self.model is not None:
58
- print("Clearing previous model from RAM/VRAM...")
59
- del self.model
60
- del self.tokenizer
61
- self.model = None
62
- self.tokenizer = None
63
- gc.collect()
64
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  print("Memory cleared successfully.")
 
1
+ # import torch
2
+ # from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ # from peft import PeftModel
4
+ # import gc
5
+ # from config import logger, LORA_CONFIGS
6
+
7
+ import os
8
+ import torch
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
+ from peft import PeftModel
11
+ from huggingface_hub import login
12
+ import gc
13
+ from config import logger, LORA_CONFIGS
14
+
15
+ # Check for Hugging Face API token
16
+ if not os.environ.get("HUGGINGFACEHUB_API_TOKEN"):
17
+ logger.error("Hugging Face API token is not set. Please set the HUGGINGFACEHUB_API_TOKEN environment variable.")
18
+ raise ValueError("Hugging Face API token is not set. Please set the HUGGINGFACEHUB_API_TOKEN environment variable.")
19
+
20
+ # Set the Hugging Face API token
21
+ os.environ["HUGGINGFACEHUB_API_TOKEN"] = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
22
+
23
+ # Initialize API
24
+ login(os.environ.get("HUGGINGFACEHUB_API_TOKEN"))
25
+
26
+ class ModelHandler:
27
+ def __init__(self):
28
+ self.model = None
29
+ self.tokenizer = None
30
+ self.current_model_id = None
31
+
32
+ def load_model(self, model_id, chatbot_state):
33
+ """Load the model, tokenizer, and apply LoRA adapter for the given model ID."""
34
+ try:
35
+ logger.info(f"Loading model: {model_id}")
36
+ print(f"Changing to model: {model_id}")
37
+ self.clear_model()
38
+
39
+ if model_id not in LORA_CONFIGS:
40
+ raise ValueError(f"Invalid model ID: {model_id}")
41
+
42
+ device = "cuda" if torch.cuda.is_available() else "cpu"
43
+ base_model_name = LORA_CONFIGS[model_id]["base_model"]
44
+ lora_adapter_name = LORA_CONFIGS[model_id]["lora_adapter"]
45
+
46
+ self.tokenizer = AutoTokenizer.from_pretrained(
47
+ base_model_name,
48
+ trust_remote_code=True
49
+ )
50
+ self.tokenizer.use_default_system_prompt = False
51
+
52
+ if self.tokenizer.pad_token is None or self.tokenizer.pad_token == self.tokenizer.eos_token:
53
+ self.tokenizer.pad_token = self.tokenizer.unk_token or "<pad>"
54
+ logger.info(f"Set pad_token to {self.tokenizer.pad_token}")
55
+
56
+ self.model = AutoModelForCausalLM.from_pretrained(
57
+ base_model_name,
58
+ torch_dtype=torch.float16,
59
+ device_map=device,
60
+ trust_remote_code=True
61
+ )
62
+
63
+ self.model = PeftModel.from_pretrained(self.model, lora_adapter_name)
64
+ self.model.eval()
65
+ self.model.config.pad_token_id = self.tokenizer.pad_token_id
66
+
67
+ self.current_model_id = model_id
68
+ chatbot_state = []
69
+ return f"Successfully loaded model: {model_id} with LoRA adapter {lora_adapter_name}", chatbot_state
70
+ except Exception as e:
71
+ logger.error(f"Failed to load model or tokenizer: {str(e)}")
72
+ return f"Error: Failed to load model {model_id}: {str(e)}", chatbot_state
73
+
74
+ def clear_model(self):
75
+ """Clear the current model and tokenizer from memory."""
76
+ if self.model is not None:
77
+ print("Clearing previous model from RAM/VRAM...")
78
+ del self.model
79
+ del self.tokenizer
80
+ self.model = None
81
+ self.tokenizer = None
82
+ gc.collect()
83
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
84
  print("Memory cleared successfully.")