| import pyrootutils | |
| root = pyrootutils.setup_root( | |
| search_from=__file__, | |
| indicator=[".project-root"], | |
| pythonpath=True, | |
| dotenv=True, | |
| ) | |
| import argparse | |
| import logging | |
| from demo.model_utils import ( | |
| retrieve_all_relevant_spans, | |
| retrieve_all_relevant_spans_for_all_documents, | |
| retrieve_relevant_spans, | |
| ) | |
| from src.langchain_modules import DocumentAwareSpanRetrieverWithRelations | |
| logger = logging.getLogger(__name__) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "-c", | |
| "--config_path", | |
| type=str, | |
| default="configs/retriever/related_span_retriever_with_relations_from_other_docs.yaml", | |
| ) | |
| parser.add_argument( | |
| "--data_path", | |
| type=str, | |
| required=True, | |
| help="Path to a zip or directory containing a retriever dump.", | |
| ) | |
| parser.add_argument("-k", "--top_k", type=int, default=10) | |
| parser.add_argument("-t", "--threshold", type=float, default=0.95) | |
| parser.add_argument( | |
| "-o", | |
| "--output_path", | |
| type=str, | |
| required=True, | |
| ) | |
| parser.add_argument( | |
| "--query_doc_id", | |
| type=str, | |
| default=None, | |
| help="If provided, retrieve all spans for only this query document.", | |
| ) | |
| parser.add_argument( | |
| "--query_span_id", | |
| type=str, | |
| default=None, | |
| help="If provided, retrieve all spans for only this query span.", | |
| ) | |
| args = parser.parse_args() | |
| logging.basicConfig( | |
| format="%(asctime)s %(levelname)-8s %(message)s", | |
| level=logging.INFO, | |
| datefmt="%Y-%m-%d %H:%M:%S", | |
| ) | |
| if not args.output_path.endswith(".json"): | |
| raise ValueError("only support json output") | |
| logger.info(f"instantiating retriever from {args.config_path}...") | |
| retriever = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config_file( | |
| args.config_path | |
| ) | |
| logger.info(f"loading data from {args.data_path}...") | |
| retriever.load_from_disc(args.data_path) | |
| search_kwargs = {"k": args.top_k, "score_threshold": args.threshold} | |
| logger.info(f"use search_kwargs: {search_kwargs}") | |
| if args.query_span_id is not None: | |
| logger.warning(f"retrieving results for single span: {args.query_span_id}") | |
| all_spans_for_all_documents = retrieve_relevant_spans( | |
| retriever=retriever, query_span_id=args.query_span_id, **search_kwargs | |
| ) | |
| elif args.query_doc_id is not None: | |
| logger.warning(f"retrieving results for single document: {args.query_doc_id}") | |
| all_spans_for_all_documents = retrieve_all_relevant_spans( | |
| retriever=retriever, query_doc_id=args.query_doc_id, **search_kwargs | |
| ) | |
| else: | |
| all_spans_for_all_documents = retrieve_all_relevant_spans_for_all_documents( | |
| retriever=retriever, **search_kwargs | |
| ) | |
| if all_spans_for_all_documents is None: | |
| logger.warning("no relevant spans found in any document") | |
| exit(0) | |
| logger.info(f"dumping results to {args.output_path}...") | |
| all_spans_for_all_documents.to_json(args.output_path) | |
| logger.info("done") | |