ariG23498 HF staff commited on
Commit
b2b6307
·
verified ·
1 Parent(s): cf572fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -11
app.py CHANGED
@@ -14,24 +14,14 @@ from kvpress import (
14
  )
15
  import spaces
16
 
17
- # Initialize GPU Zero
18
- zero = torch.Tensor([0]).cuda()
19
- print(zero.device) # Ensure the tensor is on GPU
20
-
21
  @spaces.GPU
22
  def process_request(url, question, press_type, compression_ratio):
23
  try:
24
- print(zero.device) # Confirm the GPU usage
25
-
26
  # Fetch Wikipedia content
27
  content = requests.get(url).content
28
  soup = BeautifulSoup(content, "html.parser")
29
  context = "".join([p.text for p in soup.find_all("p")]) + "\n\n"
30
 
31
- # Calculate tokens
32
- tokens = pipe.tokenizer.encode(context, return_tensors="pt").to(device)
33
- num_tokens = tokens.size(1)
34
-
35
  # Initialize the press
36
  press_class = press_map.get(press_type)
37
  if not press_class:
@@ -49,7 +39,7 @@ def process_request(url, question, press_type, compression_ratio):
49
  # Load pipeline
50
  device = "cuda:0"
51
  ckpt = "meta-llama/Meta-Llama-3.1-8B-Instruct"
52
- attn_implementation = "flash_attention_2"
53
  pipe = pipeline(
54
  "kv-press-text-generation",
55
  model=ckpt,
 
14
  )
15
  import spaces
16
 
 
 
 
 
17
  @spaces.GPU
18
  def process_request(url, question, press_type, compression_ratio):
19
  try:
 
 
20
  # Fetch Wikipedia content
21
  content = requests.get(url).content
22
  soup = BeautifulSoup(content, "html.parser")
23
  context = "".join([p.text for p in soup.find_all("p")]) + "\n\n"
24
 
 
 
 
 
25
  # Initialize the press
26
  press_class = press_map.get(press_type)
27
  if not press_class:
 
39
  # Load pipeline
40
  device = "cuda:0"
41
  ckpt = "meta-llama/Meta-Llama-3.1-8B-Instruct"
42
+ attn_implementation = "sdpa"
43
  pipe = pipeline(
44
  "kv-press-text-generation",
45
  model=ckpt,