Spaces:
Running
on
T4
Running
on
T4
| import gradio as gr | |
| import nltk | |
| import numpy as np | |
| import pandas as pd | |
| from librosa import load, resample | |
| from sentence_transformers import SentenceTransformer, util | |
| from transformers import pipeline | |
| # Constants | |
| filename = "df10k_SP500_2020.csv.zip" | |
| model_name = "sentence-transformers/msmarco-distilbert-base-v4" | |
| max_sequence_length = 512 | |
| embeddings_filename = "df10k_embeddings_msmarco-distilbert-base-v4.npz" | |
| asr_model = "facebook/wav2vec2-xls-r-300m-21-to-en" | |
| # Load corpus | |
| df = pd.read_csv(filename) | |
| df.drop_duplicates(inplace=True) | |
| print(f"Number of documents: {len(df)}") | |
| nltk.download("punkt") | |
| corpus = [] | |
| sentence_count = [] | |
| for _, row in df.iterrows(): | |
| # We're interested in the 'mdna' column: 'Management discussion and analysis' | |
| sentences = nltk.tokenize.sent_tokenize(str(row["mdna"]), language="english") | |
| sentence_count.append(len(sentences)) | |
| for _, s in enumerate(sentences): | |
| corpus.append(s) | |
| print(f"Number of sentences: {len(corpus)}") | |
| # Load pre-embedded corpus | |
| corpus_embeddings = np.load(embeddings_filename)["arr_0"] | |
| print(f"Number of embeddings: {corpus_embeddings.shape[0]}") | |
| # Load embedding model | |
| model = SentenceTransformer(model_name) | |
| model.max_seq_length = max_sequence_length | |
| # Load speech to text model | |
| asr = pipeline( | |
| "automatic-speech-recognition", model=asr_model, feature_extractor=asr_model | |
| ) | |
| def find_sentences(query, hits): | |
| query_embedding = model.encode(query) | |
| hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=hits) | |
| hits = hits[0] | |
| output = pd.DataFrame( | |
| columns=["Ticker", "Form type", "Filing date", "Text", "Score"] | |
| ) | |
| for hit in hits: | |
| corpus_id = hit["corpus_id"] | |
| # Find source document based on sentence index | |
| count = 0 | |
| for idx, c in enumerate(sentence_count): | |
| count += c | |
| if corpus_id > count - 1: | |
| continue | |
| else: | |
| doc = df.iloc[idx] | |
| new_row = { | |
| "Ticker": doc["ticker"], | |
| "Form type": doc["form_type"], | |
| "Filing date": doc["filing_date"], | |
| "Text": corpus[corpus_id][:80], | |
| "Score": "{:.2f}".format(hit["score"]), | |
| } | |
| output = pd.concat([output, pd.DataFrame([new_row])], ignore_index=True) | |
| break | |
| return output | |
| def process(input_selection, query, filepath, hits): | |
| if input_selection == "speech": | |
| speech, sampling_rate = load(filepath) | |
| if sampling_rate != 16000: | |
| speech = resample(speech, orig_sr=sampling_rate, target_sr=16000) | |
| text = asr(speech)["text"] | |
| else: | |
| text = query | |
| return text, find_sentences(text, hits) | |
| # Gradio inputs | |
| buttons = gr.Radio( | |
| ["text", "speech"], type="value", value="speech", label="Input selection" | |
| ) | |
| text_query = gr.Textbox( | |
| lines=1, | |
| label="Text input", | |
| value="The company is under investigation by tax authorities for potential fraud.", | |
| ) | |
| mic = gr.Audio( | |
| source="microphone", type="filepath", label="Speech input", optional=True | |
| ) | |
| slider = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of hits") | |
| # Gradio outputs | |
| speech_query = gr.Textbox(type="text", label="Query string") | |
| results = gr.Dataframe( | |
| type="pandas", | |
| headers=["Ticker", "Form type", "Filing date", "Text", "Score"], | |
| label="Query results", | |
| ) | |
| iface = gr.Interface( | |
| theme="huggingface", | |
| description="This Spaces lets you query a text corpus containing 2020 annual filings for all S&P500 companies. You can type a text query in English, or record an audio query in 21 languages. You can find a technical deep dive at https://www.youtube.com/watch?v=YPme-gR0f80", | |
| fn=process, | |
| inputs=[buttons, text_query, mic, slider], | |
| outputs=[speech_query, results], | |
| examples=[ | |
| [ | |
| "speech", | |
| "Nos ventes internationales ont significativement augmenté.", | |
| "sales_16k_fr.wav", | |
| 3, | |
| ], | |
| [ | |
| "speech", | |
| "Le prix de l'énergie pourrait avoir un impact négatif dans le futur.", | |
| "energy_16k_fr.wav", | |
| 3, | |
| ], | |
| [ | |
| "speech", | |
| "El precio de la energía podría tener un impacto negativo en el futuro.", | |
| "energy_24k_es.wav", | |
| 3, | |
| ], | |
| [ | |
| "speech", | |
| "Mehrere Steuerbehörden untersuchen unser Unternehmen.", | |
| "tax_24k_de.wav", | |
| 3, | |
| ], | |
| ], | |
| ) | |
| iface.launch() | |