Tonic commited on
Commit
e3f8aa5
·
verified ·
1 Parent(s): 450c4d5

solve cude device selection bug

Browse files
Files changed (1) hide show
  1. app.py +3 -13
app.py CHANGED
@@ -34,25 +34,15 @@ def load_pipeline():
34
  clear_gpu_memory()
35
  pipeline = ChronosPipeline.from_pretrained(
36
  "amazon/chronos-t5-large",
37
- device_map="gpu", # Let the model decide the best device mapping
38
- torch_dtype=torch.float16,
39
  low_cpu_mem_usage=True
40
  )
41
  pipeline.model = pipeline.model.eval()
42
  return pipeline
43
  except Exception as e:
44
  print(f"Error loading pipeline: {str(e)}")
45
- # Fallback to CPU if GPU fails
46
- if "cuda" in str(e).lower():
47
- print("Falling back to CPU mode")
48
- pipeline = ChronosPipeline.from_pretrained(
49
- "amazon/chronos-t5-large",
50
- device_map="cpu",
51
- torch_dtype=torch.float32,
52
- low_cpu_mem_usage=True
53
- )
54
- pipeline.model = pipeline.model.eval()
55
- return pipeline
56
 
57
  def get_historical_data(symbol: str, timeframe: str = "1d", lookback_days: int = 365) -> pd.DataFrame:
58
  """
 
34
  clear_gpu_memory()
35
  pipeline = ChronosPipeline.from_pretrained(
36
  "amazon/chronos-t5-large",
37
+ device_map="auto", # Let the machine choose the best device
38
+ torch_dtype=torch.float16, # Use float16 for better memory efficiency
39
  low_cpu_mem_usage=True
40
  )
41
  pipeline.model = pipeline.model.eval()
42
  return pipeline
43
  except Exception as e:
44
  print(f"Error loading pipeline: {str(e)}")
45
+ raise RuntimeError(f"Failed to load model: {str(e)}")
 
 
 
 
 
 
 
 
 
 
46
 
47
  def get_historical_data(symbol: str, timeframe: str = "1d", lookback_days: int = 365) -> pd.DataFrame:
48
  """