Chamin09 commited on
Commit
89efbe0
·
verified ·
1 Parent(s): e13d87a

Update models/llm_setup.py

Browse files
Files changed (1) hide show
  1. models/llm_setup.py +65 -64
models/llm_setup.py CHANGED
@@ -1,64 +1,65 @@
1
- from typing import Optional
2
- from llama_index.llms import HuggingFaceLLM
3
- import torch
4
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
5
-
6
- def setup_llm(model_name: str = "microsoft/phi-3-mini-4k-instruct",
7
- device: str = None,
8
- context_window: int = 4096,
9
- max_new_tokens: int = 512) -> HuggingFaceLLM:
10
- """
11
- Set up the language model for the CSV chatbot.
12
-
13
- Args:
14
- model_name: Name of the Hugging Face model to use
15
- device: Device to run the model on ('cuda', 'cpu', etc.)
16
- context_window: Maximum context window size
17
- max_new_tokens: Maximum number of new tokens to generate
18
-
19
- Returns:
20
- Configured LLM instance
21
- """
22
- # Determine device
23
- if device is None:
24
- device = "cuda" if torch.cuda.is_available() else "cpu"
25
-
26
- # Configure quantization for memory efficiency
27
- if device == "cuda":
28
- quantization_config = BitsAndBytesConfig(
29
- load_in_4bit=True,
30
- bnb_4bit_compute_dtype=torch.float16
31
- )
32
- else:
33
- quantization_config = None
34
-
35
- # Configure tokenizer
36
- tokenizer = AutoTokenizer.from_pretrained(
37
- model_name,
38
- trust_remote_code=True
39
- )
40
-
41
- # Configure model with appropriate parameters for HF Spaces
42
- model_kwargs = {
43
- "trust_remote_code": True,
44
- "torch_dtype": torch.float16,
45
- }
46
-
47
- if quantization_config:
48
- model_kwargs["quantization_config"] = quantization_config
49
-
50
- # Initialize LLM
51
- llm = HuggingFaceLLM(
52
- model_name=model_name,
53
- tokenizer_name=model_name,
54
- context_window=context_window,
55
- max_new_tokens=max_new_tokens,
56
- generate_kwargs={"temperature": 0.7, "top_p": 0.95},
57
- device_map=device,
58
- tokenizer_kwargs={"trust_remote_code": True},
59
- model_kwargs=model_kwargs,
60
- # Cache the model to avoid reloading
61
- cache_folder="./model_cache"
62
- )
63
-
64
- return llm
 
 
1
+ from typing import Optional
2
+ #from llama_index.llms import HuggingFaceLLM
3
+ from llama_index.llms.huggingface import HuggingFaceLLM
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
6
+
7
+ def setup_llm(model_name: str = "microsoft/phi-3-mini-4k-instruct",
8
+ device: str = None,
9
+ context_window: int = 4096,
10
+ max_new_tokens: int = 512) -> HuggingFaceLLM:
11
+ """
12
+ Set up the language model for the CSV chatbot.
13
+
14
+ Args:
15
+ model_name: Name of the Hugging Face model to use
16
+ device: Device to run the model on ('cuda', 'cpu', etc.)
17
+ context_window: Maximum context window size
18
+ max_new_tokens: Maximum number of new tokens to generate
19
+
20
+ Returns:
21
+ Configured LLM instance
22
+ """
23
+ # Determine device
24
+ if device is None:
25
+ device = "cuda" if torch.cuda.is_available() else "cpu"
26
+
27
+ # Configure quantization for memory efficiency
28
+ if device == "cuda":
29
+ quantization_config = BitsAndBytesConfig(
30
+ load_in_4bit=True,
31
+ bnb_4bit_compute_dtype=torch.float16
32
+ )
33
+ else:
34
+ quantization_config = None
35
+
36
+ # Configure tokenizer
37
+ tokenizer = AutoTokenizer.from_pretrained(
38
+ model_name,
39
+ trust_remote_code=True
40
+ )
41
+
42
+ # Configure model with appropriate parameters for HF Spaces
43
+ model_kwargs = {
44
+ "trust_remote_code": True,
45
+ "torch_dtype": torch.float16,
46
+ }
47
+
48
+ if quantization_config:
49
+ model_kwargs["quantization_config"] = quantization_config
50
+
51
+ # Initialize LLM
52
+ llm = HuggingFaceLLM(
53
+ model_name=model_name,
54
+ tokenizer_name=model_name,
55
+ context_window=context_window,
56
+ max_new_tokens=max_new_tokens,
57
+ generate_kwargs={"temperature": 0.7, "top_p": 0.95},
58
+ device_map=device,
59
+ tokenizer_kwargs={"trust_remote_code": True},
60
+ model_kwargs=model_kwargs,
61
+ # Cache the model to avoid reloading
62
+ cache_folder="./model_cache"
63
+ )
64
+
65
+ return llm