imdeadinside410 commited on
Commit
7af778e
·
verified ·
1 Parent(s): 2ee03f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -6
app.py CHANGED
@@ -74,7 +74,7 @@ peft_model_id = "imdeadinside410/Llama2-Syllabus"
74
  config = PeftConfig.from_pretrained(peft_model_id)
75
 
76
  model = AutoModelForCausalLM.from_pretrained(
77
- config.base_model_name_or_path, return_dict=True)
78
 
79
  tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
80
 
@@ -84,9 +84,7 @@ model = PeftModel.from_pretrained(model, peft_model_id)
84
 
85
  pipe = pipeline(task="text-generation",
86
  model=model,
87
- tokenizer=tokenizer,
88
- device=-1, # -1 indicates CPU
89
- max_length=300)
90
 
91
 
92
 
@@ -158,14 +156,14 @@ def main():
158
  conversation = st.session_state.get("conversation", [])
159
 
160
  query = st.text_input("Please input your question here:", key="user_input")
161
- result = pipe(f"<s>[INST] {query} [/INST]")
162
  if st.button("Get Answer"):
163
  if query:
164
  # Display the processing message
165
  with st.spinner("Processing your question..."):
166
  conversation.append({"role": "user", "message": query})
167
  # Call your QA function
168
- answer = result
169
  conversation.append({"role": "bot", "message": answer})
170
  st.session_state.conversation = conversation
171
  else:
 
74
  config = PeftConfig.from_pretrained(peft_model_id)
75
 
76
  model = AutoModelForCausalLM.from_pretrained(
77
+ config.base_model_name_or_path, return_dict=True, load_in_4bit=True, device_map="auto")
78
 
79
  tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
80
 
 
84
 
85
  pipe = pipeline(task="text-generation",
86
  model=model,
87
+ tokenizer=tokenizer, max_length=300)
 
 
88
 
89
 
90
 
 
156
  conversation = st.session_state.get("conversation", [])
157
 
158
  query = st.text_input("Please input your question here:", key="user_input")
159
+ result = pipe(f"<s>[INST] {prompt} [/INST]")
160
  if st.button("Get Answer"):
161
  if query:
162
  # Display the processing message
163
  with st.spinner("Processing your question..."):
164
  conversation.append({"role": "user", "message": query})
165
  # Call your QA function
166
+ answer = result[0]['generated_text'].split("[/INST]")[1]
167
  conversation.append({"role": "bot", "message": answer})
168
  st.session_state.conversation = conversation
169
  else: