krystv commited on
Commit
601c87a
·
verified ·
1 Parent(s): 8d28b9a

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -215
app.py CHANGED
@@ -1,220 +1,152 @@
1
- import gradio as gr
2
- import time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import os
4
- import shutil
5
- import streamlit as st
6
- openai_api = st.secrets["OPENAI_API_KEY"]
7
-
8
- doc_store_path = os.path.join(os.path.dirname(__file__), "doc_dir")
9
- if not os.path.isdir(doc_store_path):
10
- os.makedirs(doc_store_path)
11
-
12
- from llama_index.core import SimpleDirectoryReader, VectorStoreIndex,Settings
13
- from llama_index.core.node_parser import SentenceSplitter,SemanticSplitterNodeParser
14
- from llama_index.llms.openai import OpenAI
15
- from llama_index.llms.openai import OpenAI as OpenAIsum
16
- from llama_index.embeddings.openai import OpenAIEmbedding
17
- from llama_index.core.storage import StorageContext
18
- from llama_index.vector_stores.chroma import ChromaVectorStore
19
- from llama_index.core.storage.chat_store import SimpleChatStore
20
- from llama_index.core.memory import ChatMemoryBuffer,ChatSummaryMemoryBuffer
21
-
22
- import json
23
- import chromadb
24
- import tiktoken
25
-
26
-
27
- chat_store = SimpleChatStore()
28
- # chat_memory = ChatMemoryBuffer.from_defaults(
29
- # token_limit=3000,
30
- # chat_store=chat_store,
31
- # chat_store_key="user1",
32
- # )
33
-
34
-
35
- sum_llm = OpenAIsum(api_key=openai_api, model="gpt-3.5-turbo", max_tokens=256)
36
- chat_summary_memory = ChatSummaryMemoryBuffer.from_defaults(
37
- token_limit=256,
38
- chat_store=chat_store,
39
- chat_store_key="user1",
40
- llm = sum_llm,
41
- tokenizer_fn = tiktoken.encoding_for_model("gpt-3.5-turbo").encode
42
- )
43
-
44
-
45
- chat_store = SimpleChatStore.from_persist_path(
46
- persist_path="chat_store.json"
47
- )
48
-
49
-
50
-
51
- # documents = SimpleDirectoryReader("./data").load_data()
52
- db = chromadb.PersistentClient(path="./chroma_db")
53
-
54
- chroma_collection = db.get_or_create_collection("quickstart")
55
-
56
- vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
57
- storage_context = StorageContext.from_defaults(vector_store=vector_store)
58
-
59
- Settings.llm = OpenAI(model="gpt-3.5-turbo",api_key=openai_api,)
60
- Settings.embed_model = OpenAIEmbedding(model="text-embedding-ada-002")
61
-
62
- vector_index = VectorStoreIndex.from_vector_store(vector_store, storage_context=storage_context,)
63
- query_engine = vector_index.as_chat_engine(chat_memory=chat_summary_memory,storage_context=storage_context,use_async=True,similarity_top_k=2)
64
-
65
- current_refs = ""
66
-
67
- def metadata_from_doc(vec_index: VectorStoreIndex) -> dict:
68
- qe = vec_index.as_chat_engine()
69
- # f_prompt = """
70
- # Given the text excerpts, analyze and provide the document's title and creation date in a structured JSON format. Here are a few examples:
71
-
72
- # In this format:
73
-
74
- # {
75
- # "creation_date": "YYYY-MM-DD",
76
- # "title": "Title of the Document"
77
- # }
78
-
79
- # Text: 'An analysis of historical events. Written by Alex Johnson on 5 March 2019.'
80
- # Output: { "title": "An analysis of historical events", "creation_date": "2019-03-05" }
81
-
82
- # Text: 'Exploring the depths of the ocean. This comprehensive guide was authored by Dr. Emily White, published on 10-July 2021.'
83
- # Output: { "title": "Exploring the depths of the ocean", "creation_date": "2021-07-10" }
84
-
85
- # Text: 'The history of the Roman Empire.'
86
- # Output: { "title": "The history of the Roman Empire", "creation_date": "Unknown" }
87
-
88
-
89
- # Now, analyze the context from the provided document and generate json object.
90
- # """
91
- f_prompt ="""give me a only the data when this document was written and title of this document? in json format parameter (created_date,title),
92
- example context: 'An analysis of historical events. Written by Alex Johnson on 5 March 2019.'
93
- example output: { "title": "An analysis of historical events", "creation_date": "2019-03-05" }
94
- now analyse the context make sure to return output only in json format object only.
95
- """
96
- res = qe.query(f_prompt)
97
- parsed = json.loads(res.response)
98
- return parsed
99
-
100
- def filter_unsaved(file_paths:list):
101
- for i in file_paths:
102
- if os.path.isfile(os.path.join(doc_store_path,os.path.basename(i))):
103
- file_paths.remove(i)
104
- print("File already exist : {}".format(i))
105
- else:
106
- shutil.copy2(i,doc_store_path)
107
- return file_paths
108
-
109
- def add_doc(file_paths:list):
110
- print(file_paths)
111
- file_paths = filter_unsaved(file_paths)
112
- print(file_paths)
113
- if len(file_paths) == 0:
114
- return
115
- docs = SimpleDirectoryReader(input_files=file_paths).load_data()
116
- splitter = SemanticSplitterNodeParser(buffer_size=1, breakpoint_percentile_threshold=95, embed_model=Settings.embed_model,chunk_size=256)
117
- nodes = splitter.get_nodes_from_documents(docs)
118
- vector_index2 = VectorStoreIndex(nodes)
119
- for i in range (5):
120
- try:
121
- meta = metadata_from_doc(vector_index2)
122
- break
123
- except:
124
- meta = {
125
- "title": "Unknown",
126
- "creation_date": "Unknown"
127
- }
128
- continue
129
-
130
- print(meta)
131
- for i in range(len(nodes)):
132
- nodes[i].metadata.update(meta)
133
- vector_index.insert_nodes(nodes)
134
-
135
-
136
-
137
-
138
-
139
- CSS ="""
140
- .contain { display: flex; flex-direction: column; }
141
- .gradio-container { height: 100vh !important; }
142
- #component-0 { height: 100%; }
143
- #chatbot { flex-grow: 1; overflow: auto;}
144
  """
145
 
146
 
147
- def new_chat(chatbot:gr.Chatbot,textbox):
148
- query_engine.reset()
149
- return gr.update(value=""),[],"",gr.File(visible=False),gr.File(visible=False)
150
-
151
-
152
- def chat(history, input):
153
- response = query_engine.chat(str(input))
154
- global current_refs
155
- files = []
156
- current_refs = ""
157
- for node in response.source_nodes:
158
- try:
159
- current_refs += f"{str(node.metadata['title'])},"
160
- except:
161
- current_refs += ""
162
- try:
163
- current_refs += f"Pg - {str(node.metadata['page_label'])},"
164
- except:
165
- current_refs += "Pg - ,"
166
- try:
167
- current_refs += f"File - {str(node.metadata['file_name'])} \n,"
168
- except:
169
- current_refs += "File - ,\n"
170
-
171
- try:
172
- files.append({'path':node.metadata['file_path'],'show':True,})
173
- except:
174
- files.append({'path':None,'show':False,})
175
-
176
- if len(files) < 2:
177
- for _ in range(2-len(files)):
178
- files.append({'path':None,'show':False,})
179
-
180
- return gr.update(value=""),history + [(input, response.response)],current_refs,gr.update(visible=files[0]['show'],value=files[0]['path']),gr.update(visible=files[1]['show'],value=files[1]['path'])
181
-
182
- def file_upload(file,chatbot):
183
- print(file)
184
- add_doc(file)
185
- return gr.update(value="ChatDoc"),chatbot
186
-
187
- with gr.Blocks(fill_height=True, css=CSS) as demo:
 
 
 
 
 
 
 
188
  with gr.Row():
189
- with gr.Column(scale=1):
190
- title = gr.Label(value="chatdoc", label="ChatDoc")
191
- files = gr.UploadButton(
192
- "📁 Upload PDF or doc files", file_types=[
193
- '.pdf',
194
- '.doc'
195
- ],
196
- file_count="multiple")
197
- references = gr.Textbox(label="References",interactive=False)
198
- file_down1 = gr.File(visible=False)
199
- file_down2 = gr.File(visible=False)
200
-
201
-
202
-
203
- with gr.Column(scale=9,):
204
- chatbot = gr.Chatbot(
205
- elem_id="chatbot",
206
- bubble_full_width=False,
207
- label="ChatDoc",
208
- avatar_images=["https://www.freeiconspng.com/thumbs/person-icon-blue/person-icon-blue-25.png","https://cdn-icons-png.flaticon.com/512/8943/8943377.png"],
209
- )
210
- with gr.Row():
211
- textbox = gr.Textbox(label="Type your message", scale=10)
212
- clear = gr.Button(value="New Chat", size="sm", scale=1)
213
- clear.click(new_chat,[],[textbox, chatbot,references,file_down1,file_down2])
214
- textbox.submit(chat, [chatbot, textbox], [textbox, chatbot,references,file_down1,file_down2])
215
-
216
-
217
- files.upload(file_upload,[files,chatbot],[title,chatbot])
218
-
219
-
220
- demo.launch(share=True)
 
