betterdataai commited on
Commit
39860ea
·
verified ·
1 Parent(s): d31e35c

Upload 2 files

Browse files
Files changed (2) hide show
  1. handler.py +86 -0
  2. requirements.txt +4 -3
handler.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ from peft import PeftModel
5
+ from transformers import (
6
+ LlamaForCausalLM,
7
+ LlamaTokenizer,
8
+ GenerationConfig,
9
+ )
10
+
11
+ class EndpointHandler:
12
+ def __init__(self, model_dir: str = ".", **kwargs):
13
+ """
14
+ This method runs once when the Endpoint first starts.
15
+ - model_dir is the local directory of *this* repository
16
+ which contains your LoRA adapter weights (e.g. adapter_model.safetensors).
17
+ """
18
+
19
+ # 1) Base model from Hugging Face
20
+ # Make sure to use the EXACT base you trained on, or it won't match your LoRA.
21
+ self.base_model_id = "unsloth/Llama-3.2-3B-Instruct"
22
+
23
+ # If your base model is gated/private, you'll need a token:
24
+ # hf_token = os.getenv("HF_TOKEN", None)
25
+
26
+ # 2) Load the tokenizer
27
+ self.tokenizer = LlamaTokenizer.from_pretrained(
28
+ self.base_model_id,
29
+ trust_remote_code=True,
30
+ # use_auth_token=hf_token, # if needed
31
+ )
32
+
33
+ # 3) Load the base model
34
+ self.base_model = LlamaForCausalLM.from_pretrained(
35
+ self.base_model_id,
36
+ device_map="auto", # or "cuda:0"
37
+ torch_dtype=torch.float16, # or bfloat16
38
+ trust_remote_code=True,
39
+ # use_auth_token=hf_token, # if needed
40
+ )
41
+
42
+ # 4) Load/merge your LoRA adapter
43
+ self.model = PeftModel.from_pretrained(
44
+ self.base_model,
45
+ model_dir, # The local directory of this repo
46
+ torch_dtype=torch.float16,
47
+ ).eval()
48
+
49
+ def __call__(self, data):
50
+ """
51
+ This method is called for every request to the endpoint.
52
+ `data` is a dictionary (or JSON string) containing user inputs.
53
+ Returns a dictionary or string (will be serialized as JSON).
54
+ """
55
+ # If data is a JSON string, parse it:
56
+ if isinstance(data, str):
57
+ data = json.loads(data)
58
+
59
+ # Extract the user prompt from the request payload
60
+ prompt = data.get("inputs", "")
61
+ if not isinstance(prompt, str):
62
+ raise ValueError("`inputs` must be a string.")
63
+
64
+ # Optionally extract generation params (max_new_tokens, temperature, etc.)
65
+ # If none provided, use defaults:
66
+ gen_params = data.get("parameters", {})
67
+ generation_config = GenerationConfig(
68
+ max_new_tokens=gen_params.get("max_new_tokens", 128),
69
+ temperature=gen_params.get("temperature", 0.7),
70
+ top_p=gen_params.get("top_p", 0.9),
71
+ do_sample=gen_params.get("do_sample", True),
72
+ # etc.
73
+ )
74
+
75
+ # Tokenize the prompt
76
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
77
+
78
+ # Generate text
79
+ with torch.no_grad():
80
+ output_ids = self.model.generate(**inputs, generation_config=generation_config)
81
+
82
+ # Decode the output
83
+ output_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
84
+
85
+ # Return the generated text in a JSON-friendly format
86
+ return {"generated_text": output_text}
requirements.txt CHANGED
@@ -1,8 +1,9 @@
1
  unsloth
2
- transformers
3
  pandas
4
  datasets
5
  trl
6
- torch
7
- accelerate
8
  scipy
 
 
 
 
 
1
  unsloth
 
2
  pandas
3
  datasets
4
  trl
 
 
5
  scipy
6
+ transformers>=4.30.0
7
+ peft>=0.4.0
8
+ accelerate>=0.20.0
9
+ torch>=2.0