zhichyu commited on
Commit
94806ac
·
1 Parent(s): c3bad71

Added TRACE_MALLOC_DELTA and TRACE_MALLOC_FULL (#3555)

Browse files

### What problem does this PR solve?

Added TRACE_MALLOC_DELTA and TRACE_MALLOC_FULL to debug task_executor.py
heap. Relates to #3518

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Files changed (2) hide show
  1. rag/svr/task_executor.py +32 -1
  2. rag/utils/es_conn.py +2 -0
rag/svr/task_executor.py CHANGED
@@ -22,7 +22,8 @@ import sys
22
  from api.utils.log_utils import initRootLogger
23
 
24
  CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
25
- initRootLogger(f"task_executor_{CONSUMER_NO}")
 
26
  for module in ["pdfminer"]:
27
  module_logger = logging.getLogger(module)
28
  module_logger.setLevel(logging.WARNING)
@@ -44,6 +45,7 @@ from functools import partial
44
  from io import BytesIO
45
  from multiprocessing.context import TimeoutError
46
  from timeit import default_timer as timer
 
47
 
48
  import numpy as np
49
 
@@ -490,14 +492,43 @@ def report_status():
490
  logging.exception("report_status got exception")
491
  time.sleep(30)
492
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
  def main():
494
  settings.init_settings()
495
  background_thread = threading.Thread(target=report_status)
496
  background_thread.daemon = True
497
  background_thread.start()
498
 
 
 
 
 
 
 
 
499
  while True:
500
  handle_task()
 
 
 
 
 
 
501
 
502
  if __name__ == "__main__":
503
  main()
 
22
  from api.utils.log_utils import initRootLogger
23
 
24
  CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
25
+ CONSUMER_NAME = "task_executor_" + CONSUMER_NO
26
+ initRootLogger(CONSUMER_NAME)
27
  for module in ["pdfminer"]:
28
  module_logger = logging.getLogger(module)
29
  module_logger.setLevel(logging.WARNING)
 
45
  from io import BytesIO
46
  from multiprocessing.context import TimeoutError
47
  from timeit import default_timer as timer
48
+ import tracemalloc
49
 
50
  import numpy as np
51
 
 
492
  logging.exception("report_status got exception")
493
  time.sleep(30)
494
 
495
+ def analyze_heap(snapshot1: tracemalloc.Snapshot, snapshot2: tracemalloc.Snapshot, snapshot_id: int, dump_full: bool):
496
+ msg = ""
497
+ if dump_full:
498
+ stats2 = snapshot2.statistics('lineno')
499
+ msg += f"{CONSUMER_NAME} memory usage of snapshot {snapshot_id}:\n"
500
+ for stat in stats2[:10]:
501
+ msg += f"{stat}\n"
502
+ stats1_vs_2 = snapshot2.compare_to(snapshot1, 'lineno')
503
+ msg += f"{CONSUMER_NAME} memory usage increase from snapshot {snapshot_id-1} to snapshot {snapshot_id}:\n"
504
+ for stat in stats1_vs_2[:10]:
505
+ msg += f"{stat}\n"
506
+ msg += f"{CONSUMER_NAME} detailed traceback for the top memory consumers:\n"
507
+ for stat in stats1_vs_2[:3]:
508
+ msg += '\n'.join(stat.traceback.format())
509
+ logging.info(msg)
510
+
511
  def main():
512
  settings.init_settings()
513
  background_thread = threading.Thread(target=report_status)
514
  background_thread.daemon = True
515
  background_thread.start()
516
 
517
+ TRACE_MALLOC_DELTA = int(os.environ.get('TRACE_MALLOC_DELTA', "0"))
518
+ TRACE_MALLOC_FULL = int(os.environ.get('TRACE_MALLOC_FULL', "0"))
519
+ if TRACE_MALLOC_DELTA > 0:
520
+ if TRACE_MALLOC_FULL < TRACE_MALLOC_DELTA:
521
+ TRACE_MALLOC_FULL = TRACE_MALLOC_DELTA
522
+ tracemalloc.start()
523
+ snapshot1 = tracemalloc.take_snapshot()
524
  while True:
525
  handle_task()
526
+ num_tasks = DONE_TASKS + FAILED_TASKS
527
+ if TRACE_MALLOC_DELTA> 0 and num_tasks > 0 and num_tasks % TRACE_MALLOC_DELTA == 0:
528
+ snapshot2 = tracemalloc.take_snapshot()
529
+ analyze_heap(snapshot1, snapshot2, int(num_tasks/TRACE_MALLOC_DELTA), num_tasks % TRACE_MALLOC_FULL == 0)
530
+ snapshot1 = snapshot2
531
+ snapshot2 = None
532
 
533
  if __name__ == "__main__":
534
  main()
rag/utils/es_conn.py CHANGED
@@ -237,6 +237,7 @@ class ESConnection(DocStoreConnection):
237
  res = []
238
  for _ in range(ATTEMPT_TIME):
239
  try:
 
240
  r = self.es.bulk(index=(indexName), operations=operations,
241
  refresh=False, timeout="60s")
242
  if re.search(r"False", str(r["errors"]), re.IGNORECASE):
@@ -248,6 +249,7 @@ class ESConnection(DocStoreConnection):
248
  res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"]))
249
  return res
250
  except Exception as e:
 
251
  logging.warning("ESConnection.insert got exception: " + str(e))
252
  res = []
253
  if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
 
237
  res = []
238
  for _ in range(ATTEMPT_TIME):
239
  try:
240
+ res = []
241
  r = self.es.bulk(index=(indexName), operations=operations,
242
  refresh=False, timeout="60s")
243
  if re.search(r"False", str(r["errors"]), re.IGNORECASE):
 
249
  res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"]))
250
  return res
251
  except Exception as e:
252
+ res.append(str(e))
253
  logging.warning("ESConnection.insert got exception: " + str(e))
254
  res = []
255
  if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):