Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import requests | |
| import re | |
| import emoji | |
| import nltk | |
| import lxml | |
| import os | |
| from bs4 import BeautifulSoup | |
| from markdown import markdown | |
| from nltk.corpus import stopwords | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer, util | |
| from retry import retry | |
| from transformers import pipeline | |
| pipe = pipeline("translation", model="Helsinki-NLP/opus-mt-en-es") | |
| # 确保已下载 nltk 的停用词 | |
| nltk.download('stopwords') | |
| # 从环境变量中获取 hf_token | |
| hf_token = os.getenv('HF_TOKEN') | |
| model_id = "BAAI/bge-large-en-v1.5" | |
| feature_extraction_pipeline = pipeline("feature-extraction", model=model_id) | |
| # model_id = "BAAI/bge-large-en-v1.5" | |
| # api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_id}" | |
| # headers = {"Authorization": f"Bearer {hf_token}"} | |
| # @retry(tries=3, delay=10) | |
| # def query(texts): | |
| # response = requests.post(api_url, headers=headers, json={"inputs": texts}) | |
| # if response.status_code == 200: | |
| # result = response.json() | |
| # if isinstance(result, list): | |
| # return result | |
| # elif 'error' in result: | |
| # raise RuntimeError("Error from Hugging Face API: " + result['error']) | |
| # else: | |
| # raise RuntimeError("Failed to get response from Hugging Face API, status code: " + str(response.status_code)) | |
| # 加载嵌入向量数据集 | |
| faqs_embeddings_dataset = load_dataset('chenglu/hf-blogs-baai-embeddings') | |
| df = faqs_embeddings_dataset["train"].to_pandas() | |
| embeddings_array = df.T.to_numpy() | |
| dataset_embeddings = torch.from_numpy(embeddings_array).to(torch.float) | |
| # 加载原始数据集 | |
| original_dataset = load_dataset("chenglu/hf-blogs")['train'] | |
| # 定义英语停用词集 | |
| stop_words = set(stopwords.words('english')) | |
| def remove_stopwords(text): | |
| return ' '.join([word for word in text.split() if word.lower() not in stop_words]) | |
| def clean_content(content): | |
| content = re.sub(r"(```.*?```|`.*?`)", "", content, flags=re.DOTALL) | |
| content = BeautifulSoup(content, "html.parser").get_text() | |
| content = emoji.replace_emoji(content, replace='') | |
| content = re.sub(r"[^a-zA-Z\s]", "", content) | |
| content = re.sub(r"http\S+|www\S+|https\S+", '', content, flags=re.MULTILINE) | |
| content = markdown(content) | |
| content = ''.join(BeautifulSoup(content, 'lxml').findAll(text=True)) | |
| content = re.sub(r'\s+', ' ', content) | |
| return content | |
| def get_tags_for_local(dataset, local_value): | |
| entry = next((item for item in dataset if item['local'] == local_value), None) | |
| if entry: | |
| return entry['tags'] | |
| else: | |
| return None | |
| def gradio_query_interface(input_text): | |
| cleaned_text = clean_content(input_text) | |
| no_stopwords_text = remove_stopwords(cleaned_text) | |
| # new_embedding = query(no_stopwords_text) | |
| new_embedding = feature_extraction_pipeline(input_text) | |
| query_embeddings = torch.FloatTensor(new_embedding) | |
| hits = util.semantic_search(query_embeddings, dataset_embeddings, top_k=5) | |
| if all(hit['score'] < 0.6 for hit in hits[0]): | |
| return "Content Not related" | |
| else: | |
| highest_score_result = max(hits[0], key=lambda x: x['score']) | |
| highest_score_corpus_id = highest_score_result['corpus_id'] | |
| local = df.columns[highest_score_corpus_id] | |
| recommended_tags = get_tags_for_local(original_dataset, local) | |
| return f"Recommended category tags: {recommended_tags}" | |
| iface = gr.Interface( | |
| fn=gradio_query_interface, | |
| inputs="text", | |
| outputs="label" | |
| ) | |
| iface.launch() | |