1
+ import subprocess
2
+ import sys
3
+
4
+ subprocess.check_call([sys.executable,"-m","pip","install",'causal-conv1d'])
5
+ subprocess.check_call([sys.executable, "-m", "pip", "install", 'torch', 'numpy', 'miditok','mamba-ssm','gradio'])
6
+ subprocess.check_call(["apt-get", "install", "timidity", "-y"])
7
+
8
+ # !pip install pretty_midi midi2audio
9
+ # !pip install miditok
10
+ # !apt-get install fluidsynth
11
+ # !apt install timidity -y
12
+ # !pip install causal-conv1d>=1.1.0
13
+ # !pip install mamba-ssm
14
+ # !pip install gradio
15
+
16
+
17
+
18
+ # !export LC_ALL="en_US.UTF-8"
19
+ # !export LD_LIBRARY_PATH="/usr/lib64-nvidia"
20
+ # !export LIBRARY_PATH="/usr/local/cuda/lib64/stubs"
21
+
22
+ # subprocess.check_call(['export', 'LC_ALL="en_US.UTF-8"'])
23
+ # subprocess.check_call(['export', 'LD_LIBRARY_PATH="/usr/lib64-nvidia"'])
24
+ # subprocess.check_call(['export', 'LIBRARY_PATH="/usr/local/cuda/lib64/stubs"'])
25
  import os
26
+
27
+ os.environ['LC_ALL'] = "en_US.UTF-8"
28
+ os.environ['LD_LIBRARY_PATH'] = "/usr/lib64-nvidia"
29
+ os.environ['LIBRARY_PATH'] = "/usr/local/cuda/lib64/stubs"
30
+
31
+
32
+
33
+ import gradio as gr
34
+ import torch
35
+ from mamba_ssm import Mamba
36
+ from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
37
+ from mamba_ssm.models.config_mamba import MambaConfig
38
+ import numpy as np
39
+
40
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
41
+ if torch.cuda.is_available():
42
+ subprocess.check_call(['ldconfig', '/usr/lib64-nvidia'])
43
+ # !ldconfig /usr/lib64-nvidia
44
+
45
+ # !wget "https://huggingface.co/krystv/MIDI_Mamba-159M/resolve/main/MIDI_Mamba-159M_1536VS.pt"
46
+ # !wget "https://huggingface.co/krystv/MIDI_Mamba-159M/resolve/main/tokenizer_1536mix_BPE.json"
47
+ if os.path.isfile("MIDI_Mamba-159M_1536VS.pt") == False:
48
+ subprocess.check_call(['wget', 'https://huggingface.co/krystv/MIDI_Mamba-159M/resolve/main/MIDI_Mamba-159M_1536VS.pt'])
49
+
50
+ if os.path.isfile("tokenizer_1536mix_BPE.json") == False:
51
+ subprocess.check_call(['wget', 'https://huggingface.co/krystv/MIDI_Mamba-159M/resolve/main/tokenizer_1536mix_BPE.json'])
52
+
53
+
54
+
55
+ mc = MambaConfig()
56
+ mc.d_model = 768
57
+ mc.n_layer = 42
58
+ mc.vocab_size = 1536
59
+
60
+ from miditok import MIDILike,REMI,TokenizerConfig
61
+ from pathlib import Path
62
+ import torch
63
+
64
+ tokenizer = REMI(params='tokenizer_1536mix_BPE.json')
65
+
66
+
67
+
68
+ mf = MambaLMHeadModel(config=mc,device=device)
69
+ mf.load_state_dict(torch.load("/content/MIDI_Mamba-159M_1536VS.pt",map_location=device))
70
+
71
+
72
+
73
+ twitter_follow_link = "https://twitter.com/iamhemantindia"
74
+ instagram_follow_link = "https://instagram.com/iamhemantindia"
75
+
76
+ custom_html = f"""
77
+ <div style='text-align: center;'>
78
+ <a href="{twitter_follow_link}" target="_blank" style="margin-right: 5px;">
79
+ <img src="https://img.icons8.com/fluent/24/000000/twitter.png" alt="Follow on Twitter"/>
80
+ </a>
81
+ <a href="{instagram_follow_link}" target="_blank">
82
+ <img src="https://img.icons8.com/fluent/24/000000/instagram-new.png" alt="Follow on Instagram"/>
83
+ </a>
84
+ </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  """
86
 
