radussad commited on
Commit
a3d72f2
·
verified ·
1 Parent(s): ba9106d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -4,10 +4,14 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
  import torch
5
  from retriever import retrieve_documents
6
 
 
 
 
 
7
  # Load Mistral 7B model
8
  MODEL_NAME = "mistralai/Mistral-7B-v0.1"
9
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=os.getenv("HUGGING_FACE_HUB_TOKEN"))
10
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, use_auth_token=os.getenv("HUGGING_FACE_HUB_TOKEN")) #, device_map="auto", torch_dtype=torch.float16)
11
 
12
  # Create inference pipeline
13
  generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
 
4
  import torch
5
  from retriever import retrieve_documents
6
 
7
+ # Set writable cache location
8
+ #os.environ["HF_HOME"] = "/tmp/huggingface"
9
+ #os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
10
+
11
  # Load Mistral 7B model
12
  MODEL_NAME = "mistralai/Mistral-7B-v0.1"
13
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=os.getenv("HUGGING_FACE_HUB_TOKEN"), cache_dir="/tmp/huggingface")
14
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, use_auth_token=os.getenv("HUGGING_FACE_HUB_TOKEN"), cache_dir="/tmp/huggingface") #, device_map="auto", torch_dtype=torch.float16)
15
 
16
  # Create inference pipeline
17
  generator = pipeline("text-generation", model=model, tokenizer=tokenizer)