YanshekWoo commited on
Commit
b1f1fd7
·
verified ·
1 Parent(s): 633bc73

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. app.py +128 -0
  2. requirements.txt +3 -0
  3. resources/head.html +36 -0
  4. resources/styles.css +158 -0
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import argparse
3
+ from pathlib import Path
4
+ from typing import List
5
+
6
+ import gradio as gr
7
+ import faiss
8
+ import numpy as np
9
+ import torch
10
+ from sentence_transformers import SentenceTransformer
11
+
12
+
13
+ file_example = """Please upload a JSON file with a "text" field (with optional "title" field). For example
14
+ ```JSON
15
+ [
16
+ {"title": "", "text": "This an example text without the title"},
17
+ {"title": "Title A", "text": "This an example text with the title"},
18
+ {"title": "Title B", "text": "This an example text with the title"},
19
+ ]
20
+ ```"""
21
+
22
+
23
+ def create_index(embeddings, use_gpu):
24
+ index = faiss.IndexFlatIP(len(embeddings[0]))
25
+ embeddings = np.asarray(embeddings, dtype=np.float32)
26
+ if use_gpu:
27
+ co = faiss.GpuMultipleClonerOptions()
28
+ co.shard = True
29
+ co.useFloat16 = True
30
+ index = faiss.index_cpu_to_all_gpus(index, co=co)
31
+ index.add(embeddings)
32
+ return index
33
+
34
+
35
+ def upload_file_fn(
36
+ file_path: List[str],
37
+ progress: gr.Progress = gr.Progress(track_tqdm=True)
38
+ ):
39
+ try:
40
+ with open(file_path) as f:
41
+ document_data = json.load(f)
42
+ documents = []
43
+ for obj in document_data:
44
+ text = obj["title"] + "\n" + obj["text"] if obj.get("title") else obj["text"]
45
+ documents.append(text)
46
+ except Exception as e:
47
+ print(e)
48
+ gr.Warning("Read the file failed. Please check the data format.")
49
+ return None, None
50
+
51
+ documents_embeddings = model.encode(documents)
52
+
53
+ document_index = create_index(documents_embeddings, use_gpu=False)
54
+
55
+ if torch.cuda.is_available():
56
+ torch.cuda.empty_cache()
57
+ torch.cuda.ipc_collect()
58
+
59
+ return document_index, document_data
60
+
61
+
62
+ def clear_file_fn():
63
+ return None, None
64
+
65
+
66
+ def retrieve_document_fn(question, document_data, document_index):
67
+ num_retrieval_doc = 3
68
+ if document_index is None or document_data is None:
69
+ gr.Warning("Please upload documents first!")
70
+ return [None for i in range(num_retrieval_doc)]
71
+
72
+ question_embedding = model.encode([question])
73
+ batch_scores, batch_inxs = document_index.search(question_embedding, k=num_retrieval_doc)
74
+
75
+ answers = [document_data[i]["text"] for i in batch_inxs[0][:num_retrieval_doc]]
76
+ return tuple(answers)
77
+
78
+
79
+ def main(args):
80
+ global model
81
+
82
+ model = SentenceTransformer(args.model_name_or_path)
83
+
84
+ document_index = gr.State()
85
+ document_data = gr.State()
86
+
87
+
88
+ with open(Path(__file__).parent / "resources/head.html") as html_file:
89
+ head = html_file.read().strip()
90
+ with gr.Blocks(theme=gr.themes.Soft(font="sans-serif").set(background_fill_primary="linear-gradient(90deg, #e3ffe7 0%, #d9e7ff 100%)", background_fill_primary_dark="linear-gradient(90deg, #4b6cb7 0%, #182848 100%)",),
91
+ head=head,
92
+ css=Path(__file__).parent / "resources/styles.css",
93
+ title="KaLM-Embedding",
94
+ fill_height=True,
95
+ analytics_enabled=False) as demo:
96
+ gr.Markdown(file_example)
97
+ doc_files_box = gr.File(label="Upload Documents", file_types=[".json"], file_count="single")
98
+ retrieval_interface = gr.Interface(
99
+ fn=retrieve_document_fn,
100
+ inputs=["text"],
101
+ outputs=["text", "text", "text"],
102
+ additional_inputs=[document_data, document_index],
103
+ concurrency_limit=1,
104
+ )
105
+
106
+ doc_files_box.upload(
107
+ upload_file_fn,
108
+ [doc_files_box],
109
+ [document_index, document_data],
110
+ queue=True,
111
+ trigger_mode="once"
112
+ )
113
+ doc_files_box.clear(
114
+ upload_file_fn,
115
+ None,
116
+ [document_index, document_data],
117
+ queue=True,
118
+ trigger_mode="once"
119
+ )
120
+ demo.launch()
121
+
122
+
123
+ if __name__ == "__main__":
124
+ parser = argparse.ArgumentParser()
125
+ parser.add_argument("--model_name_or_path", type=str, default="HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5")
126
+
127
+ args = parser.parse_args()
128
+ main(args)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ transformers==4.39.1
2
+ sentence-transformers==2.5.1
3
+ faiss-cpu==1.8.0
resources/head.html ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <style>
2
+ :root {
3
+ --highlight-background-color-light: #F1EAFF;
4
+ --highlight-background-color-dark: #3E00FF;
5
+ }
6
+
7
+ body {
8
+ --highlight-background-color: var(--highlight-background-color-light);
9
+ }
10
+
11
+ @media (prefers-color-scheme: dark) {
12
+ body {
13
+ --highlight-background-color: var(--highlight-background-color-dark);
14
+ }
15
+ }
16
+ </style>
17
+
18
+ <script>
19
+ document.addEventListener('click', function(event) {
20
+ if (event.target.tagName.toLowerCase() === 'a') {
21
+ var href = event.target.getAttribute('href');
22
+
23
+ if (href && href.startsWith('#')) {
24
+ var targetId = href.substring(1);
25
+ var targetArticle = document.getElementById(targetId);
26
+ var articles = document.getElementsByTagName('article');
27
+
28
+ for (var i = 0; i < articles.length; i++) {
29
+ articles[i].style.backgroundColor = '';
30
+ }
31
+
32
+ targetArticle.style.backgroundColor = 'var(--highlight-background-color)';
33
+ }
34
+ }
35
+ });
36
+ </script>
resources/styles.css ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ a {
3
+ text-decoration: none!important;
4
+ }
5
+
6
+ article {
7
+ padding: 12px;
8
+ }
9
+
10
+
11
+ progress::-webkit-progress-bar {
12
+ border-radius: 8rpx !important;
13
+ background-color: #f0f0f0;
14
+ }
15
+ progress::-webkit-progress-value {
16
+ border-radius: 8rpx !important;
17
+ }
18
+ .progress-1 {
19
+ color: #FF004D;
20
+ }
21
+ .progress-1::-webkit-progress-value {
22
+ background-color:#FF004D;
23
+ }
24
+ .progress-1::-moz-progress-bar {
25
+ background-color:#FF004D;
26
+ }
27
+ .progress-2 {
28
+ color: #FF8400;
29
+ }
30
+ .progress-2::-webkit-progress-value {
31
+ background-color:#FF8400;
32
+ }
33
+ .progress-2::-moz-progress-bar {
34
+ background-color:#FF8400;
35
+ }
36
+ .progress-3 {
37
+ color: #0079FF;
38
+ }
39
+ .progress-3::-webkit-progress-value {
40
+ background-color:#0079FF;
41
+ }
42
+ .progress-3::-moz-progress-bar {
43
+ background-color:#0079FF;
44
+ }
45
+
46
+ .factual-score {
47
+ width: 20px;
48
+ height: 4px;
49
+ }
50
+
51
+
52
+ .hide {
53
+ display: none;
54
+ }
55
+
56
+ .html-text {
57
+ white-space: pre-line;
58
+ word-break: normal;
59
+ text-align: justify;
60
+ overflow-y: auto;
61
+ overflow-x: hidden;
62
+ height: 450px;
63
+ padding: 2px;
64
+ }
65
+
66
+ .tab {
67
+ border: none;
68
+ outline: none;
69
+ padding: 0;
70
+ margin: 0;
71
+ }
72
+
73
+
74
+ /* html text slider */
75
+ ::-webkit-scrollbar {
76
+ width: 6px;
77
+ height: 6px;
78
+ }
79
+ ::-webkit-scrollbar-track {
80
+ border-radius: 3px;
81
+ background: rgba(0,0,0,0.06);
82
+ -webkit-box-shadow: inset 0 0 5px rgba(0,0,0,0.08);
83
+ }
84
+ ::-webkit-scrollbar-thumb {
85
+ border-radius: 3px;
86
+ background: rgba(0,0,0,0.12);
87
+ -webkit-box-shadow: inset 0 0 10px rgba(0,0,0,0.2);
88
+ }
89
+
90
+
91
+ /* split line style */
92
+ .hr-edge-weak {
93
+ border: 0;
94
+ padding-top: 1px;
95
+ background: linear-gradient(to right, transparent, #d0d0d5, transparent);
96
+ }
97
+
98
+ .hr-double-arrow {
99
+ color: #d0d0d5;
100
+ border: double;
101
+ border-width: 3px 5px;
102
+ border-color: #d0d0d5 transparent;
103
+ height: 1px;
104
+ overflow: visible;
105
+ margin-left: 20px;
106
+ margin-right: 20px;
107
+ position: relative;
108
+ }
109
+ .hr-double-arrow:before,
110
+ .hr-double-arrow:after {
111
+ content: '';
112
+ position: absolute;
113
+ width: 5px; height: 5px;
114
+ border-width: 0 3px 3px 0;
115
+ border-style: double;
116
+ top: -3px;
117
+ background: radial-gradient(2px at 1px 1px, currentColor 2px, transparent 0) no-repeat;
118
+ }
119
+ .hr-double-arrow:before {
120
+ transform: rotate(-45deg);
121
+ left: -20px;
122
+ }
123
+ .hr-double-arrow:after {
124
+ transform: rotate(135deg);
125
+ right: -20px;
126
+ }
127
+
128
+
129
+ .citation-button {
130
+ -webkit-tap-highlight-color: rgba(0,0,0,0);
131
+ -webkit-text-size-adjust: 100%;
132
+ tab-size: 4;
133
+ color-scheme: light;
134
+ word-break: break-word;
135
+ white-space: pre-wrap;
136
+ font-family: Open Sans,sans-serif!important;
137
+ box-sizing: border-box;
138
+ border: 0 solid #e5e7eb;
139
+ touch-action: manipulation;
140
+ margin-right: 1px;
141
+ display: inline-flex;
142
+ height: .75rem;
143
+ min-width: .75rem;
144
+ align-items: center;
145
+ justify-content: center;
146
+ border-radius: 9999px;
147
+ background-color: var(--block-label-background-fill);
148
+ padding-left: .25rem;
149
+ padding-right: .25rem;
150
+ vertical-align: top;
151
+ font-size: 8px;
152
+ font-weight: 600;
153
+ line-height: 0;
154
+ color: var(--block-label-text-color);
155
+ outline: none;
156
+ cursor: pointer;
157
+ transition: color 0.3s;
158
+ }