87
 
88
+ @spaces.GPU(duration=120)
89
+ def generate(number,top_k_selector,top_p_selector, temperature_selector):
90
+ input_ids = torch.tensor([[1,]]).to(device)
91
+ out = mf.generate(
92
+ input_ids=input_ids,
93
+ max_length=int(number),
94
+ temperature=temperature_selector,
95
+ top_p=top_p_selector,
96
+ top_k=top_k_selector,
97
+
98
+ eos_token_id=2,)
99
+ m = tokenizer.decode(np.array(out[0].to('cpu')))
100
+ np.array(out.to('cpu')).shape
101
+ m.dump_midi('output.mid')
102
+ # !timidity output.mid -Ow -o - | ffmpeg -y -f wav -i - output.mp3
103
+ timidity_cmd = ['timidity', 'output.mid', '-Ow', '-o', 'output.wav']
104
+ subprocess.check_call(timidity_cmd)
105
+
106
+ # Then convert the WAV to MP3 using ffmpeg
107
+ ffmpeg_cmd = ['ffmpeg', '-y', '-f', 'wav', '-i', 'output.wav', 'output.mp3']
108
+ subprocess.check_call(ffmpeg_cmd)
109
+
110
+ return "output.mp3"
111
+
112
+
113
+ # text_box = gr.Textbox(label="Enter Text")
114
+
115
+
116
+ def generate_and_save(number,top_k_selector,top_p_selector, temperature_selector,generate_button,custom_html_wid):
117
+ output_audio = generate(number,top_k_selector,top_p_selector, temperature_selector)
118
+ return gr.Audio(output_audio,autoplay=True),gr.File(label="Download MIDI",value="output.mid"),generate_button
119
+
120
+
121
+
122
+
123
+
124
+
125
+ # iface = gr.Interface(fn=generate_and_save,
126
+ # inputs=[number_selector,top_k_selector,top_p_selector, temperature_selector,generate_button,custom_html_wid],
127
+ # outputs=[output_box,download_midi_button],
128
+ # title="MIDI Mamba-159M",submit_btn=False,
129
+ # clear_btn=False,
130
+ # description="MIDI Mamba is a Mamba based model trained on MIDI data collected from open internet to train music model.",
131
+ # allow_flagging=False,)
132
+
133
+ with gr.Blocks() as b1:
134
+ gr.Markdown("<h1 style='text-align: center;'>MIDI Mamba-159M <h1/> ")
135
+ gr.Markdown("<h3 style='text-align: center;'>MIDI Mamba is a Mamba based model trained on MIDI data collected from open internet to train music model. <br> by Hemant Kumar<h3/>")
136
  with gr.Row():
137
+ with gr.Column():
138
+ number_selector = gr.Number(label="Select Length of output",value=512)
139
+ top_p_selector = gr.Slider(label="Select Top P", minimum=0, maximum=1.0, step=0.05, value=0.9)
140
+ temperature_selector = gr.Slider(label="Select Temperature", minimum=0, maximum=1.0, step=0.1, value=0.9)
141
+ top_k_selector = gr.Slider(label="Select Top K", minimum=1, maximum=1536, step=1, value=30)
142
+ generate_button = gr.Button(value="Generate",variant="primary")
143
+ custom_html_wid = gr.HTML(custom_html)
144
+ with gr.Column():
145
+ output_box = gr.Audio("output.mp3",autoplay=True,)
146
+ download_midi_button = gr.File(label="Download MIDI")
147
+ generate_button.click(generate_and_save,inputs=[number_selector,top_k_selector,top_p_selector, temperature_selector,generate_button,custom_html_wid],outputs=[output_box,download_midi_button,generate_button])
148
+
149
+
150
+
151
+
152
+ b1.launch(share=True)