cheberle commited on
Commit
553c8c4
·
1 Parent(s): c60d44f
Files changed (2) hide show
  1. app.py +27 -39
  2. requirements.txt +3 -1
app.py CHANGED
@@ -1,55 +1,43 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
 
4
 
5
- MODEL_NAME = "cheberle/autotrain-35swc-b4r9z"
 
 
6
 
7
- # ---------------------------------------------------------------------------
8
- # 1) Load the tokenizer and model for sequence classification
9
- # ---------------------------------------------------------------------------
10
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
11
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, trust_remote_code=True)
12
 
13
- # Create a pipeline for text classification
14
- classifier = pipeline("text-classification", model=model, tokenizer=tokenizer)
 
 
 
 
15
 
16
- # ---------------------------------------------------------------------------
17
- # 2) Define inference function
18
- # ---------------------------------------------------------------------------
19
  def classify_text(text):
20
  """
21
- Return the classification results in the format:
22
- [
23
- {
24
- 'label': 'POSITIVE',
25
- 'score': 0.98
26
- }
27
- ]
28
  """
29
- results = classifier(text)
30
- return results
 
 
 
 
31
 
32
- # ---------------------------------------------------------------------------
33
- # 3) Build the Gradio UI
34
- # ---------------------------------------------------------------------------
35
  with gr.Blocks() as demo:
36
- gr.Markdown("<h3>Text Classification Demo</h3>")
37
-
38
- with gr.Row():
39
- input_text = gr.Textbox(
40
- lines=3,
41
- label="Enter text to classify",
42
- placeholder="Type something..."
43
- )
44
- output = gr.JSON(label="Classification Output")
45
 
46
  classify_btn = gr.Button("Classify")
 
47
 
48
- # Link the button to the function
49
- classify_btn.click(fn=classify_text, inputs=input_text, outputs=output)
50
-
51
- # ---------------------------------------------------------------------------
52
- # 4) Launch the demo
53
- # ---------------------------------------------------------------------------
54
  if __name__ == "__main__":
55
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
+ from peft import PeftModel, PeftConfig
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
 
6
+ # Repos
7
+ BASE_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
8
+ ADAPTER_REPO = "cheberle/autotrain-35swc-b4r9z"
9
 
10
+ # 1. Load the PEFT config to confirm the base model
11
+ peft_config = PeftConfig.from_pretrained(ADAPTER_REPO)
12
+ print("PEFT Base Model:", peft_config.base_model_name_or_path)
 
 
13
 
14
+ # 2. Load the tokenizer & base model
15
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
16
+ base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, trust_remote_code=True)
17
+
18
+ # 3. Load your LoRA adapter weights onto the base model
19
+ model = PeftModel.from_pretrained(base_model, ADAPTER_REPO)
20
 
 
 
 
21
  def classify_text(text):
22
  """
23
+ Simple prompting approach: we ask the model to return a single classification label
24
+ (e.g., 'positive', 'negative', etc.).
25
+ You can refine this prompt, add chain-of-thought, or multiple classes as needed.
 
 
 
 
26
  """
27
+ prompt = f"Below is some text.\nText: {text}\nPlease classify the sentiment (positive or negative):"
28
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
29
+ with torch.no_grad():
30
+ outputs = model.generate(**inputs, max_new_tokens=64)
31
+ answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
32
+ return answer
33
 
 
 
 
34
  with gr.Blocks() as demo:
35
+ gr.Markdown("## Qwen + LoRA Adapter: Text Classification Demo")
36
+ input_box = gr.Textbox(lines=3, label="Enter text")
37
+ output_box = gr.Textbox(lines=3, label="Model's generated output (classification)")
 
 
 
 
 
 
38
 
39
  classify_btn = gr.Button("Classify")
40
+ classify_btn.click(fn=classify_text, inputs=input_box, outputs=output_box)
41
 
 
 
 
 
 
 
42
  if __name__ == "__main__":
43
  demo.launch()
requirements.txt CHANGED
@@ -1,4 +1,6 @@
1
  huggingface_hub==0.25.2
2
  torch
3
  transformers
4
- gradio
 
 
 
1
  huggingface_hub==0.25.2
2
  torch
3
  transformers
4
+ peft
5
+ gradio
6
+ accelerate # if needed for device_map