imdeadinside410 commited on
Commit
d5c2512
·
verified ·
1 Parent(s): 950e744

Upload 2 files

Browse files
Files changed (2) hide show
  1. model.py +177 -0
  2. requirements.txt +15 -0
model.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from langchain.document_loaders import PyPDFLoader, DirectoryLoader
3
+ from langchain import PromptTemplate
4
+ from langchain.embeddings import HuggingFaceEmbeddings
5
+ from langchain.vectorstores import FAISS
6
+ from langchain.llms import CTransformers
7
+ from langchain.chains import RetrievalQA
8
+
9
+ # DB_FAISS_PATH = 'vectorstores/db_faiss/NE-Syllabus'
10
+
11
+ # custom_prompt_template = """Use the following pieces of information to answer the user's question.
12
+ # If you don't know the answer, just say that you don't know, don't try to make up an answer.
13
+
14
+ # # Context: {answer}
15
+ # # Question: {question}
16
+
17
+ # Only return the helpful answer below and nothing else.
18
+ # Helpful answer:
19
+ # """
20
+
21
+
22
+ # def set_custom_prompt():
23
+ # prompt = PromptTemplate(template=custom_prompt_template,
24
+ # input_variables=['context', 'question'])
25
+ # return prompt
26
+
27
+
28
+ # def retrieval_qa_chain(llm, prompt, db):
29
+ # qa_chain = RetrievalQA.from_chain_type(llm=llm,
30
+ # chain_type='stuff',
31
+ # retriever=db.as_retriever(
32
+ # search_kwargs={'k': 2}),
33
+ # return_source_documents=True,
34
+ # chain_type_kwargs={'prompt': prompt}
35
+ # )
36
+ # return qa_chain
37
+
38
+
39
+ # def load_llm():
40
+ # llm = CTransformers(
41
+ # model="TheBloke/Llama-2-7B-Chat-GGML",
42
+ # model_type="llama",
43
+ # max_new_tokens=512,
44
+ # temperature=0.5
45
+ # )
46
+ # return llm
47
+
48
+
49
+ # def qa_bot(query):
50
+ # # sentence-transformers/all-MiniLM-L6-v2
51
+ # embeddings = HuggingFaceEmbeddings(model_name="imdeadinside410/TestTrainedModel",
52
+ # model_kwargs={'device': 'cpu'})
53
+ # db = FAISS.load_local(DB_FAISS_PATH, embeddings)
54
+ # llm = load_llm()
55
+ # qa_prompt = set_custom_prompt()
56
+ # qa = retrieval_qa_chain(llm, qa_prompt, db)
57
+
58
+ # # Implement the question-answering logic here
59
+ # response = qa({'query': query})
60
+ # return response['result']
61
+
62
+ from langchain.llms.huggingface_pipeline import HuggingFacePipeline
63
+ from transformers import pipeline
64
+ from langchain.prompts import PromptTemplate
65
+ import torch
66
+ from torch import cuda
67
+ from peft import PeftModel, PeftConfig
68
+ from transformers import AutoModelForCausalLM, AutoTokenizer
69
+
70
+ peft_model_id = "imdeadinside410/Llama2-Syllabus"
71
+ device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
72
+ config = PeftConfig.from_pretrained(peft_model_id)
73
+ model = AutoModelForCausalLM.from_pretrained(
74
+ config.base_model_name_or_path, return_dict=True, load_in_8bit=True, device_map=device)
75
+ tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
76
+
77
+ # Load the Lora model
78
+ model = PeftModel.from_pretrained(model, peft_model_id)
79
+
80
+ # prompt = "What is the mission of the School of Computer Science and Engineering?"
81
+
82
+ pipe = pipeline(task="text-generation",
83
+ model=model,
84
+ tokenizer=tokenizer, max_length=300)
85
+
86
+
87
+
88
+ # result = pipe(f"<s>[INST] {prompt} [/INST]")
89
+ # print(result[0]['generated_text'].split("[/INST]")[1])
90
+
91
+ # template = """Question: {question}
92
+
93
+ # Answer: Let's think step by step."""
94
+ # prompt = PromptTemplate.from_template(template)
95
+
96
+ # chain = prompt | hf
97
+
98
+ # question = "What is IT ?"
99
+
100
+ # print(chain.invoke({"question": question}))
101
+
102
+
103
+ def add_vertical_space(spaces=1):
104
+ for _ in range(spaces):
105
+ st.markdown("---")
106
+
107
+
108
+ def main():
109
+ st.set_page_config(page_title="AIoTLab NE Syllabus")
110
+
111
+ with st.sidebar:
112
+ st.title('AIoTLab NE Syllabus')
113
+ st.markdown('''
114
+ Hi
115
+ ''')
116
+ add_vertical_space(1) # Adjust the number of spaces as needed
117
+ st.write(
118
+ 'AIoT Lab')
119
+
120
+ st.title("AIoTLab NE Syllabus")
121
+ st.markdown(
122
+ """
123
+ <style>
124
+ .chat-container {
125
+ display: flex;
126
+ flex-direction: column;
127
+ height: 400px;
128
+ overflow-y: auto;
129
+ padding: 10px;
130
+ color: white; /* Font color */
131
+ }
132
+ .user-bubble {
133
+ background-color: #007bff; /* Blue color for user */
134
+ align-self: flex-end;
135
+ border-radius: 10px;
136
+ padding: 8px;
137
+ margin: 5px;
138
+ max-width: 70%;
139
+ word-wrap: break-word;
140
+ }
141
+ .bot-bubble {
142
+ background-color: #363636; /* Slightly lighter background color */
143
+ align-self: flex-start;
144
+ border-radius: 10px;
145
+ padding: 8px;
146
+ margin: 5px;
147
+ max-width: 70%;
148
+ word-wrap: break-word;
149
+ }
150
+ </style>
151
+ """, unsafe_allow_html=True)
152
+
153
+ conversation = st.session_state.get("conversation", [])
154
+
155
+ query = st.text_input("Please input your question here:", key="user_input")
156
+ result = pipe(f"<s>[INST] {query} [/INST]")
157
+ if st.button("Get Answer"):
158
+ if query:
159
+ # Display the processing message
160
+ with st.spinner("Processing your question..."):
161
+ conversation.append({"role": "user", "message": query})
162
+ # Call your QA function
163
+ answer = result
164
+ conversation.append({"role": "bot", "message": answer})
165
+ st.session_state.conversation = conversation
166
+ else:
167
+ st.warning("Please input a question.")
168
+
169
+ chat_container = st.empty()
170
+ chat_bubbles = ''.join(
171
+ [f'<div class="{c["role"]}-bubble">{c["message"]}</div>' for c in conversation])
172
+ chat_container.markdown(
173
+ f'<div class="chat-container">{chat_bubbles}</div>', unsafe_allow_html=True)
174
+
175
+
176
+ if __name__ == "__main__":
177
+ main()
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pypdf==3.17.0
2
+ PyPDF2==3.0.0
3
+ scikit-learn==1.3.0
4
+ accelerate==0.22.0
5
+ bitsandbytes==0.41.1
6
+ ctransformers==0.2.26
7
+ huggingface-hub==0.22.0
8
+ langchain==0.0.329
9
+ sentence-transformers==2.2.2
10
+ torch==2.3.0
11
+ transformers==4.36.0
12
+ peft==0.8.0
13
+
14
+ chainlit==1.0.0
15
+ streamlit==1.30.0