YanshekWoo commited on
Commit
c9d8253
·
verified ·
1 Parent(s): 2b145d6

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +21 -14
app.py CHANGED
@@ -19,7 +19,7 @@ file_example = """Please upload a JSON file with a "text" field (with optional "
19
  {"title": "Title B", "text": "This an example text with the title"},
20
  ]
21
  ```
22
- Due to the computation resources, please test with small scale data.
23
  """
24
 
25
 
@@ -42,6 +42,12 @@ def upload_file_fn(
42
  try:
43
  with open(file_path) as f:
44
  document_data = json.load(f)
 
 
 
 
 
 
45
  documents = []
46
  for obj in document_data:
47
  text = obj["title"] + "\n" + obj["text"] if obj.get("title") else obj["text"]
@@ -55,18 +61,13 @@ def upload_file_fn(
55
  gr.Error(str(e))
56
  return None, gr.update(interactive=False)
57
 
58
- if len(documents) < 3:
59
- gr.Error("Please upload more than 3 documents.")
60
  return None, gr.update(interactive=False)
61
-
62
- gr.Info(f"Upload {len(documents)} documents.")
63
- if len(documents) > 1000:
64
- gr.Info(f"Cut uploaded documents to 1000 due to the computation resource.")
65
- documents = documents[: 1000]
66
 
67
  # documents_embeddings = model.encode(documents, show_progress_bar=True)
68
  documents_embeddings = []
69
- batch_size = 8
70
  for i in tqdm(range(0, len(documents), batch_size)):
71
  batch_documents = documents[i: i+batch_size]
72
  batch_embeddings = model.encode(batch_documents, show_progress_bar=True)
@@ -87,7 +88,7 @@ def clear_file_fn():
87
 
88
 
89
  def retrieve_document_fn(question, document_states, instruct):
90
- num_retrieval_doc = 3
91
 
92
  if document_states is None:
93
  gr.Warning("Please upload documents first!")
@@ -95,11 +96,16 @@ def retrieve_document_fn(question, document_states, instruct):
95
 
96
  document_data, document_index = document_states["document_data"], document_states["document_index"]
97
 
98
- question_embedding = model.encode([str(instruct) + str(question)])
 
 
 
 
 
99
  batch_scores, batch_inxs = document_index.search(question_embedding, k=min(len(document_data), 150))
100
 
101
  answers = [document_data[i]["text"] for i in batch_inxs[0][:num_retrieval_doc]]
102
- return answers[0], answers[1], answers[2], document_states
103
 
104
 
105
  def main(args):
@@ -126,9 +132,10 @@ def main(args):
126
  retrieval_interface = gr.Interface(
127
  fn=retrieve_document_fn,
128
  inputs=[gr.Textbox(label="Query"), document_state],
129
- outputs=[gr.Text(label="Recall-1"), gr.Text(label="Recall-2"), gr.Text(label="Recall-3"), gr.State()],
130
  additional_inputs=[gr.Textbox("Instruct: Given a query, retrieve documents that answer the query. \n Query: ", label="Instruct of Query", lines=2)],
131
  concurrency_limit=1,
 
132
  )
133
  # retrieval_interface.input_components[0] = gr.update(interactive=False)
134
 
@@ -153,7 +160,7 @@ def main(args):
153
  if __name__ == "__main__":
154
  parser = argparse.ArgumentParser()
155
  parser.add_argument("--model_name_or_path", type=str, default="HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5")
156
- parser.add_argument("--revision", type=str, default="refs/pr/2")
157
 
158
  args = parser.parse_args()
159
  main(args)
 
19
  {"title": "Title B", "text": "This an example text with the title"},
20
  ]
21
  ```
22
+ Due to the computation resources, please test with small scale data (<1000).
23
  """
24
 
25
 
 
42
  try:
43
  with open(file_path) as f:
44
  document_data = json.load(f)
45
+
46
+ gr.Info(f"Upload {len(document_data)} documents.")
47
+ if len(document_data) > 1000:
48
+ gr.Info(f"Cut uploaded documents to 1000 due to the computation resource.")
49
+ document_data = document_data[: 1000]
50
+
51
  documents = []
52
  for obj in document_data:
53
  text = obj["title"] + "\n" + obj["text"] if obj.get("title") else obj["text"]
 
61
  gr.Error(str(e))
62
  return None, gr.update(interactive=False)
63
 
64
+ if len(documents) < 5:
65
+ gr.Error("Please upload more than 53 documents.")
66
  return None, gr.update(interactive=False)
 
 
 
 
 
67
 
68
  # documents_embeddings = model.encode(documents, show_progress_bar=True)
69
  documents_embeddings = []
70
+ batch_size = 16
71
  for i in tqdm(range(0, len(documents), batch_size)):
72
  batch_documents = documents[i: i+batch_size]
73
  batch_embeddings = model.encode(batch_documents, show_progress_bar=True)
 
88
 
89
 
90
  def retrieve_document_fn(question, document_states, instruct):
91
+ num_retrieval_doc = 5
92
 
93
  if document_states is None:
94
  gr.Warning("Please upload documents first!")
 
96
 
97
  document_data, document_index = document_states["document_data"], document_states["document_index"]
98
 
99
+ question_with_inst = str(instruct) + str(question)
100
+ if len(question_with_inst.strip()) == 0:
101
+ gr.Warning("Please enter a non-empty query.")
102
+ return None, None, None, None, None, document_states
103
+
104
+ question_embedding = model.encode([question_with_inst])
105
  batch_scores, batch_inxs = document_index.search(question_embedding, k=min(len(document_data), 150))
106
 
107
  answers = [document_data[i]["text"] for i in batch_inxs[0][:num_retrieval_doc]]
108
+ return answers[0], answers[1], answers[2], answers[3], answers[4],document_states
109
 
110
 
111
  def main(args):
 
132
  retrieval_interface = gr.Interface(
133
  fn=retrieve_document_fn,
134
  inputs=[gr.Textbox(label="Query"), document_state],
135
+ outputs=[gr.Text(label="Recall-1"), gr.Text(label="Recall-2"), gr.Text(label="Recall-3"), gr.Text(label="Recall-4"), gr.Text(label="Recall-5"), gr.State()],
136
  additional_inputs=[gr.Textbox("Instruct: Given a query, retrieve documents that answer the query. \n Query: ", label="Instruct of Query", lines=2)],
137
  concurrency_limit=1,
138
+ allow_flagging="never",
139
  )
140
  # retrieval_interface.input_components[0] = gr.update(interactive=False)
141
 
 
160
  if __name__ == "__main__":
161
  parser = argparse.ArgumentParser()
162
  parser.add_argument("--model_name_or_path", type=str, default="HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5")
163
+ parser.add_argument("--revision", type=str, default=None)
164
 
165
  args = parser.parse_args()
166
  main(args)