Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -74,7 +74,7 @@ peft_model_id = "imdeadinside410/Llama2-Syllabus"
|
|
74 |
config = PeftConfig.from_pretrained(peft_model_id)
|
75 |
|
76 |
model = AutoModelForCausalLM.from_pretrained(
|
77 |
-
config.base_model_name_or_path, return_dict=True)
|
78 |
|
79 |
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
|
80 |
|
@@ -84,9 +84,7 @@ model = PeftModel.from_pretrained(model, peft_model_id)
|
|
84 |
|
85 |
pipe = pipeline(task="text-generation",
|
86 |
model=model,
|
87 |
-
tokenizer=tokenizer,
|
88 |
-
device=-1, # -1 indicates CPU
|
89 |
-
max_length=300)
|
90 |
|
91 |
|
92 |
|
@@ -158,14 +156,14 @@ def main():
|
|
158 |
conversation = st.session_state.get("conversation", [])
|
159 |
|
160 |
query = st.text_input("Please input your question here:", key="user_input")
|
161 |
-
result = pipe(f"<s>[INST] {
|
162 |
if st.button("Get Answer"):
|
163 |
if query:
|
164 |
# Display the processing message
|
165 |
with st.spinner("Processing your question..."):
|
166 |
conversation.append({"role": "user", "message": query})
|
167 |
# Call your QA function
|
168 |
-
answer = result
|
169 |
conversation.append({"role": "bot", "message": answer})
|
170 |
st.session_state.conversation = conversation
|
171 |
else:
|
|
|
74 |
config = PeftConfig.from_pretrained(peft_model_id)
|
75 |
|
76 |
model = AutoModelForCausalLM.from_pretrained(
|
77 |
+
config.base_model_name_or_path, return_dict=True, load_in_4bit=True, device_map="auto")
|
78 |
|
79 |
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
|
80 |
|
|
|
84 |
|
85 |
pipe = pipeline(task="text-generation",
|
86 |
model=model,
|
87 |
+
tokenizer=tokenizer, max_length=300)
|
|
|
|
|
88 |
|
89 |
|
90 |
|
|
|
156 |
conversation = st.session_state.get("conversation", [])
|
157 |
|
158 |
query = st.text_input("Please input your question here:", key="user_input")
|
159 |
+
result = pipe(f"<s>[INST] {prompt} [/INST]")
|
160 |
if st.button("Get Answer"):
|
161 |
if query:
|
162 |
# Display the processing message
|
163 |
with st.spinner("Processing your question..."):
|
164 |
conversation.append({"role": "user", "message": query})
|
165 |
# Call your QA function
|
166 |
+
answer = result[0]['generated_text'].split("[/INST]")[1]
|
167 |
conversation.append({"role": "bot", "message": answer})
|
168 |
st.session_state.conversation = conversation
|
169 |
else:
|