radussad commited on
Commit
71e9b01
·
verified ·
1 Parent(s): d999ea2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -9,14 +9,15 @@ from retriever import retrieve_documents
9
  #os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
10
 
11
  # Load Mistral 7B model
12
- MODEL_NAME = "mistralai/Mistral-7B-v0.1"
 
 
13
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=os.getenv("HUGGING_FACE_HUB_TOKEN"), cache_dir="/tmp/huggingface")
14
  model = AutoModelForCausalLM.from_pretrained(MODEL_NAME,
15
  use_auth_token=os.getenv("HUGGING_FACE_HUB_TOKEN"),
16
  cache_dir="/tmp/huggingface",
17
  device_map="auto",
18
- torch_dtype=torch.float16,
19
- load_in_4bit=True
20
  )
21
 
22
  # Create inference pipeline
@@ -27,7 +28,7 @@ app = FastAPI()
27
 
28
  @app.get("/")
29
  def read_root():
30
- return {"message": "Mistral 7B RAG API is running!"}
31
 
32
  @app.get("/generate/")
33
  def generate_response(query: str = Query(..., title="User Query")):
 
9
  #os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
10
 
11
  # Load Mistral 7B model
12
+ #MODEL_NAME = "mistralai/Mistral-7B-v0.1"
13
+ MODEL_NAME = "microsoft/phi3-mini-4k-instruct"
14
+
15
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=os.getenv("HUGGING_FACE_HUB_TOKEN"), cache_dir="/tmp/huggingface")
16
  model = AutoModelForCausalLM.from_pretrained(MODEL_NAME,
17
  use_auth_token=os.getenv("HUGGING_FACE_HUB_TOKEN"),
18
  cache_dir="/tmp/huggingface",
19
  device_map="auto",
20
+ torch_dtype=torch.float16
 
21
  )
22
 
23
  # Create inference pipeline
 
28
 
29
  @app.get("/")
30
  def read_root():
31
+ return {"message": "Phi3 Mini RAG API is running!"}
32
 
33
  @app.get("/generate/")
34
  def generate_response(query: str = Query(..., title="User Query")):