larry1129 commited on
Commit
7d93b52
·
verified ·
1 Parent(s): 08721e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -8
app.py CHANGED
@@ -1,24 +1,38 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
3
  import torch
4
  import os
5
 
6
  # 获取 Hugging Face 访问令牌
7
  hf_token = os.getenv("HF_API_TOKEN")
8
 
9
- # 定义模型名称(替换为您上传的模型名称)
10
- model_name = "larry1129/WooWoof_AI" # 替换为您的模型名称
 
 
 
11
 
12
  # 加载分词器
13
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_token)
14
 
15
- # 加载模型
16
- model = AutoModelForCausalLM.from_pretrained(
17
- model_name,
18
  device_map="auto",
19
  torch_dtype=torch.float16,
20
  use_auth_token=hf_token,
21
- trust_remote_code=True # 如果你的模型使用自定义代码,请保留此参数
 
 
 
 
 
 
 
 
 
 
22
  )
23
 
24
  # 设置 pad_token
@@ -51,7 +65,7 @@ def generate_prompt(instruction, input_text=""):
51
  def generate_response(instruction, input_text):
52
  prompt = generate_prompt(instruction, input_text)
53
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
54
-
55
  with torch.no_grad():
56
  outputs = model.generate(
57
  input_ids=inputs["input_ids"],
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ from peft import PeftModel
4
  import torch
5
  import os
6
 
7
  # 获取 Hugging Face 访问令牌
8
  hf_token = os.getenv("HF_API_TOKEN")
9
 
10
+ # 定义基础模型名称
11
+ base_model_name = "unsloth/meta-llama-3.1-8b-bnb-4bit" # 替换为你的基础模型名称
12
+
13
+ # 定义 adapter 模型名称(假设 adapter 在同一个 repo 中)
14
+ adapter_model_name = "larry1129/WooWoof_AI" # 替换为你的 adapter 模型名称
15
 
16
  # 加载分词器
17
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name, use_auth_token=hf_token)
18
 
19
+ # 加载基础模型
20
+ base_model = AutoModelForCausalLM.from_pretrained(
21
+ base_model_name,
22
  device_map="auto",
23
  torch_dtype=torch.float16,
24
  use_auth_token=hf_token,
25
+ trust_remote_code=True
26
+ )
27
+
28
+ # 加载 adapter 并将其应用到基础模型上
29
+ model = PeftModel.from_pretrained(
30
+ base_model,
31
+ adapter_model_name,
32
+ device_map="auto",
33
+ torch_dtype=torch.float16,
34
+ use_auth_token=hf_token,
35
+ trust_remote_code=True
36
  )
37
 
38
  # 设置 pad_token
 
65
  def generate_response(instruction, input_text):
66
  prompt = generate_prompt(instruction, input_text)
67
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
68
+
69
  with torch.no_grad():
70
  outputs = model.generate(
71
  input_ids=inputs["input_ids"],