Upload DatasetTransformer
Browse files- README.md +369 -369
- config.json +7 -1
- misc.py +514 -0
- model.py +607 -0
- model.safetensors +2 -2
README.md
CHANGED
|
@@ -4,12 +4,14 @@ tags:
|
|
| 4 |
model-index:
|
| 5 |
- name: cde-small-v1
|
| 6 |
results:
|
| 7 |
-
-
|
| 8 |
-
|
|
|
|
| 9 |
name: MTEB AmazonCounterfactualClassification (en)
|
| 10 |
-
revision: e8379541af4e31359cca9fbcf4b00f2671dba205
|
| 11 |
-
split: test
|
| 12 |
type: mteb/amazon_counterfactual
|
|
|
|
|
|
|
|
|
|
| 13 |
metrics:
|
| 14 |
- type: accuracy
|
| 15 |
value: 87.01492537313433
|
|
@@ -23,14 +25,14 @@ model-index:
|
|
| 23 |
value: 87.74802754480477
|
| 24 |
- type: main_score
|
| 25 |
value: 87.01492537313433
|
| 26 |
-
|
| 27 |
type: Classification
|
| 28 |
-
|
| 29 |
-
config: default
|
| 30 |
name: MTEB AmazonPolarityClassification (default)
|
| 31 |
-
revision: e2d317d38cd51312af73b3d32a06d1a08b442046
|
| 32 |
-
split: test
|
| 33 |
type: mteb/amazon_polarity
|
|
|
|
|
|
|
|
|
|
| 34 |
metrics:
|
| 35 |
- type: accuracy
|
| 36 |
value: 94.652275
|
|
@@ -44,14 +46,14 @@ model-index:
|
|
| 44 |
value: 94.64655930708355
|
| 45 |
- type: main_score
|
| 46 |
value: 94.652275
|
| 47 |
-
|
| 48 |
type: Classification
|
| 49 |
-
|
| 50 |
-
config: en
|
| 51 |
name: MTEB AmazonReviewsClassification (en)
|
| 52 |
-
revision: 1399c76144fd37290681b995c656ef9b2e06e26d
|
| 53 |
-
split: test
|
| 54 |
type: mteb/amazon_reviews_multi
|
|
|
|
|
|
|
|
|
|
| 55 |
metrics:
|
| 56 |
- type: accuracy
|
| 57 |
value: 55.75599999999999
|
|
@@ -61,14 +63,14 @@ model-index:
|
|
| 61 |
value: 55.07058630829347
|
| 62 |
- type: main_score
|
| 63 |
value: 55.75599999999999
|
| 64 |
-
|
| 65 |
-
type:
|
| 66 |
-
|
| 67 |
-
config: default
|
| 68 |
name: MTEB ArguAna (default)
|
| 69 |
-
revision: c22ab2a51041ffd869aaddef7af8d8215647e41a
|
| 70 |
-
split: test
|
| 71 |
type: mteb/arguana
|
|
|
|
|
|
|
|
|
|
| 72 |
metrics:
|
| 73 |
- type: main_score
|
| 74 |
value: 69.959
|
|
@@ -352,14 +354,14 @@ model-index:
|
|
| 352 |
value: 74.182
|
| 353 |
- type: recall_at_5
|
| 354 |
value: 84.495
|
| 355 |
-
|
| 356 |
-
type:
|
| 357 |
-
|
| 358 |
-
config: default
|
| 359 |
name: MTEB ArxivClusteringP2P (default)
|
| 360 |
-
revision: a122ad7f3f0291bf49cc6f4d32aa80929df69d5d
|
| 361 |
-
split: test
|
| 362 |
type: mteb/arxiv-clustering-p2p
|
|
|
|
|
|
|
|
|
|
| 363 |
metrics:
|
| 364 |
- type: main_score
|
| 365 |
value: 48.54672141116669
|
|
@@ -367,14 +369,14 @@ model-index:
|
|
| 367 |
value: 48.54672141116669
|
| 368 |
- type: v_measure_std
|
| 369 |
value: 14.037498386768362
|
| 370 |
-
|
| 371 |
type: Clustering
|
| 372 |
-
|
| 373 |
-
config: default
|
| 374 |
name: MTEB ArxivClusteringS2S (default)
|
| 375 |
-
revision: f910caf1a6075f7329cdf8c1a6135696f37dbd53
|
| 376 |
-
split: test
|
| 377 |
type: mteb/arxiv-clustering-s2s
|
|
|
|
|
|
|
|
|
|
| 378 |
metrics:
|
| 379 |
- type: main_score
|
| 380 |
value: 40.5914039166466
|
|
@@ -382,14 +384,14 @@ model-index:
|
|
| 382 |
value: 40.5914039166466
|
| 383 |
- type: v_measure_std
|
| 384 |
value: 14.385069818910331
|
| 385 |
-
|
| 386 |
-
type:
|
| 387 |
-
|
| 388 |
-
config: default
|
| 389 |
name: MTEB AskUbuntuDupQuestions (default)
|
| 390 |
-
revision: 2000358ca161889fa9c082cb41daa8dcfb161a54
|
| 391 |
-
split: test
|
| 392 |
type: mteb/askubuntudupquestions-reranking
|
|
|
|
|
|
|
|
|
|
| 393 |
metrics:
|
| 394 |
- type: main_score
|
| 395 |
value: 61.13621260261507
|
|
@@ -409,14 +411,14 @@ model-index:
|
|
| 409 |
value: 31.484257486448364
|
| 410 |
- type: nAUC_mrr_std
|
| 411 |
value: 21.252659250011632
|
| 412 |
-
|
| 413 |
-
type:
|
| 414 |
-
|
| 415 |
-
config: default
|
| 416 |
name: MTEB BIOSSES (default)
|
| 417 |
-
revision: d3fb88f8f02e40887cd149695127462bbcf29b4a
|
| 418 |
-
split: test
|
| 419 |
type: mteb/biosses-sts
|
|
|
|
|
|
|
|
|
|
| 420 |
metrics:
|
| 421 |
- type: cosine_pearson
|
| 422 |
value: 89.07028016646942
|
|
@@ -436,14 +438,14 @@ model-index:
|
|
| 436 |
value: 89.07028016646942
|
| 437 |
- type: spearman
|
| 438 |
value: 86.69595132967805
|
| 439 |
-
|
| 440 |
-
type:
|
| 441 |
-
|
| 442 |
-
config: default
|
| 443 |
name: MTEB Banking77Classification (default)
|
| 444 |
-
revision: 0fd18e25b25c072e09e0d92ab615fda904d66300
|
| 445 |
-
split: test
|
| 446 |
type: mteb/banking77
|
|
|
|
|
|
|
|
|
|
| 447 |
metrics:
|
| 448 |
- type: accuracy
|
| 449 |
value: 88.6038961038961
|
|
@@ -453,14 +455,14 @@ model-index:
|
|
| 453 |
value: 88.56824205739822
|
| 454 |
- type: main_score
|
| 455 |
value: 88.6038961038961
|
| 456 |
-
|
| 457 |
-
type:
|
| 458 |
-
|
| 459 |
-
config: default
|
| 460 |
name: MTEB BiorxivClusteringP2P (default)
|
| 461 |
-
revision: 65b79d1d13f80053f67aca9498d9402c2d9f1f40
|
| 462 |
-
split: test
|
| 463 |
type: mteb/biorxiv-clustering-p2p
|
|
|
|
|
|
|
|
|
|
| 464 |
metrics:
|
| 465 |
- type: main_score
|
| 466 |
value: 44.77800814327256
|
|
@@ -468,14 +470,14 @@ model-index:
|
|
| 468 |
value: 44.77800814327256
|
| 469 |
- type: v_measure_std
|
| 470 |
value: 0.6462535527471919
|
| 471 |
-
|
| 472 |
type: Clustering
|
| 473 |
-
|
| 474 |
-
config: default
|
| 475 |
name: MTEB BiorxivClusteringS2S (default)
|
| 476 |
-
revision: 258694dd0231531bc1fd9de6ceb52a0853c6d908
|
| 477 |
-
split: test
|
| 478 |
type: mteb/biorxiv-clustering-s2s
|
|
|
|
|
|
|
|
|
|
| 479 |
metrics:
|
| 480 |
- type: main_score
|
| 481 |
value: 38.16110272459102
|
|
@@ -483,14 +485,14 @@ model-index:
|
|
| 483 |
value: 38.16110272459102
|
| 484 |
- type: v_measure_std
|
| 485 |
value: 0.7456916212435019
|
| 486 |
-
|
| 487 |
-
type:
|
| 488 |
-
|
| 489 |
-
config: default
|
| 490 |
name: MTEB CQADupstackAndroidRetrieval (default)
|
| 491 |
-
revision: f46a197baaae43b4f621051089b82a364682dfeb
|
| 492 |
-
split: test
|
| 493 |
type: mteb/cqadupstack-android
|
|
|
|
|
|
|
|
|
|
| 494 |
metrics:
|
| 495 |
- type: main_score
|
| 496 |
value: 49.376
|
|
@@ -774,14 +776,14 @@ model-index:
|
|
| 774 |
value: 47.591
|
| 775 |
- type: recall_at_5
|
| 776 |
value: 54.245
|
| 777 |
-
|
| 778 |
type: Retrieval
|
| 779 |
-
|
| 780 |
-
config: default
|
| 781 |
name: MTEB CQADupstackEnglishRetrieval (default)
|
| 782 |
-
revision: ad9991cb51e31e31e430383c75ffb2885547b5f0
|
| 783 |
-
split: test
|
| 784 |
type: mteb/cqadupstack-english
|
|
|
|
|
|
|
|
|
|
| 785 |
metrics:
|
| 786 |
- type: main_score
|
| 787 |
value: 44.727
|
|
@@ -1065,14 +1067,14 @@ model-index:
|
|
| 1065 |
value: 42.085
|
| 1066 |
- type: recall_at_5
|
| 1067 |
value: 47.5
|
| 1068 |
-
|
| 1069 |
type: Retrieval
|
| 1070 |
-
|
| 1071 |
-
config: default
|
| 1072 |
name: MTEB CQADupstackGamingRetrieval (default)
|
| 1073 |
-
revision: 4885aa143210c98657558c04aaf3dc47cfb54340
|
| 1074 |
-
split: test
|
| 1075 |
type: mteb/cqadupstack-gaming
|
|
|
|
|
|
|
|
|
|
| 1076 |
metrics:
|
| 1077 |
- type: main_score
|
| 1078 |
value: 59.001999999999995
|
|
@@ -1356,14 +1358,14 @@ model-index:
|
|
| 1356 |
value: 57.916000000000004
|
| 1357 |
- type: recall_at_5
|
| 1358 |
value: 65.44
|
| 1359 |
-
|
| 1360 |
type: Retrieval
|
| 1361 |
-
|
| 1362 |
-
config: default
|
| 1363 |
name: MTEB CQADupstackGisRetrieval (default)
|
| 1364 |
-
revision: 5003b3064772da1887988e05400cf3806fe491f2
|
| 1365 |
-
split: test
|
| 1366 |
type: mteb/cqadupstack-gis
|
|
|
|
|
|
|
|
|
|
| 1367 |
metrics:
|
| 1368 |
- type: main_score
|
| 1369 |
value: 37.501
|
|
@@ -1647,14 +1649,14 @@ model-index:
|
|
| 1647 |
value: 37.218
|
| 1648 |
- type: recall_at_5
|
| 1649 |
value: 42.559000000000005
|
| 1650 |
-
|
| 1651 |
type: Retrieval
|
| 1652 |
-
|
| 1653 |
-
config: default
|
| 1654 |
name: MTEB CQADupstackMathematicaRetrieval (default)
|
| 1655 |
-
revision: 90fceea13679c63fe563ded68f3b6f06e50061de
|
| 1656 |
-
split: test
|
| 1657 |
type: mteb/cqadupstack-mathematica
|
|
|
|
|
|
|
|
|
|
| 1658 |
metrics:
|
| 1659 |
- type: main_score
|
| 1660 |
value: 27.653
|
|
@@ -1938,14 +1940,14 @@ model-index:
|
|
| 1938 |
value: 25.469
|
| 1939 |
- type: recall_at_5
|
| 1940 |
value: 31.316
|
| 1941 |
-
|
| 1942 |
type: Retrieval
|
| 1943 |
-
|
| 1944 |
-
config: default
|
| 1945 |
name: MTEB CQADupstackPhysicsRetrieval (default)
|
| 1946 |
-
revision: 79531abbd1fb92d06c6d6315a0cbbbf5bb247ea4
|
| 1947 |
-
split: test
|
| 1948 |
type: mteb/cqadupstack-physics
|
|
|
|
|
|
|
|
|
|
| 1949 |
metrics:
|
| 1950 |
- type: main_score
|
| 1951 |
value: 45.314
|
|
@@ -2229,14 +2231,14 @@ model-index:
|
|
| 2229 |
value: 43.679
|
| 2230 |
- type: recall_at_5
|
| 2231 |
value: 49.735
|
| 2232 |
-
|
| 2233 |
type: Retrieval
|
| 2234 |
-
|
| 2235 |
-
config: default
|
| 2236 |
name: MTEB CQADupstackProgrammersRetrieval (default)
|
| 2237 |
-
revision: 6184bc1440d2dbc7612be22b50686b8826d22b32
|
| 2238 |
-
split: test
|
| 2239 |
type: mteb/cqadupstack-programmers
|
|
|
|
|
|
|
|
|
|
| 2240 |
metrics:
|
| 2241 |
- type: main_score
|
| 2242 |
value: 41.972
|
|
@@ -2520,27 +2522,27 @@ model-index:
|
|
| 2520 |
value: 39.363
|
| 2521 |
- type: recall_at_5
|
| 2522 |
value: 44.665
|
| 2523 |
-
|
| 2524 |
type: Retrieval
|
| 2525 |
-
|
| 2526 |
-
config: default
|
| 2527 |
name: MTEB CQADupstackRetrieval (default)
|
| 2528 |
-
revision: CQADupstackRetrieval_is_a_combined_dataset
|
| 2529 |
-
split: test
|
| 2530 |
type: CQADupstackRetrieval_is_a_combined_dataset
|
|
|
|
|
|
|
|
|
|
| 2531 |
metrics:
|
| 2532 |
- type: main_score
|
| 2533 |
value: 39.823499999999996
|
| 2534 |
- type: ndcg_at_10
|
| 2535 |
value: 39.823499999999996
|
| 2536 |
-
|
| 2537 |
type: Retrieval
|
| 2538 |
-
|
| 2539 |
-
config: default
|
| 2540 |
name: MTEB CQADupstackStatsRetrieval (default)
|
| 2541 |
-
revision: 65ac3a16b8e91f9cee4c9828cc7c335575432a2a
|
| 2542 |
-
split: test
|
| 2543 |
type: mteb/cqadupstack-stats
|
|
|
|
|
|
|
|
|
|
| 2544 |
metrics:
|
| 2545 |
- type: main_score
|
| 2546 |
value: 34.943000000000005
|
|
@@ -2824,14 +2826,14 @@ model-index:
|
|
| 2824 |
value: 33.427
|
| 2825 |
- type: recall_at_5
|
| 2826 |
value: 37.643
|
| 2827 |
-
|
| 2828 |
type: Retrieval
|
| 2829 |
-
|
| 2830 |
-
config: default
|
| 2831 |
name: MTEB CQADupstackTexRetrieval (default)
|
| 2832 |
-
revision: 46989137a86843e03a6195de44b09deda022eec7
|
| 2833 |
-
split: test
|
| 2834 |
type: mteb/cqadupstack-tex
|
|
|
|
|
|
|
|
|
|
| 2835 |
metrics:
|
| 2836 |
- type: main_score
|
| 2837 |
value: 27.271
|
|
@@ -3115,14 +3117,14 @@ model-index:
|
|
| 3115 |
value: 25.592
|
| 3116 |
- type: recall_at_5
|
| 3117 |
value: 30.279
|
| 3118 |
-
|
| 3119 |
type: Retrieval
|
| 3120 |
-
|
| 3121 |
-
config: default
|
| 3122 |
name: MTEB CQADupstackUnixRetrieval (default)
|
| 3123 |
-
revision: 6c6430d3a6d36f8d2a829195bc5dc94d7e063e53
|
| 3124 |
-
split: test
|
| 3125 |
type: mteb/cqadupstack-unix
|
|
|
|
|
|
|
|
|
|
| 3126 |
metrics:
|
| 3127 |
- type: main_score
|
| 3128 |
value: 38.237
|
|
@@ -3406,14 +3408,14 @@ model-index:
|
|
| 3406 |
value: 36.275
|
| 3407 |
- type: recall_at_5
|
| 3408 |
value: 42.199
|
| 3409 |
-
|
| 3410 |
type: Retrieval
|
| 3411 |
-
|
| 3412 |
-
config: default
|
| 3413 |
name: MTEB CQADupstackWebmastersRetrieval (default)
|
| 3414 |
-
revision: 160c094312a0e1facb97e55eeddb698c0abe3571
|
| 3415 |
-
split: test
|
| 3416 |
type: mteb/cqadupstack-webmasters
|
|
|
|
|
|
|
|
|
|
| 3417 |
metrics:
|
| 3418 |
- type: main_score
|
| 3419 |
value: 38.702
|
|
@@ -3697,14 +3699,14 @@ model-index:
|
|
| 3697 |
value: 37.634
|
| 3698 |
- type: recall_at_5
|
| 3699 |
value: 42.021
|
| 3700 |
-
|
| 3701 |
type: Retrieval
|
| 3702 |
-
|
| 3703 |
-
config: default
|
| 3704 |
name: MTEB CQADupstackWordpressRetrieval (default)
|
| 3705 |
-
revision: 4ffe81d471b1924886b33c7567bfb200e9eec5c4
|
| 3706 |
-
split: test
|
| 3707 |
type: mteb/cqadupstack-wordpress
|
|
|
|
|
|
|
|
|
|
| 3708 |
metrics:
|
| 3709 |
- type: main_score
|
| 3710 |
value: 33.184000000000005
|
|
@@ -3988,14 +3990,14 @@ model-index:
|
|
| 3988 |
value: 32.683
|
| 3989 |
- type: recall_at_5
|
| 3990 |
value: 36.756
|
| 3991 |
-
|
| 3992 |
type: Retrieval
|
| 3993 |
-
|
| 3994 |
-
config: default
|
| 3995 |
name: MTEB ClimateFEVER (default)
|
| 3996 |
-
revision: 47f2ac6acb640fc46020b02a5b59fdda04d39380
|
| 3997 |
-
split: test
|
| 3998 |
type: mteb/climate-fever
|
|
|
|
|
|
|
|
|
|
| 3999 |
metrics:
|
| 4000 |
- type: main_score
|
| 4001 |
value: 25.068
|
|
@@ -4279,14 +4281,14 @@ model-index:
|
|
| 4279 |
value: 18.312
|
| 4280 |
- type: recall_at_5
|
| 4281 |
value: 22.776
|
| 4282 |
-
|
| 4283 |
type: Retrieval
|
| 4284 |
-
|
| 4285 |
-
config: default
|
| 4286 |
name: MTEB DBPedia (default)
|
| 4287 |
-
revision: c0f706b76e590d620bd6618b3ca8efdd34e2d659
|
| 4288 |
-
split: test
|
| 4289 |
type: mteb/dbpedia
|
|
|
|
|
|
|
|
|
|
| 4290 |
metrics:
|
| 4291 |
- type: main_score
|
| 4292 |
value: 40.128
|
|
@@ -4570,14 +4572,14 @@ model-index:
|
|
| 4570 |
value: 14.562
|
| 4571 |
- type: recall_at_5
|
| 4572 |
value: 18.779
|
| 4573 |
-
|
| 4574 |
-
type:
|
| 4575 |
-
|
| 4576 |
-
config: default
|
| 4577 |
name: MTEB EmotionClassification (default)
|
| 4578 |
-
revision: 4f58c6b202a23cf9a4da393831edf4f9183cad37
|
| 4579 |
-
split: test
|
| 4580 |
type: mteb/emotion
|
|
|
|
|
|
|
|
|
|
| 4581 |
metrics:
|
| 4582 |
- type: accuracy
|
| 4583 |
value: 74.86
|
|
@@ -4587,14 +4589,14 @@ model-index:
|
|
| 4587 |
value: 75.96499621761998
|
| 4588 |
- type: main_score
|
| 4589 |
value: 74.86
|
| 4590 |
-
|
| 4591 |
-
type:
|
| 4592 |
-
|
| 4593 |
-
config: default
|
| 4594 |
name: MTEB FEVER (default)
|
| 4595 |
-
revision: bea83ef9e8fb933d90a2f1d5515737465d613e12
|
| 4596 |
-
split: test
|
| 4597 |
type: mteb/fever
|
|
|
|
|
|
|
|
|
|
| 4598 |
metrics:
|
| 4599 |
- type: main_score
|
| 4600 |
value: 86.029
|
|
@@ -4878,14 +4880,14 @@ model-index:
|
|
| 4878 |
value: 88.382
|
| 4879 |
- type: recall_at_5
|
| 4880 |
value: 90.908
|
| 4881 |
-
|
| 4882 |
type: Retrieval
|
| 4883 |
-
|
| 4884 |
-
config: default
|
| 4885 |
name: MTEB FiQA2018 (default)
|
| 4886 |
-
revision: 27a168819829fe9bcd655c2df245fb19452e8e06
|
| 4887 |
-
split: test
|
| 4888 |
type: mteb/fiqa
|
|
|
|
|
|
|
|
|
|
| 4889 |
metrics:
|
| 4890 |
- type: main_score
|
| 4891 |
value: 45.238
|
|
@@ -5169,14 +5171,14 @@ model-index:
|
|
| 5169 |
value: 37.656
|
| 5170 |
- type: recall_at_5
|
| 5171 |
value: 44.766
|
| 5172 |
-
|
| 5173 |
type: Retrieval
|
| 5174 |
-
|
| 5175 |
-
config: default
|
| 5176 |
name: MTEB HotpotQA (default)
|
| 5177 |
-
revision: ab518f4d6fcca38d87c25209f94beba119d02014
|
| 5178 |
-
split: test
|
| 5179 |
type: mteb/hotpotqa
|
|
|
|
|
|
|
|
|
|
| 5180 |
metrics:
|
| 5181 |
- type: main_score
|
| 5182 |
value: 66.672
|
|
@@ -5460,14 +5462,14 @@ model-index:
|
|
| 5460 |
value: 57.522
|
| 5461 |
- type: recall_at_5
|
| 5462 |
value: 62.134
|
| 5463 |
-
|
| 5464 |
-
type:
|
| 5465 |
-
|
| 5466 |
-
config: default
|
| 5467 |
name: MTEB ImdbClassification (default)
|
| 5468 |
-
revision: 3d86128a09e091d6018b6d26cad27f2739fc2db7
|
| 5469 |
-
split: test
|
| 5470 |
type: mteb/imdb
|
|
|
|
|
|
|
|
|
|
| 5471 |
metrics:
|
| 5472 |
- type: accuracy
|
| 5473 |
value: 93.5944
|
|
@@ -5481,14 +5483,14 @@ model-index:
|
|
| 5481 |
value: 93.58945949328377
|
| 5482 |
- type: main_score
|
| 5483 |
value: 93.5944
|
| 5484 |
-
|
| 5485 |
-
type:
|
| 5486 |
-
|
| 5487 |
-
config: default
|
| 5488 |
name: MTEB MSMARCO (default)
|
| 5489 |
-
revision: c5a29a104738b98a9e76336939199e264163d4a0
|
| 5490 |
-
split: dev
|
| 5491 |
type: mteb/msmarco
|
|
|
|
|
|
|
|
|
|
| 5492 |
metrics:
|
| 5493 |
- type: main_score
|
| 5494 |
value: 41.448
|
|
@@ -5772,14 +5774,14 @@ model-index:
|
|
| 5772 |
value: 41.304
|
| 5773 |
- type: recall_at_5
|
| 5774 |
value: 51.076
|
| 5775 |
-
|
| 5776 |
-
type:
|
| 5777 |
-
|
| 5778 |
-
config: en
|
| 5779 |
name: MTEB MTOPDomainClassification (en)
|
| 5780 |
-
revision: d80d48c1eb48d3562165c59d59d0034df9fff0bf
|
| 5781 |
-
split: test
|
| 5782 |
type: mteb/mtop_domain
|
|
|
|
|
|
|
|
|
|
| 5783 |
metrics:
|
| 5784 |
- type: accuracy
|
| 5785 |
value: 96.03967168262655
|
|
@@ -5789,14 +5791,14 @@ model-index:
|
|
| 5789 |
value: 96.06623245823347
|
| 5790 |
- type: main_score
|
| 5791 |
value: 96.03967168262655
|
| 5792 |
-
|
| 5793 |
type: Classification
|
| 5794 |
-
|
| 5795 |
-
config: en
|
| 5796 |
name: MTEB MTOPIntentClassification (en)
|
| 5797 |
-
revision: ae001d0e6b1228650b7bd1c2c65fb50ad11a8aba
|
| 5798 |
-
split: test
|
| 5799 |
type: mteb/mtop_intent
|
|
|
|
|
|
|
|
|
|
| 5800 |
metrics:
|
| 5801 |
- type: accuracy
|
| 5802 |
value: 89.12904696762428
|
|
@@ -5806,14 +5808,14 @@ model-index:
|
|
| 5806 |
value: 90.41290566743324
|
| 5807 |
- type: main_score
|
| 5808 |
value: 89.12904696762428
|
| 5809 |
-
|
| 5810 |
type: Classification
|
| 5811 |
-
|
| 5812 |
-
config: en
|
| 5813 |
name: MTEB MassiveIntentClassification (en)
|
| 5814 |
-
revision: 4672e20407010da34463acc759c162ca9734bca6
|
| 5815 |
-
split: test
|
| 5816 |
type: mteb/amazon_massive_intent
|
|
|
|
|
|
|
|
|
|
| 5817 |
metrics:
|
| 5818 |
- type: accuracy
|
| 5819 |
value: 76.49630127774041
|
|
@@ -5823,14 +5825,14 @@ model-index:
|
|
| 5823 |
value: 76.42436195016484
|
| 5824 |
- type: main_score
|
| 5825 |
value: 76.49630127774041
|
| 5826 |
-
|
| 5827 |
type: Classification
|
| 5828 |
-
|
| 5829 |
-
config: en
|
| 5830 |
name: MTEB MassiveScenarioClassification (en)
|
| 5831 |
-
revision: fad2c6e8459f9e1c45d9315f4953d921437d70f8
|
| 5832 |
-
split: test
|
| 5833 |
type: mteb/amazon_massive_scenario
|
|
|
|
|
|
|
|
|
|
| 5834 |
metrics:
|
| 5835 |
- type: accuracy
|
| 5836 |
value: 78.9340954942838
|
|
@@ -5840,14 +5842,14 @@ model-index:
|
|
| 5840 |
value: 78.87787647838971
|
| 5841 |
- type: main_score
|
| 5842 |
value: 78.9340954942838
|
| 5843 |
-
|
| 5844 |
-
type:
|
| 5845 |
-
|
| 5846 |
-
config: default
|
| 5847 |
name: MTEB MedrxivClusteringP2P (default)
|
| 5848 |
-
revision: e7a26af6f3ae46b30dde8737f02c07b1505bcc73
|
| 5849 |
-
split: test
|
| 5850 |
type: mteb/medrxiv-clustering-p2p
|
|
|
|
|
|
|
|
|
|
| 5851 |
metrics:
|
| 5852 |
- type: main_score
|
| 5853 |
value: 37.50182848656019
|
|
@@ -5855,14 +5857,14 @@ model-index:
|
|
| 5855 |
value: 37.50182848656019
|
| 5856 |
- type: v_measure_std
|
| 5857 |
value: 1.1708518023877268
|
| 5858 |
-
|
| 5859 |
type: Clustering
|
| 5860 |
-
|
| 5861 |
-
config: default
|
| 5862 |
name: MTEB MedrxivClusteringS2S (default)
|
| 5863 |
-
revision: 35191c8c0dca72d8ff3efcd72aa802307d469663
|
| 5864 |
-
split: test
|
| 5865 |
type: mteb/medrxiv-clustering-s2s
|
|
|
|
|
|
|
|
|
|
| 5866 |
metrics:
|
| 5867 |
- type: main_score
|
| 5868 |
value: 35.72762609825363
|
|
@@ -5870,14 +5872,14 @@ model-index:
|
|
| 5870 |
value: 35.72762609825363
|
| 5871 |
- type: v_measure_std
|
| 5872 |
value: 1.4555014772914985
|
| 5873 |
-
|
| 5874 |
-
type:
|
| 5875 |
-
|
| 5876 |
-
config: default
|
| 5877 |
name: MTEB MindSmallReranking (default)
|
| 5878 |
-
revision: 59042f120c80e8afa9cdbb224f67076cec0fc9a7
|
| 5879 |
-
split: test
|
| 5880 |
type: mteb/mind_small
|
|
|
|
|
|
|
|
|
|
| 5881 |
metrics:
|
| 5882 |
- type: main_score
|
| 5883 |
value: 30.47716416454022
|
|
@@ -5897,14 +5899,14 @@ model-index:
|
|
| 5897 |
value: -15.78941850629242
|
| 5898 |
- type: nAUC_mrr_std
|
| 5899 |
value: -1.1330442292510805
|
| 5900 |
-
|
| 5901 |
-
type:
|
| 5902 |
-
|
| 5903 |
-
config: default
|
| 5904 |
name: MTEB NFCorpus (default)
|
| 5905 |
-
revision: ec0fa4fe99da2ff19ca1214b7966684033a58814
|
| 5906 |
-
split: test
|
| 5907 |
type: mteb/nfcorpus
|
|
|
|
|
|
|
|
|
|
| 5908 |
metrics:
|
| 5909 |
- type: main_score
|
| 5910 |
value: 34.648
|
|
@@ -6188,14 +6190,14 @@ model-index:
|
|
| 6188 |
value: 10.037
|
| 6189 |
- type: recall_at_5
|
| 6190 |
value: 12.717999999999998
|
| 6191 |
-
|
| 6192 |
type: Retrieval
|
| 6193 |
-
|
| 6194 |
-
config: default
|
| 6195 |
name: MTEB NQ (default)
|
| 6196 |
-
revision: b774495ed302d8c44a3a7ea25c90dbce03968f31
|
| 6197 |
-
split: test
|
| 6198 |
type: mteb/nq
|
|
|
|
|
|
|
|
|
|
| 6199 |
metrics:
|
| 6200 |
- type: main_score
|
| 6201 |
value: 60.06
|
|
@@ -6479,14 +6481,14 @@ model-index:
|
|
| 6479 |
value: 61.114000000000004
|
| 6480 |
- type: recall_at_5
|
| 6481 |
value: 69.812
|
| 6482 |
-
|
| 6483 |
type: Retrieval
|
| 6484 |
-
|
| 6485 |
-
config: default
|
| 6486 |
name: MTEB QuoraRetrieval (default)
|
| 6487 |
-
revision: e4e08e0b7dbe3c8700f0daef558ff32256715259
|
| 6488 |
-
split: test
|
| 6489 |
type: mteb/quora
|
|
|
|
|
|
|
|
|
|
| 6490 |
metrics:
|
| 6491 |
- type: main_score
|
| 6492 |
value: 89.821
|
|
@@ -6770,14 +6772,14 @@ model-index:
|
|
| 6770 |
value: 88.714
|
| 6771 |
- type: recall_at_5
|
| 6772 |
value: 92.96799999999999
|
| 6773 |
-
|
| 6774 |
-
type:
|
| 6775 |
-
|
| 6776 |
-
config: default
|
| 6777 |
name: MTEB RedditClustering (default)
|
| 6778 |
-
revision: 24640382cdbf8abc73003fb0fa6d111a705499eb
|
| 6779 |
-
split: test
|
| 6780 |
type: mteb/reddit-clustering
|
|
|
|
|
|
|
|
|
|
| 6781 |
metrics:
|
| 6782 |
- type: main_score
|
| 6783 |
value: 59.36038828851887
|
|
@@ -6785,14 +6787,14 @@ model-index:
|
|
| 6785 |
value: 59.36038828851887
|
| 6786 |
- type: v_measure_std
|
| 6787 |
value: 4.1958765965154425
|
| 6788 |
-
|
| 6789 |
type: Clustering
|
| 6790 |
-
|
| 6791 |
-
config: default
|
| 6792 |
name: MTEB RedditClusteringP2P (default)
|
| 6793 |
-
revision: 385e3cb46b4cfa89021f56c4380204149d0efe33
|
| 6794 |
-
split: test
|
| 6795 |
type: mteb/reddit-clustering-p2p
|
|
|
|
|
|
|
|
|
|
| 6796 |
metrics:
|
| 6797 |
- type: main_score
|
| 6798 |
value: 64.67522832408089
|
|
@@ -6800,14 +6802,14 @@ model-index:
|
|
| 6800 |
value: 64.67522832408089
|
| 6801 |
- type: v_measure_std
|
| 6802 |
value: 12.473765016158698
|
| 6803 |
-
|
| 6804 |
-
type:
|
| 6805 |
-
|
| 6806 |
-
config: default
|
| 6807 |
name: MTEB SCIDOCS (default)
|
| 6808 |
-
revision: f8c2fcf00f625baaa80f62ec5bd9e1fff3b8ae88
|
| 6809 |
-
split: test
|
| 6810 |
type: mteb/scidocs
|
|
|
|
|
|
|
|
|
|
| 6811 |
metrics:
|
| 6812 |
- type: main_score
|
| 6813 |
value: 21.751
|
|
@@ -7091,14 +7093,14 @@ model-index:
|
|
| 7091 |
value: 11.648
|
| 7092 |
- type: recall_at_5
|
| 7093 |
value: 15.883
|
| 7094 |
-
|
| 7095 |
-
type:
|
| 7096 |
-
|
| 7097 |
-
config: default
|
| 7098 |
name: MTEB SICK-R (default)
|
| 7099 |
-
revision: 20a6d6f312dd54037fe07a32d58e5e168867909d
|
| 7100 |
-
split: test
|
| 7101 |
type: mteb/sickr-sts
|
|
|
|
|
|
|
|
|
|
| 7102 |
metrics:
|
| 7103 |
- type: cosine_pearson
|
| 7104 |
value: 84.0161170579997
|
|
@@ -7118,14 +7120,14 @@ model-index:
|
|
| 7118 |
value: 84.0161170579997
|
| 7119 |
- type: spearman
|
| 7120 |
value: 77.52025923874551
|
| 7121 |
-
|
| 7122 |
type: STS
|
| 7123 |
-
|
| 7124 |
-
config: default
|
| 7125 |
name: MTEB STS12 (default)
|
| 7126 |
-
revision: a0d554a64d88156834ff5ae9920b964011b16384
|
| 7127 |
-
split: test
|
| 7128 |
type: mteb/sts12-sts
|
|
|
|
|
|
|
|
|
|
| 7129 |
metrics:
|
| 7130 |
- type: cosine_pearson
|
| 7131 |
value: 81.32328780209225
|
|
@@ -7145,14 +7147,14 @@ model-index:
|
|
| 7145 |
value: 81.32328780209225
|
| 7146 |
- type: spearman
|
| 7147 |
value: 74.17570679745272
|
| 7148 |
-
|
| 7149 |
type: STS
|
| 7150 |
-
|
| 7151 |
-
config: default
|
| 7152 |
name: MTEB STS13 (default)
|
| 7153 |
-
revision: 7e90230a92c190f1bf69ae9002b8cea547a64cca
|
| 7154 |
-
split: test
|
| 7155 |
type: mteb/sts13-sts
|
|
|
|
|
|
|
|
|
|
| 7156 |
metrics:
|
| 7157 |
- type: cosine_pearson
|
| 7158 |
value: 85.53224141249392
|
|
@@ -7172,14 +7174,14 @@ model-index:
|
|
| 7172 |
value: 85.53224141249392
|
| 7173 |
- type: spearman
|
| 7174 |
value: 86.16981525069227
|
| 7175 |
-
|
| 7176 |
type: STS
|
| 7177 |
-
|
| 7178 |
-
config: default
|
| 7179 |
name: MTEB STS14 (default)
|
| 7180 |
-
revision: 6031580fec1f6af667f0bd2da0a551cf4f0b2375
|
| 7181 |
-
split: test
|
| 7182 |
type: mteb/sts14-sts
|
|
|
|
|
|
|
|
|
|
| 7183 |
metrics:
|
| 7184 |
- type: cosine_pearson
|
| 7185 |
value: 82.234064045301
|
|
@@ -7199,14 +7201,14 @@ model-index:
|
|
| 7199 |
value: 82.234064045301
|
| 7200 |
- type: spearman
|
| 7201 |
value: 78.86920830792957
|
| 7202 |
-
|
| 7203 |
type: STS
|
| 7204 |
-
|
| 7205 |
-
config: default
|
| 7206 |
name: MTEB STS15 (default)
|
| 7207 |
-
revision: ae752c7c21bf194d8b67fd573edf7ae58183cbe3
|
| 7208 |
-
split: test
|
| 7209 |
type: mteb/sts15-sts
|
|
|
|
|
|
|
|
|
|
| 7210 |
metrics:
|
| 7211 |
- type: cosine_pearson
|
| 7212 |
value: 86.23114543080261
|
|
@@ -7226,14 +7228,14 @@ model-index:
|
|
| 7226 |
value: 86.23114543080261
|
| 7227 |
- type: spearman
|
| 7228 |
value: 87.481042891123
|
| 7229 |
-
|
| 7230 |
type: STS
|
| 7231 |
-
|
| 7232 |
-
config: default
|
| 7233 |
name: MTEB STS16 (default)
|
| 7234 |
-
revision: 4d8694f8f0e0100860b497b999b3dbed754a0513
|
| 7235 |
-
split: test
|
| 7236 |
type: mteb/sts16-sts
|
|
|
|
|
|
|
|
|
|
| 7237 |
metrics:
|
| 7238 |
- type: cosine_pearson
|
| 7239 |
value: 82.9156629047782
|
|
@@ -7253,14 +7255,14 @@ model-index:
|
|
| 7253 |
value: 82.9156629047782
|
| 7254 |
- type: spearman
|
| 7255 |
value: 84.28381329207937
|
| 7256 |
-
|
| 7257 |
type: STS
|
| 7258 |
-
|
| 7259 |
-
config: en-en
|
| 7260 |
name: MTEB STS17 (en-en)
|
| 7261 |
-
revision: faeb762787bd10488a50c8b5be4a3b82e411949c
|
| 7262 |
-
split: test
|
| 7263 |
type: mteb/sts17-crosslingual-sts
|
|
|
|
|
|
|
|
|
|
| 7264 |
metrics:
|
| 7265 |
- type: cosine_pearson
|
| 7266 |
value: 88.91985349746744
|
|
@@ -7280,14 +7282,14 @@ model-index:
|
|
| 7280 |
value: 88.91985349746744
|
| 7281 |
- type: spearman
|
| 7282 |
value: 89.69151633966257
|
| 7283 |
-
|
| 7284 |
type: STS
|
| 7285 |
-
|
| 7286 |
-
config: en
|
| 7287 |
name: MTEB STS22 (en)
|
| 7288 |
-
revision: de9d86b3b84231dc21f76c7b7af1f28e2f57f6e3
|
| 7289 |
-
split: test
|
| 7290 |
type: mteb/sts22-crosslingual-sts
|
|
|
|
|
|
|
|
|
|
| 7291 |
metrics:
|
| 7292 |
- type: cosine_pearson
|
| 7293 |
value: 65.0979772547511
|
|
@@ -7307,14 +7309,14 @@ model-index:
|
|
| 7307 |
value: 65.0979772547511
|
| 7308 |
- type: spearman
|
| 7309 |
value: 65.78126527764236
|
| 7310 |
-
|
| 7311 |
type: STS
|
| 7312 |
-
|
| 7313 |
-
config: default
|
| 7314 |
name: MTEB STSBenchmark (default)
|
| 7315 |
-
revision: b0fddb56ed78048fa8b90373c8a3cfc37b684831
|
| 7316 |
-
split: test
|
| 7317 |
type: mteb/stsbenchmark-sts
|
|
|
|
|
|
|
|
|
|
| 7318 |
metrics:
|
| 7319 |
- type: cosine_pearson
|
| 7320 |
value: 85.6426635049971
|
|
@@ -7334,14 +7336,14 @@ model-index:
|
|
| 7334 |
value: 85.6426635049971
|
| 7335 |
- type: spearman
|
| 7336 |
value: 85.609856578385
|
| 7337 |
-
|
| 7338 |
-
type:
|
| 7339 |
-
|
| 7340 |
-
config: default
|
| 7341 |
name: MTEB SciDocsRR (default)
|
| 7342 |
-
revision: d3c5e1fc0b855ab6097bf1cda04dd73947d7caab
|
| 7343 |
-
split: test
|
| 7344 |
type: mteb/scidocs-reranking
|
|
|
|
|
|
|
|
|
|
| 7345 |
metrics:
|
| 7346 |
- type: main_score
|
| 7347 |
value: 82.85163332499799
|
|
@@ -7361,14 +7363,14 @@ model-index:
|
|
| 7361 |
value: 89.47202967481866
|
| 7362 |
- type: nAUC_mrr_std
|
| 7363 |
value: 85.40446996933892
|
| 7364 |
-
|
| 7365 |
-
type:
|
| 7366 |
-
|
| 7367 |
-
config: default
|
| 7368 |
name: MTEB SciFact (default)
|
| 7369 |
-
revision: 0228b52cf27578f30900b9e5271d331663a030d7
|
| 7370 |
-
split: test
|
| 7371 |
type: mteb/scifact
|
|
|
|
|
|
|
|
|
|
| 7372 |
metrics:
|
| 7373 |
- type: main_score
|
| 7374 |
value: 71.655
|
|
@@ -7652,14 +7654,14 @@ model-index:
|
|
| 7652 |
value: 71.61699999999999
|
| 7653 |
- type: recall_at_5
|
| 7654 |
value: 78.361
|
| 7655 |
-
|
| 7656 |
-
type:
|
| 7657 |
-
|
| 7658 |
-
config: default
|
| 7659 |
name: MTEB SprintDuplicateQuestions (default)
|
| 7660 |
-
revision: d66bd1f72af766a5cc4b0ca5e00c162f89e8cc46
|
| 7661 |
-
split: test
|
| 7662 |
type: mteb/sprintduplicatequestions-pairclassification
|
|
|
|
|
|
|
|
|
|
| 7663 |
metrics:
|
| 7664 |
- type: cosine_accuracy
|
| 7665 |
value: 99.8019801980198
|
|
@@ -7743,14 +7745,14 @@ model-index:
|
|
| 7743 |
value: 90.79754601226993
|
| 7744 |
- type: similarity_recall
|
| 7745 |
value: 88.8
|
| 7746 |
-
|
| 7747 |
-
type:
|
| 7748 |
-
|
| 7749 |
-
config: default
|
| 7750 |
name: MTEB StackExchangeClustering (default)
|
| 7751 |
-
revision: 6cbc1f7b2bc0622f2e39d2c77fa502909748c259
|
| 7752 |
-
split: test
|
| 7753 |
type: mteb/stackexchange-clustering
|
|
|
|
|
|
|
|
|
|
| 7754 |
metrics:
|
| 7755 |
- type: main_score
|
| 7756 |
value: 66.63931197758824
|
|
@@ -7758,14 +7760,14 @@ model-index:
|
|
| 7758 |
value: 66.63931197758824
|
| 7759 |
- type: v_measure_std
|
| 7760 |
value: 3.896206781511776
|
| 7761 |
-
|
| 7762 |
type: Clustering
|
| 7763 |
-
|
| 7764 |
-
config: default
|
| 7765 |
name: MTEB StackExchangeClusteringP2P (default)
|
| 7766 |
-
revision: 815ca46b2622cec33ccafc3735d572c266efdb44
|
| 7767 |
-
split: test
|
| 7768 |
type: mteb/stackexchange-clustering-p2p
|
|
|
|
|
|
|
|
|
|
| 7769 |
metrics:
|
| 7770 |
- type: main_score
|
| 7771 |
value: 38.984892653301884
|
|
@@ -7773,14 +7775,14 @@ model-index:
|
|
| 7773 |
value: 38.984892653301884
|
| 7774 |
- type: v_measure_std
|
| 7775 |
value: 1.3308552162270453
|
| 7776 |
-
|
| 7777 |
-
type:
|
| 7778 |
-
|
| 7779 |
-
config: default
|
| 7780 |
name: MTEB StackOverflowDupQuestions (default)
|
| 7781 |
-
revision: e185fbe320c72810689fc5848eb6114e1ef5ec69
|
| 7782 |
-
split: test
|
| 7783 |
type: mteb/stackoverflowdupquestions-reranking
|
|
|
|
|
|
|
|
|
|
| 7784 |
metrics:
|
| 7785 |
- type: main_score
|
| 7786 |
value: 52.71499643455044
|
|
@@ -7800,14 +7802,14 @@ model-index:
|
|
| 7800 |
value: 13.931448578334379
|
| 7801 |
- type: nAUC_mrr_std
|
| 7802 |
value: 10.441860004959661
|
| 7803 |
-
|
| 7804 |
-
type:
|
| 7805 |
-
|
| 7806 |
-
config: default
|
| 7807 |
name: MTEB SummEval (default)
|
| 7808 |
-
revision: cda12ad7615edc362dbf25a00fdd61d3b1eaf93c
|
| 7809 |
-
split: test
|
| 7810 |
type: mteb/summeval
|
|
|
|
|
|
|
|
|
|
| 7811 |
metrics:
|
| 7812 |
- type: cosine_pearson
|
| 7813 |
value: 31.5167525286909
|
|
@@ -7823,14 +7825,14 @@ model-index:
|
|
| 7823 |
value: 31.5167525286909
|
| 7824 |
- type: spearman
|
| 7825 |
value: 31.218862970706496
|
| 7826 |
-
|
| 7827 |
-
type:
|
| 7828 |
-
|
| 7829 |
-
config: default
|
| 7830 |
name: MTEB TRECCOVID (default)
|
| 7831 |
-
revision: bb9466bac8153a0349341eb1b22e06409e78ef4e
|
| 7832 |
-
split: test
|
| 7833 |
type: mteb/trec-covid
|
|
|
|
|
|
|
|
|
|
| 7834 |
metrics:
|
| 7835 |
- type: main_score
|
| 7836 |
value: 78.996
|
|
@@ -8114,14 +8116,14 @@ model-index:
|
|
| 8114 |
value: 0.705
|
| 8115 |
- type: recall_at_5
|
| 8116 |
value: 1.162
|
| 8117 |
-
|
| 8118 |
type: Retrieval
|
| 8119 |
-
|
| 8120 |
-
config: default
|
| 8121 |
name: MTEB Touche2020 (default)
|
| 8122 |
-
revision: a34f9a33db75fa0cbb21bb5cfc3dae8dc8bec93f
|
| 8123 |
-
split: test
|
| 8124 |
type: mteb/touche2020
|
|
|
|
|
|
|
|
|
|
| 8125 |
metrics:
|
| 8126 |
- type: main_score
|
| 8127 |
value: 24.234
|
|
@@ -8405,14 +8407,14 @@ model-index:
|
|
| 8405 |
value: 6.625
|
| 8406 |
- type: recall_at_5
|
| 8407 |
value: 9.094
|
| 8408 |
-
|
| 8409 |
-
type:
|
| 8410 |
-
|
| 8411 |
-
config: default
|
| 8412 |
name: MTEB ToxicConversationsClassification (default)
|
| 8413 |
-
revision: edfaf9da55d3dd50d43143d90c1ac476895ae6de
|
| 8414 |
-
split: test
|
| 8415 |
type: mteb/toxic_conversations_50k
|
|
|
|
|
|
|
|
|
|
| 8416 |
metrics:
|
| 8417 |
- type: accuracy
|
| 8418 |
value: 72.822265625
|
|
@@ -8426,14 +8428,14 @@ model-index:
|
|
| 8426 |
value: 78.7454393727821
|
| 8427 |
- type: main_score
|
| 8428 |
value: 72.822265625
|
| 8429 |
-
|
| 8430 |
type: Classification
|
| 8431 |
-
|
| 8432 |
-
config: default
|
| 8433 |
name: MTEB TweetSentimentExtractionClassification (default)
|
| 8434 |
-
revision: d604517c81ca91fe16a244d1248fc021f9ecee7a
|
| 8435 |
-
split: test
|
| 8436 |
type: mteb/tweet_sentiment_extraction
|
|
|
|
|
|
|
|
|
|
| 8437 |
metrics:
|
| 8438 |
- type: accuracy
|
| 8439 |
value: 72.54385964912281
|
|
@@ -8443,14 +8445,14 @@ model-index:
|
|
| 8443 |
value: 72.18022450339639
|
| 8444 |
- type: main_score
|
| 8445 |
value: 72.54385964912281
|
| 8446 |
-
|
| 8447 |
-
type:
|
| 8448 |
-
|
| 8449 |
-
config: default
|
| 8450 |
name: MTEB TwentyNewsgroupsClustering (default)
|
| 8451 |
-
revision: 6125ec4e24fa026cec8a478383ee943acfbd5449
|
| 8452 |
-
split: test
|
| 8453 |
type: mteb/twentynewsgroups-clustering
|
|
|
|
|
|
|
|
|
|
| 8454 |
metrics:
|
| 8455 |
- type: main_score
|
| 8456 |
value: 57.41861450414374
|
|
@@ -8458,14 +8460,14 @@ model-index:
|
|
| 8458 |
value: 57.41861450414374
|
| 8459 |
- type: v_measure_std
|
| 8460 |
value: 1.1732394227153524
|
| 8461 |
-
|
| 8462 |
-
type:
|
| 8463 |
-
|
| 8464 |
-
config: default
|
| 8465 |
name: MTEB TwitterSemEval2015 (default)
|
| 8466 |
-
revision: 70970daeab8776df92f5ea462b6173c0b46fd2d1
|
| 8467 |
-
split: test
|
| 8468 |
type: mteb/twittersemeval2015-pairclassification
|
|
|
|
|
|
|
|
|
|
| 8469 |
metrics:
|
| 8470 |
- type: cosine_accuracy
|
| 8471 |
value: 85.65893783155511
|
|
@@ -8549,14 +8551,14 @@ model-index:
|
|
| 8549 |
value: 64.0855106888361
|
| 8550 |
- type: similarity_recall
|
| 8551 |
value: 71.18733509234828
|
| 8552 |
-
|
| 8553 |
type: PairClassification
|
| 8554 |
-
|
| 8555 |
-
config: default
|
| 8556 |
name: MTEB TwitterURLCorpus (default)
|
| 8557 |
-
revision: 8b6510b0b1fa4e4c4f879467980e9be563ec1cdf
|
| 8558 |
-
split: test
|
| 8559 |
type: mteb/twitterurlcorpus-pairclassification
|
|
|
|
|
|
|
|
|
|
| 8560 |
metrics:
|
| 8561 |
- type: cosine_accuracy
|
| 8562 |
value: 88.86754375751931
|
|
@@ -8640,8 +8642,6 @@ model-index:
|
|
| 8640 |
value: 74.19310344827586
|
| 8641 |
- type: similarity_recall
|
| 8642 |
value: 82.83030489682784
|
| 8643 |
-
task:
|
| 8644 |
-
type: PairClassification
|
| 8645 |
---
|
| 8646 |
# Contextual Document Embeddings (CDE)
|
| 8647 |
|
|
|
|
| 4 |
model-index:
|
| 5 |
- name: cde-small-v1
|
| 6 |
results:
|
| 7 |
+
- task:
|
| 8 |
+
type: Classification
|
| 9 |
+
dataset:
|
| 10 |
name: MTEB AmazonCounterfactualClassification (en)
|
|
|
|
|
|
|
| 11 |
type: mteb/amazon_counterfactual
|
| 12 |
+
config: en
|
| 13 |
+
split: test
|
| 14 |
+
revision: e8379541af4e31359cca9fbcf4b00f2671dba205
|
| 15 |
metrics:
|
| 16 |
- type: accuracy
|
| 17 |
value: 87.01492537313433
|
|
|
|
| 25 |
value: 87.74802754480477
|
| 26 |
- type: main_score
|
| 27 |
value: 87.01492537313433
|
| 28 |
+
- task:
|
| 29 |
type: Classification
|
| 30 |
+
dataset:
|
|
|
|
| 31 |
name: MTEB AmazonPolarityClassification (default)
|
|
|
|
|
|
|
| 32 |
type: mteb/amazon_polarity
|
| 33 |
+
config: default
|
| 34 |
+
split: test
|
| 35 |
+
revision: e2d317d38cd51312af73b3d32a06d1a08b442046
|
| 36 |
metrics:
|
| 37 |
- type: accuracy
|
| 38 |
value: 94.652275
|
|
|
|
| 46 |
value: 94.64655930708355
|
| 47 |
- type: main_score
|
| 48 |
value: 94.652275
|
| 49 |
+
- task:
|
| 50 |
type: Classification
|
| 51 |
+
dataset:
|
|
|
|
| 52 |
name: MTEB AmazonReviewsClassification (en)
|
|
|
|
|
|
|
| 53 |
type: mteb/amazon_reviews_multi
|
| 54 |
+
config: en
|
| 55 |
+
split: test
|
| 56 |
+
revision: 1399c76144fd37290681b995c656ef9b2e06e26d
|
| 57 |
metrics:
|
| 58 |
- type: accuracy
|
| 59 |
value: 55.75599999999999
|
|
|
|
| 63 |
value: 55.07058630829347
|
| 64 |
- type: main_score
|
| 65 |
value: 55.75599999999999
|
| 66 |
+
- task:
|
| 67 |
+
type: Retrieval
|
| 68 |
+
dataset:
|
|
|
|
| 69 |
name: MTEB ArguAna (default)
|
|
|
|
|
|
|
| 70 |
type: mteb/arguana
|
| 71 |
+
config: default
|
| 72 |
+
split: test
|
| 73 |
+
revision: c22ab2a51041ffd869aaddef7af8d8215647e41a
|
| 74 |
metrics:
|
| 75 |
- type: main_score
|
| 76 |
value: 69.959
|
|
|
|
| 354 |
value: 74.182
|
| 355 |
- type: recall_at_5
|
| 356 |
value: 84.495
|
| 357 |
+
- task:
|
| 358 |
+
type: Clustering
|
| 359 |
+
dataset:
|
|
|
|
| 360 |
name: MTEB ArxivClusteringP2P (default)
|
|
|
|
|
|
|
| 361 |
type: mteb/arxiv-clustering-p2p
|
| 362 |
+
config: default
|
| 363 |
+
split: test
|
| 364 |
+
revision: a122ad7f3f0291bf49cc6f4d32aa80929df69d5d
|
| 365 |
metrics:
|
| 366 |
- type: main_score
|
| 367 |
value: 48.54672141116669
|
|
|
|
| 369 |
value: 48.54672141116669
|
| 370 |
- type: v_measure_std
|
| 371 |
value: 14.037498386768362
|
| 372 |
+
- task:
|
| 373 |
type: Clustering
|
| 374 |
+
dataset:
|
|
|
|
| 375 |
name: MTEB ArxivClusteringS2S (default)
|
|
|
|
|
|
|
| 376 |
type: mteb/arxiv-clustering-s2s
|
| 377 |
+
config: default
|
| 378 |
+
split: test
|
| 379 |
+
revision: f910caf1a6075f7329cdf8c1a6135696f37dbd53
|
| 380 |
metrics:
|
| 381 |
- type: main_score
|
| 382 |
value: 40.5914039166466
|
|
|
|
| 384 |
value: 40.5914039166466
|
| 385 |
- type: v_measure_std
|
| 386 |
value: 14.385069818910331
|
| 387 |
+
- task:
|
| 388 |
+
type: Reranking
|
| 389 |
+
dataset:
|
|
|
|
| 390 |
name: MTEB AskUbuntuDupQuestions (default)
|
|
|
|
|
|
|
| 391 |
type: mteb/askubuntudupquestions-reranking
|
| 392 |
+
config: default
|
| 393 |
+
split: test
|
| 394 |
+
revision: 2000358ca161889fa9c082cb41daa8dcfb161a54
|
| 395 |
metrics:
|
| 396 |
- type: main_score
|
| 397 |
value: 61.13621260261507
|
|
|
|
| 411 |
value: 31.484257486448364
|
| 412 |
- type: nAUC_mrr_std
|
| 413 |
value: 21.252659250011632
|
| 414 |
+
- task:
|
| 415 |
+
type: STS
|
| 416 |
+
dataset:
|
|
|
|
| 417 |
name: MTEB BIOSSES (default)
|
|
|
|
|
|
|
| 418 |
type: mteb/biosses-sts
|
| 419 |
+
config: default
|
| 420 |
+
split: test
|
| 421 |
+
revision: d3fb88f8f02e40887cd149695127462bbcf29b4a
|
| 422 |
metrics:
|
| 423 |
- type: cosine_pearson
|
| 424 |
value: 89.07028016646942
|
|
|
|
| 438 |
value: 89.07028016646942
|
| 439 |
- type: spearman
|
| 440 |
value: 86.69595132967805
|
| 441 |
+
- task:
|
| 442 |
+
type: Classification
|
| 443 |
+
dataset:
|
|
|
|
| 444 |
name: MTEB Banking77Classification (default)
|
|
|
|
|
|
|
| 445 |
type: mteb/banking77
|
| 446 |
+
config: default
|
| 447 |
+
split: test
|
| 448 |
+
revision: 0fd18e25b25c072e09e0d92ab615fda904d66300
|
| 449 |
metrics:
|
| 450 |
- type: accuracy
|
| 451 |
value: 88.6038961038961
|
|
|
|
| 455 |
value: 88.56824205739822
|
| 456 |
- type: main_score
|
| 457 |
value: 88.6038961038961
|
| 458 |
+
- task:
|
| 459 |
+
type: Clustering
|
| 460 |
+
dataset:
|
|
|
|
| 461 |
name: MTEB BiorxivClusteringP2P (default)
|
|
|
|
|
|
|
| 462 |
type: mteb/biorxiv-clustering-p2p
|
| 463 |
+
config: default
|
| 464 |
+
split: test
|
| 465 |
+
revision: 65b79d1d13f80053f67aca9498d9402c2d9f1f40
|
| 466 |
metrics:
|
| 467 |
- type: main_score
|
| 468 |
value: 44.77800814327256
|
|
|
|
| 470 |
value: 44.77800814327256
|
| 471 |
- type: v_measure_std
|
| 472 |
value: 0.6462535527471919
|
| 473 |
+
- task:
|
| 474 |
type: Clustering
|
| 475 |
+
dataset:
|
|
|
|
| 476 |
name: MTEB BiorxivClusteringS2S (default)
|
|
|
|
|
|
|
| 477 |
type: mteb/biorxiv-clustering-s2s
|
| 478 |
+
config: default
|
| 479 |
+
split: test
|
| 480 |
+
revision: 258694dd0231531bc1fd9de6ceb52a0853c6d908
|
| 481 |
metrics:
|
| 482 |
- type: main_score
|
| 483 |
value: 38.16110272459102
|
|
|
|
| 485 |
value: 38.16110272459102
|
| 486 |
- type: v_measure_std
|
| 487 |
value: 0.7456916212435019
|
| 488 |
+
- task:
|
| 489 |
+
type: Retrieval
|
| 490 |
+
dataset:
|
|
|
|
| 491 |
name: MTEB CQADupstackAndroidRetrieval (default)
|
|
|
|
|
|
|
| 492 |
type: mteb/cqadupstack-android
|
| 493 |
+
config: default
|
| 494 |
+
split: test
|
| 495 |
+
revision: f46a197baaae43b4f621051089b82a364682dfeb
|
| 496 |
metrics:
|
| 497 |
- type: main_score
|
| 498 |
value: 49.376
|
|
|
|
| 776 |
value: 47.591
|
| 777 |
- type: recall_at_5
|
| 778 |
value: 54.245
|
| 779 |
+
- task:
|
| 780 |
type: Retrieval
|
| 781 |
+
dataset:
|
|
|
|
| 782 |
name: MTEB CQADupstackEnglishRetrieval (default)
|
|
|
|
|
|
|
| 783 |
type: mteb/cqadupstack-english
|
| 784 |
+
config: default
|
| 785 |
+
split: test
|
| 786 |
+
revision: ad9991cb51e31e31e430383c75ffb2885547b5f0
|
| 787 |
metrics:
|
| 788 |
- type: main_score
|
| 789 |
value: 44.727
|
|
|
|
| 1067 |
value: 42.085
|
| 1068 |
- type: recall_at_5
|
| 1069 |
value: 47.5
|
| 1070 |
+
- task:
|
| 1071 |
type: Retrieval
|
| 1072 |
+
dataset:
|
|
|
|
| 1073 |
name: MTEB CQADupstackGamingRetrieval (default)
|
|
|
|
|
|
|
| 1074 |
type: mteb/cqadupstack-gaming
|
| 1075 |
+
config: default
|
| 1076 |
+
split: test
|
| 1077 |
+
revision: 4885aa143210c98657558c04aaf3dc47cfb54340
|
| 1078 |
metrics:
|
| 1079 |
- type: main_score
|
| 1080 |
value: 59.001999999999995
|
|
|
|
| 1358 |
value: 57.916000000000004
|
| 1359 |
- type: recall_at_5
|
| 1360 |
value: 65.44
|
| 1361 |
+
- task:
|
| 1362 |
type: Retrieval
|
| 1363 |
+
dataset:
|
|
|
|
| 1364 |
name: MTEB CQADupstackGisRetrieval (default)
|
|
|
|
|
|
|
| 1365 |
type: mteb/cqadupstack-gis
|
| 1366 |
+
config: default
|
| 1367 |
+
split: test
|
| 1368 |
+
revision: 5003b3064772da1887988e05400cf3806fe491f2
|
| 1369 |
metrics:
|
| 1370 |
- type: main_score
|
| 1371 |
value: 37.501
|
|
|
|
| 1649 |
value: 37.218
|
| 1650 |
- type: recall_at_5
|
| 1651 |
value: 42.559000000000005
|
| 1652 |
+
- task:
|
| 1653 |
type: Retrieval
|
| 1654 |
+
dataset:
|
|
|
|
| 1655 |
name: MTEB CQADupstackMathematicaRetrieval (default)
|
|
|
|
|
|
|
| 1656 |
type: mteb/cqadupstack-mathematica
|
| 1657 |
+
config: default
|
| 1658 |
+
split: test
|
| 1659 |
+
revision: 90fceea13679c63fe563ded68f3b6f06e50061de
|
| 1660 |
metrics:
|
| 1661 |
- type: main_score
|
| 1662 |
value: 27.653
|
|
|
|
| 1940 |
value: 25.469
|
| 1941 |
- type: recall_at_5
|
| 1942 |
value: 31.316
|
| 1943 |
+
- task:
|
| 1944 |
type: Retrieval
|
| 1945 |
+
dataset:
|
|
|
|
| 1946 |
name: MTEB CQADupstackPhysicsRetrieval (default)
|
|
|
|
|
|
|
| 1947 |
type: mteb/cqadupstack-physics
|
| 1948 |
+
config: default
|
| 1949 |
+
split: test
|
| 1950 |
+
revision: 79531abbd1fb92d06c6d6315a0cbbbf5bb247ea4
|
| 1951 |
metrics:
|
| 1952 |
- type: main_score
|
| 1953 |
value: 45.314
|
|
|
|
| 2231 |
value: 43.679
|
| 2232 |
- type: recall_at_5
|
| 2233 |
value: 49.735
|
| 2234 |
+
- task:
|
| 2235 |
type: Retrieval
|
| 2236 |
+
dataset:
|
|
|
|
| 2237 |
name: MTEB CQADupstackProgrammersRetrieval (default)
|
|
|
|
|
|
|
| 2238 |
type: mteb/cqadupstack-programmers
|
| 2239 |
+
config: default
|
| 2240 |
+
split: test
|
| 2241 |
+
revision: 6184bc1440d2dbc7612be22b50686b8826d22b32
|
| 2242 |
metrics:
|
| 2243 |
- type: main_score
|
| 2244 |
value: 41.972
|
|
|
|
| 2522 |
value: 39.363
|
| 2523 |
- type: recall_at_5
|
| 2524 |
value: 44.665
|
| 2525 |
+
- task:
|
| 2526 |
type: Retrieval
|
| 2527 |
+
dataset:
|
|
|
|
| 2528 |
name: MTEB CQADupstackRetrieval (default)
|
|
|
|
|
|
|
| 2529 |
type: CQADupstackRetrieval_is_a_combined_dataset
|
| 2530 |
+
config: default
|
| 2531 |
+
split: test
|
| 2532 |
+
revision: CQADupstackRetrieval_is_a_combined_dataset
|
| 2533 |
metrics:
|
| 2534 |
- type: main_score
|
| 2535 |
value: 39.823499999999996
|
| 2536 |
- type: ndcg_at_10
|
| 2537 |
value: 39.823499999999996
|
| 2538 |
+
- task:
|
| 2539 |
type: Retrieval
|
| 2540 |
+
dataset:
|
|
|
|
| 2541 |
name: MTEB CQADupstackStatsRetrieval (default)
|
|
|
|
|
|
|
| 2542 |
type: mteb/cqadupstack-stats
|
| 2543 |
+
config: default
|
| 2544 |
+
split: test
|
| 2545 |
+
revision: 65ac3a16b8e91f9cee4c9828cc7c335575432a2a
|
| 2546 |
metrics:
|
| 2547 |
- type: main_score
|
| 2548 |
value: 34.943000000000005
|
|
|
|
| 2826 |
value: 33.427
|
| 2827 |
- type: recall_at_5
|
| 2828 |
value: 37.643
|
| 2829 |
+
- task:
|
| 2830 |
type: Retrieval
|
| 2831 |
+
dataset:
|
|
|
|
| 2832 |
name: MTEB CQADupstackTexRetrieval (default)
|
|
|
|
|
|
|
| 2833 |
type: mteb/cqadupstack-tex
|
| 2834 |
+
config: default
|
| 2835 |
+
split: test
|
| 2836 |
+
revision: 46989137a86843e03a6195de44b09deda022eec7
|
| 2837 |
metrics:
|
| 2838 |
- type: main_score
|
| 2839 |
value: 27.271
|
|
|
|
| 3117 |
value: 25.592
|
| 3118 |
- type: recall_at_5
|
| 3119 |
value: 30.279
|
| 3120 |
+
- task:
|
| 3121 |
type: Retrieval
|
| 3122 |
+
dataset:
|
|
|
|
| 3123 |
name: MTEB CQADupstackUnixRetrieval (default)
|
|
|
|
|
|
|
| 3124 |
type: mteb/cqadupstack-unix
|
| 3125 |
+
config: default
|
| 3126 |
+
split: test
|
| 3127 |
+
revision: 6c6430d3a6d36f8d2a829195bc5dc94d7e063e53
|
| 3128 |
metrics:
|
| 3129 |
- type: main_score
|
| 3130 |
value: 38.237
|
|
|
|
| 3408 |
value: 36.275
|
| 3409 |
- type: recall_at_5
|
| 3410 |
value: 42.199
|
| 3411 |
+
- task:
|
| 3412 |
type: Retrieval
|
| 3413 |
+
dataset:
|
|
|
|
| 3414 |
name: MTEB CQADupstackWebmastersRetrieval (default)
|
|
|
|
|
|
|
| 3415 |
type: mteb/cqadupstack-webmasters
|
| 3416 |
+
config: default
|
| 3417 |
+
split: test
|
| 3418 |
+
revision: 160c094312a0e1facb97e55eeddb698c0abe3571
|
| 3419 |
metrics:
|
| 3420 |
- type: main_score
|
| 3421 |
value: 38.702
|
|
|
|
| 3699 |
value: 37.634
|
| 3700 |
- type: recall_at_5
|
| 3701 |
value: 42.021
|
| 3702 |
+
- task:
|
| 3703 |
type: Retrieval
|
| 3704 |
+
dataset:
|
|
|
|
| 3705 |
name: MTEB CQADupstackWordpressRetrieval (default)
|
|
|
|
|
|
|
| 3706 |
type: mteb/cqadupstack-wordpress
|
| 3707 |
+
config: default
|
| 3708 |
+
split: test
|
| 3709 |
+
revision: 4ffe81d471b1924886b33c7567bfb200e9eec5c4
|
| 3710 |
metrics:
|
| 3711 |
- type: main_score
|
| 3712 |
value: 33.184000000000005
|
|
|
|
| 3990 |
value: 32.683
|
| 3991 |
- type: recall_at_5
|
| 3992 |
value: 36.756
|
| 3993 |
+
- task:
|
| 3994 |
type: Retrieval
|
| 3995 |
+
dataset:
|
|
|
|
| 3996 |
name: MTEB ClimateFEVER (default)
|
|
|
|
|
|
|
| 3997 |
type: mteb/climate-fever
|
| 3998 |
+
config: default
|
| 3999 |
+
split: test
|
| 4000 |
+
revision: 47f2ac6acb640fc46020b02a5b59fdda04d39380
|
| 4001 |
metrics:
|
| 4002 |
- type: main_score
|
| 4003 |
value: 25.068
|
|
|
|
| 4281 |
value: 18.312
|
| 4282 |
- type: recall_at_5
|
| 4283 |
value: 22.776
|
| 4284 |
+
- task:
|
| 4285 |
type: Retrieval
|
| 4286 |
+
dataset:
|
|
|
|
| 4287 |
name: MTEB DBPedia (default)
|
|
|
|
|
|
|
| 4288 |
type: mteb/dbpedia
|
| 4289 |
+
config: default
|
| 4290 |
+
split: test
|
| 4291 |
+
revision: c0f706b76e590d620bd6618b3ca8efdd34e2d659
|
| 4292 |
metrics:
|
| 4293 |
- type: main_score
|
| 4294 |
value: 40.128
|
|
|
|
| 4572 |
value: 14.562
|
| 4573 |
- type: recall_at_5
|
| 4574 |
value: 18.779
|
| 4575 |
+
- task:
|
| 4576 |
+
type: Classification
|
| 4577 |
+
dataset:
|
|
|
|
| 4578 |
name: MTEB EmotionClassification (default)
|
|
|
|
|
|
|
| 4579 |
type: mteb/emotion
|
| 4580 |
+
config: default
|
| 4581 |
+
split: test
|
| 4582 |
+
revision: 4f58c6b202a23cf9a4da393831edf4f9183cad37
|
| 4583 |
metrics:
|
| 4584 |
- type: accuracy
|
| 4585 |
value: 74.86
|
|
|
|
| 4589 |
value: 75.96499621761998
|
| 4590 |
- type: main_score
|
| 4591 |
value: 74.86
|
| 4592 |
+
- task:
|
| 4593 |
+
type: Retrieval
|
| 4594 |
+
dataset:
|
|
|
|
| 4595 |
name: MTEB FEVER (default)
|
|
|
|
|
|
|
| 4596 |
type: mteb/fever
|
| 4597 |
+
config: default
|
| 4598 |
+
split: test
|
| 4599 |
+
revision: bea83ef9e8fb933d90a2f1d5515737465d613e12
|
| 4600 |
metrics:
|
| 4601 |
- type: main_score
|
| 4602 |
value: 86.029
|
|
|
|
| 4880 |
value: 88.382
|
| 4881 |
- type: recall_at_5
|
| 4882 |
value: 90.908
|
| 4883 |
+
- task:
|
| 4884 |
type: Retrieval
|
| 4885 |
+
dataset:
|
|
|
|
| 4886 |
name: MTEB FiQA2018 (default)
|
|
|
|
|
|
|
| 4887 |
type: mteb/fiqa
|
| 4888 |
+
config: default
|
| 4889 |
+
split: test
|
| 4890 |
+
revision: 27a168819829fe9bcd655c2df245fb19452e8e06
|
| 4891 |
metrics:
|
| 4892 |
- type: main_score
|
| 4893 |
value: 45.238
|
|
|
|
| 5171 |
value: 37.656
|
| 5172 |
- type: recall_at_5
|
| 5173 |
value: 44.766
|
| 5174 |
+
- task:
|
| 5175 |
type: Retrieval
|
| 5176 |
+
dataset:
|
|
|
|
| 5177 |
name: MTEB HotpotQA (default)
|
|
|
|
|
|
|
| 5178 |
type: mteb/hotpotqa
|
| 5179 |
+
config: default
|
| 5180 |
+
split: test
|
| 5181 |
+
revision: ab518f4d6fcca38d87c25209f94beba119d02014
|
| 5182 |
metrics:
|
| 5183 |
- type: main_score
|
| 5184 |
value: 66.672
|
|
|
|
| 5462 |
value: 57.522
|
| 5463 |
- type: recall_at_5
|
| 5464 |
value: 62.134
|
| 5465 |
+
- task:
|
| 5466 |
+
type: Classification
|
| 5467 |
+
dataset:
|
|
|
|
| 5468 |
name: MTEB ImdbClassification (default)
|
|
|
|
|
|
|
| 5469 |
type: mteb/imdb
|
| 5470 |
+
config: default
|
| 5471 |
+
split: test
|
| 5472 |
+
revision: 3d86128a09e091d6018b6d26cad27f2739fc2db7
|
| 5473 |
metrics:
|
| 5474 |
- type: accuracy
|
| 5475 |
value: 93.5944
|
|
|
|
| 5483 |
value: 93.58945949328377
|
| 5484 |
- type: main_score
|
| 5485 |
value: 93.5944
|
| 5486 |
+
- task:
|
| 5487 |
+
type: Retrieval
|
| 5488 |
+
dataset:
|
|
|
|
| 5489 |
name: MTEB MSMARCO (default)
|
|
|
|
|
|
|
| 5490 |
type: mteb/msmarco
|
| 5491 |
+
config: default
|
| 5492 |
+
split: dev
|
| 5493 |
+
revision: c5a29a104738b98a9e76336939199e264163d4a0
|
| 5494 |
metrics:
|
| 5495 |
- type: main_score
|
| 5496 |
value: 41.448
|
|
|
|
| 5774 |
value: 41.304
|
| 5775 |
- type: recall_at_5
|
| 5776 |
value: 51.076
|
| 5777 |
+
- task:
|
| 5778 |
+
type: Classification
|
| 5779 |
+
dataset:
|
|
|
|
| 5780 |
name: MTEB MTOPDomainClassification (en)
|
|
|
|
|
|
|
| 5781 |
type: mteb/mtop_domain
|
| 5782 |
+
config: en
|
| 5783 |
+
split: test
|
| 5784 |
+
revision: d80d48c1eb48d3562165c59d59d0034df9fff0bf
|
| 5785 |
metrics:
|
| 5786 |
- type: accuracy
|
| 5787 |
value: 96.03967168262655
|
|
|
|
| 5791 |
value: 96.06623245823347
|
| 5792 |
- type: main_score
|
| 5793 |
value: 96.03967168262655
|
| 5794 |
+
- task:
|
| 5795 |
type: Classification
|
| 5796 |
+
dataset:
|
|
|
|
| 5797 |
name: MTEB MTOPIntentClassification (en)
|
|
|
|
|
|
|
| 5798 |
type: mteb/mtop_intent
|
| 5799 |
+
config: en
|
| 5800 |
+
split: test
|
| 5801 |
+
revision: ae001d0e6b1228650b7bd1c2c65fb50ad11a8aba
|
| 5802 |
metrics:
|
| 5803 |
- type: accuracy
|
| 5804 |
value: 89.12904696762428
|
|
|
|
| 5808 |
value: 90.41290566743324
|
| 5809 |
- type: main_score
|
| 5810 |
value: 89.12904696762428
|
| 5811 |
+
- task:
|
| 5812 |
type: Classification
|
| 5813 |
+
dataset:
|
|
|
|
| 5814 |
name: MTEB MassiveIntentClassification (en)
|
|
|
|
|
|
|
| 5815 |
type: mteb/amazon_massive_intent
|
| 5816 |
+
config: en
|
| 5817 |
+
split: test
|
| 5818 |
+
revision: 4672e20407010da34463acc759c162ca9734bca6
|
| 5819 |
metrics:
|
| 5820 |
- type: accuracy
|
| 5821 |
value: 76.49630127774041
|
|
|
|
| 5825 |
value: 76.42436195016484
|
| 5826 |
- type: main_score
|
| 5827 |
value: 76.49630127774041
|
| 5828 |
+
- task:
|
| 5829 |
type: Classification
|
| 5830 |
+
dataset:
|
|
|
|
| 5831 |
name: MTEB MassiveScenarioClassification (en)
|
|
|
|
|
|
|
| 5832 |
type: mteb/amazon_massive_scenario
|
| 5833 |
+
config: en
|
| 5834 |
+
split: test
|
| 5835 |
+
revision: fad2c6e8459f9e1c45d9315f4953d921437d70f8
|
| 5836 |
metrics:
|
| 5837 |
- type: accuracy
|
| 5838 |
value: 78.9340954942838
|
|
|
|
| 5842 |
value: 78.87787647838971
|
| 5843 |
- type: main_score
|
| 5844 |
value: 78.9340954942838
|
| 5845 |
+
- task:
|
| 5846 |
+
type: Clustering
|
| 5847 |
+
dataset:
|
|
|
|
| 5848 |
name: MTEB MedrxivClusteringP2P (default)
|
|
|
|
|
|
|
| 5849 |
type: mteb/medrxiv-clustering-p2p
|
| 5850 |
+
config: default
|
| 5851 |
+
split: test
|
| 5852 |
+
revision: e7a26af6f3ae46b30dde8737f02c07b1505bcc73
|
| 5853 |
metrics:
|
| 5854 |
- type: main_score
|
| 5855 |
value: 37.50182848656019
|
|
|
|
| 5857 |
value: 37.50182848656019
|
| 5858 |
- type: v_measure_std
|
| 5859 |
value: 1.1708518023877268
|
| 5860 |
+
- task:
|
| 5861 |
type: Clustering
|
| 5862 |
+
dataset:
|
|
|
|
| 5863 |
name: MTEB MedrxivClusteringS2S (default)
|
|
|
|
|
|
|
| 5864 |
type: mteb/medrxiv-clustering-s2s
|
| 5865 |
+
config: default
|
| 5866 |
+
split: test
|
| 5867 |
+
revision: 35191c8c0dca72d8ff3efcd72aa802307d469663
|
| 5868 |
metrics:
|
| 5869 |
- type: main_score
|
| 5870 |
value: 35.72762609825363
|
|
|
|
| 5872 |
value: 35.72762609825363
|
| 5873 |
- type: v_measure_std
|
| 5874 |
value: 1.4555014772914985
|
| 5875 |
+
- task:
|
| 5876 |
+
type: Reranking
|
| 5877 |
+
dataset:
|
|
|
|
| 5878 |
name: MTEB MindSmallReranking (default)
|
|
|
|
|
|
|
| 5879 |
type: mteb/mind_small
|
| 5880 |
+
config: default
|
| 5881 |
+
split: test
|
| 5882 |
+
revision: 59042f120c80e8afa9cdbb224f67076cec0fc9a7
|
| 5883 |
metrics:
|
| 5884 |
- type: main_score
|
| 5885 |
value: 30.47716416454022
|
|
|
|
| 5899 |
value: -15.78941850629242
|
| 5900 |
- type: nAUC_mrr_std
|
| 5901 |
value: -1.1330442292510805
|
| 5902 |
+
- task:
|
| 5903 |
+
type: Retrieval
|
| 5904 |
+
dataset:
|
|
|
|
| 5905 |
name: MTEB NFCorpus (default)
|
|
|
|
|
|
|
| 5906 |
type: mteb/nfcorpus
|
| 5907 |
+
config: default
|
| 5908 |
+
split: test
|
| 5909 |
+
revision: ec0fa4fe99da2ff19ca1214b7966684033a58814
|
| 5910 |
metrics:
|
| 5911 |
- type: main_score
|
| 5912 |
value: 34.648
|
|
|
|
| 6190 |
value: 10.037
|
| 6191 |
- type: recall_at_5
|
| 6192 |
value: 12.717999999999998
|
| 6193 |
+
- task:
|
| 6194 |
type: Retrieval
|
| 6195 |
+
dataset:
|
|
|
|
| 6196 |
name: MTEB NQ (default)
|
|
|
|
|
|
|
| 6197 |
type: mteb/nq
|
| 6198 |
+
config: default
|
| 6199 |
+
split: test
|
| 6200 |
+
revision: b774495ed302d8c44a3a7ea25c90dbce03968f31
|
| 6201 |
metrics:
|
| 6202 |
- type: main_score
|
| 6203 |
value: 60.06
|
|
|
|
| 6481 |
value: 61.114000000000004
|
| 6482 |
- type: recall_at_5
|
| 6483 |
value: 69.812
|
| 6484 |
+
- task:
|
| 6485 |
type: Retrieval
|
| 6486 |
+
dataset:
|
|
|
|
| 6487 |
name: MTEB QuoraRetrieval (default)
|
|
|
|
|
|
|
| 6488 |
type: mteb/quora
|
| 6489 |
+
config: default
|
| 6490 |
+
split: test
|
| 6491 |
+
revision: e4e08e0b7dbe3c8700f0daef558ff32256715259
|
| 6492 |
metrics:
|
| 6493 |
- type: main_score
|
| 6494 |
value: 89.821
|
|
|
|
| 6772 |
value: 88.714
|
| 6773 |
- type: recall_at_5
|
| 6774 |
value: 92.96799999999999
|
| 6775 |
+
- task:
|
| 6776 |
+
type: Clustering
|
| 6777 |
+
dataset:
|
|
|
|
| 6778 |
name: MTEB RedditClustering (default)
|
|
|
|
|
|
|
| 6779 |
type: mteb/reddit-clustering
|
| 6780 |
+
config: default
|
| 6781 |
+
split: test
|
| 6782 |
+
revision: 24640382cdbf8abc73003fb0fa6d111a705499eb
|
| 6783 |
metrics:
|
| 6784 |
- type: main_score
|
| 6785 |
value: 59.36038828851887
|
|
|
|
| 6787 |
value: 59.36038828851887
|
| 6788 |
- type: v_measure_std
|
| 6789 |
value: 4.1958765965154425
|
| 6790 |
+
- task:
|
| 6791 |
type: Clustering
|
| 6792 |
+
dataset:
|
|
|
|
| 6793 |
name: MTEB RedditClusteringP2P (default)
|
|
|
|
|
|
|
| 6794 |
type: mteb/reddit-clustering-p2p
|
| 6795 |
+
config: default
|
| 6796 |
+
split: test
|
| 6797 |
+
revision: 385e3cb46b4cfa89021f56c4380204149d0efe33
|
| 6798 |
metrics:
|
| 6799 |
- type: main_score
|
| 6800 |
value: 64.67522832408089
|
|
|
|
| 6802 |
value: 64.67522832408089
|
| 6803 |
- type: v_measure_std
|
| 6804 |
value: 12.473765016158698
|
| 6805 |
+
- task:
|
| 6806 |
+
type: Retrieval
|
| 6807 |
+
dataset:
|
|
|
|
| 6808 |
name: MTEB SCIDOCS (default)
|
|
|
|
|
|
|
| 6809 |
type: mteb/scidocs
|
| 6810 |
+
config: default
|
| 6811 |
+
split: test
|
| 6812 |
+
revision: f8c2fcf00f625baaa80f62ec5bd9e1fff3b8ae88
|
| 6813 |
metrics:
|
| 6814 |
- type: main_score
|
| 6815 |
value: 21.751
|
|
|
|
| 7093 |
value: 11.648
|
| 7094 |
- type: recall_at_5
|
| 7095 |
value: 15.883
|
| 7096 |
+
- task:
|
| 7097 |
+
type: STS
|
| 7098 |
+
dataset:
|
|
|
|
| 7099 |
name: MTEB SICK-R (default)
|
|
|
|
|
|
|
| 7100 |
type: mteb/sickr-sts
|
| 7101 |
+
config: default
|
| 7102 |
+
split: test
|
| 7103 |
+
revision: 20a6d6f312dd54037fe07a32d58e5e168867909d
|
| 7104 |
metrics:
|
| 7105 |
- type: cosine_pearson
|
| 7106 |
value: 84.0161170579997
|
|
|
|
| 7120 |
value: 84.0161170579997
|
| 7121 |
- type: spearman
|
| 7122 |
value: 77.52025923874551
|
| 7123 |
+
- task:
|
| 7124 |
type: STS
|
| 7125 |
+
dataset:
|
|
|
|
| 7126 |
name: MTEB STS12 (default)
|
|
|
|
|
|
|
| 7127 |
type: mteb/sts12-sts
|
| 7128 |
+
config: default
|
| 7129 |
+
split: test
|
| 7130 |
+
revision: a0d554a64d88156834ff5ae9920b964011b16384
|
| 7131 |
metrics:
|
| 7132 |
- type: cosine_pearson
|
| 7133 |
value: 81.32328780209225
|
|
|
|
| 7147 |
value: 81.32328780209225
|
| 7148 |
- type: spearman
|
| 7149 |
value: 74.17570679745272
|
| 7150 |
+
- task:
|
| 7151 |
type: STS
|
| 7152 |
+
dataset:
|
|
|
|
| 7153 |
name: MTEB STS13 (default)
|
|
|
|
|
|
|
| 7154 |
type: mteb/sts13-sts
|
| 7155 |
+
config: default
|
| 7156 |
+
split: test
|
| 7157 |
+
revision: 7e90230a92c190f1bf69ae9002b8cea547a64cca
|
| 7158 |
metrics:
|
| 7159 |
- type: cosine_pearson
|
| 7160 |
value: 85.53224141249392
|
|
|
|
| 7174 |
value: 85.53224141249392
|
| 7175 |
- type: spearman
|
| 7176 |
value: 86.16981525069227
|
| 7177 |
+
- task:
|
| 7178 |
type: STS
|
| 7179 |
+
dataset:
|
|
|
|
| 7180 |
name: MTEB STS14 (default)
|
|
|
|
|
|
|
| 7181 |
type: mteb/sts14-sts
|
| 7182 |
+
config: default
|
| 7183 |
+
split: test
|
| 7184 |
+
revision: 6031580fec1f6af667f0bd2da0a551cf4f0b2375
|
| 7185 |
metrics:
|
| 7186 |
- type: cosine_pearson
|
| 7187 |
value: 82.234064045301
|
|
|
|
| 7201 |
value: 82.234064045301
|
| 7202 |
- type: spearman
|
| 7203 |
value: 78.86920830792957
|
| 7204 |
+
- task:
|
| 7205 |
type: STS
|
| 7206 |
+
dataset:
|
|
|
|
| 7207 |
name: MTEB STS15 (default)
|
|
|
|
|
|
|
| 7208 |
type: mteb/sts15-sts
|
| 7209 |
+
config: default
|
| 7210 |
+
split: test
|
| 7211 |
+
revision: ae752c7c21bf194d8b67fd573edf7ae58183cbe3
|
| 7212 |
metrics:
|
| 7213 |
- type: cosine_pearson
|
| 7214 |
value: 86.23114543080261
|
|
|
|
| 7228 |
value: 86.23114543080261
|
| 7229 |
- type: spearman
|
| 7230 |
value: 87.481042891123
|
| 7231 |
+
- task:
|
| 7232 |
type: STS
|
| 7233 |
+
dataset:
|
|
|
|
| 7234 |
name: MTEB STS16 (default)
|
|
|
|
|
|
|
| 7235 |
type: mteb/sts16-sts
|
| 7236 |
+
config: default
|
| 7237 |
+
split: test
|
| 7238 |
+
revision: 4d8694f8f0e0100860b497b999b3dbed754a0513
|
| 7239 |
metrics:
|
| 7240 |
- type: cosine_pearson
|
| 7241 |
value: 82.9156629047782
|
|
|
|
| 7255 |
value: 82.9156629047782
|
| 7256 |
- type: spearman
|
| 7257 |
value: 84.28381329207937
|
| 7258 |
+
- task:
|
| 7259 |
type: STS
|
| 7260 |
+
dataset:
|
|
|
|
| 7261 |
name: MTEB STS17 (en-en)
|
|
|
|
|
|
|
| 7262 |
type: mteb/sts17-crosslingual-sts
|
| 7263 |
+
config: en-en
|
| 7264 |
+
split: test
|
| 7265 |
+
revision: faeb762787bd10488a50c8b5be4a3b82e411949c
|
| 7266 |
metrics:
|
| 7267 |
- type: cosine_pearson
|
| 7268 |
value: 88.91985349746744
|
|
|
|
| 7282 |
value: 88.91985349746744
|
| 7283 |
- type: spearman
|
| 7284 |
value: 89.69151633966257
|
| 7285 |
+
- task:
|
| 7286 |
type: STS
|
| 7287 |
+
dataset:
|
|
|
|
| 7288 |
name: MTEB STS22 (en)
|
|
|
|
|
|
|
| 7289 |
type: mteb/sts22-crosslingual-sts
|
| 7290 |
+
config: en
|
| 7291 |
+
split: test
|
| 7292 |
+
revision: de9d86b3b84231dc21f76c7b7af1f28e2f57f6e3
|
| 7293 |
metrics:
|
| 7294 |
- type: cosine_pearson
|
| 7295 |
value: 65.0979772547511
|
|
|
|
| 7309 |
value: 65.0979772547511
|
| 7310 |
- type: spearman
|
| 7311 |
value: 65.78126527764236
|
| 7312 |
+
- task:
|
| 7313 |
type: STS
|
| 7314 |
+
dataset:
|
|
|
|
| 7315 |
name: MTEB STSBenchmark (default)
|
|
|
|
|
|
|
| 7316 |
type: mteb/stsbenchmark-sts
|
| 7317 |
+
config: default
|
| 7318 |
+
split: test
|
| 7319 |
+
revision: b0fddb56ed78048fa8b90373c8a3cfc37b684831
|
| 7320 |
metrics:
|
| 7321 |
- type: cosine_pearson
|
| 7322 |
value: 85.6426635049971
|
|
|
|
| 7336 |
value: 85.6426635049971
|
| 7337 |
- type: spearman
|
| 7338 |
value: 85.609856578385
|
| 7339 |
+
- task:
|
| 7340 |
+
type: Reranking
|
| 7341 |
+
dataset:
|
|
|
|
| 7342 |
name: MTEB SciDocsRR (default)
|
|
|
|
|
|
|
| 7343 |
type: mteb/scidocs-reranking
|
| 7344 |
+
config: default
|
| 7345 |
+
split: test
|
| 7346 |
+
revision: d3c5e1fc0b855ab6097bf1cda04dd73947d7caab
|
| 7347 |
metrics:
|
| 7348 |
- type: main_score
|
| 7349 |
value: 82.85163332499799
|
|
|
|
| 7363 |
value: 89.47202967481866
|
| 7364 |
- type: nAUC_mrr_std
|
| 7365 |
value: 85.40446996933892
|
| 7366 |
+
- task:
|
| 7367 |
+
type: Retrieval
|
| 7368 |
+
dataset:
|
|
|
|
| 7369 |
name: MTEB SciFact (default)
|
|
|
|
|
|
|
| 7370 |
type: mteb/scifact
|
| 7371 |
+
config: default
|
| 7372 |
+
split: test
|
| 7373 |
+
revision: 0228b52cf27578f30900b9e5271d331663a030d7
|
| 7374 |
metrics:
|
| 7375 |
- type: main_score
|
| 7376 |
value: 71.655
|
|
|
|
| 7654 |
value: 71.61699999999999
|
| 7655 |
- type: recall_at_5
|
| 7656 |
value: 78.361
|
| 7657 |
+
- task:
|
| 7658 |
+
type: PairClassification
|
| 7659 |
+
dataset:
|
|
|
|
| 7660 |
name: MTEB SprintDuplicateQuestions (default)
|
|
|
|
|
|
|
| 7661 |
type: mteb/sprintduplicatequestions-pairclassification
|
| 7662 |
+
config: default
|
| 7663 |
+
split: test
|
| 7664 |
+
revision: d66bd1f72af766a5cc4b0ca5e00c162f89e8cc46
|
| 7665 |
metrics:
|
| 7666 |
- type: cosine_accuracy
|
| 7667 |
value: 99.8019801980198
|
|
|
|
| 7745 |
value: 90.79754601226993
|
| 7746 |
- type: similarity_recall
|
| 7747 |
value: 88.8
|
| 7748 |
+
- task:
|
| 7749 |
+
type: Clustering
|
| 7750 |
+
dataset:
|
|
|
|
| 7751 |
name: MTEB StackExchangeClustering (default)
|
|
|
|
|
|
|
| 7752 |
type: mteb/stackexchange-clustering
|
| 7753 |
+
config: default
|
| 7754 |
+
split: test
|
| 7755 |
+
revision: 6cbc1f7b2bc0622f2e39d2c77fa502909748c259
|
| 7756 |
metrics:
|
| 7757 |
- type: main_score
|
| 7758 |
value: 66.63931197758824
|
|
|
|
| 7760 |
value: 66.63931197758824
|
| 7761 |
- type: v_measure_std
|
| 7762 |
value: 3.896206781511776
|
| 7763 |
+
- task:
|
| 7764 |
type: Clustering
|
| 7765 |
+
dataset:
|
|
|
|
| 7766 |
name: MTEB StackExchangeClusteringP2P (default)
|
|
|
|
|
|
|
| 7767 |
type: mteb/stackexchange-clustering-p2p
|
| 7768 |
+
config: default
|
| 7769 |
+
split: test
|
| 7770 |
+
revision: 815ca46b2622cec33ccafc3735d572c266efdb44
|
| 7771 |
metrics:
|
| 7772 |
- type: main_score
|
| 7773 |
value: 38.984892653301884
|
|
|
|
| 7775 |
value: 38.984892653301884
|
| 7776 |
- type: v_measure_std
|
| 7777 |
value: 1.3308552162270453
|
| 7778 |
+
- task:
|
| 7779 |
+
type: Reranking
|
| 7780 |
+
dataset:
|
|
|
|
| 7781 |
name: MTEB StackOverflowDupQuestions (default)
|
|
|
|
|
|
|
| 7782 |
type: mteb/stackoverflowdupquestions-reranking
|
| 7783 |
+
config: default
|
| 7784 |
+
split: test
|
| 7785 |
+
revision: e185fbe320c72810689fc5848eb6114e1ef5ec69
|
| 7786 |
metrics:
|
| 7787 |
- type: main_score
|
| 7788 |
value: 52.71499643455044
|
|
|
|
| 7802 |
value: 13.931448578334379
|
| 7803 |
- type: nAUC_mrr_std
|
| 7804 |
value: 10.441860004959661
|
| 7805 |
+
- task:
|
| 7806 |
+
type: Summarization
|
| 7807 |
+
dataset:
|
|
|
|
| 7808 |
name: MTEB SummEval (default)
|
|
|
|
|
|
|
| 7809 |
type: mteb/summeval
|
| 7810 |
+
config: default
|
| 7811 |
+
split: test
|
| 7812 |
+
revision: cda12ad7615edc362dbf25a00fdd61d3b1eaf93c
|
| 7813 |
metrics:
|
| 7814 |
- type: cosine_pearson
|
| 7815 |
value: 31.5167525286909
|
|
|
|
| 7825 |
value: 31.5167525286909
|
| 7826 |
- type: spearman
|
| 7827 |
value: 31.218862970706496
|
| 7828 |
+
- task:
|
| 7829 |
+
type: Retrieval
|
| 7830 |
+
dataset:
|
|
|
|
| 7831 |
name: MTEB TRECCOVID (default)
|
|
|
|
|
|
|
| 7832 |
type: mteb/trec-covid
|
| 7833 |
+
config: default
|
| 7834 |
+
split: test
|
| 7835 |
+
revision: bb9466bac8153a0349341eb1b22e06409e78ef4e
|
| 7836 |
metrics:
|
| 7837 |
- type: main_score
|
| 7838 |
value: 78.996
|
|
|
|
| 8116 |
value: 0.705
|
| 8117 |
- type: recall_at_5
|
| 8118 |
value: 1.162
|
| 8119 |
+
- task:
|
| 8120 |
type: Retrieval
|
| 8121 |
+
dataset:
|
|
|
|
| 8122 |
name: MTEB Touche2020 (default)
|
|
|
|
|
|
|
| 8123 |
type: mteb/touche2020
|
| 8124 |
+
config: default
|
| 8125 |
+
split: test
|
| 8126 |
+
revision: a34f9a33db75fa0cbb21bb5cfc3dae8dc8bec93f
|
| 8127 |
metrics:
|
| 8128 |
- type: main_score
|
| 8129 |
value: 24.234
|
|
|
|
| 8407 |
value: 6.625
|
| 8408 |
- type: recall_at_5
|
| 8409 |
value: 9.094
|
| 8410 |
+
- task:
|
| 8411 |
+
type: Classification
|
| 8412 |
+
dataset:
|
|
|
|
| 8413 |
name: MTEB ToxicConversationsClassification (default)
|
|
|
|
|
|
|
| 8414 |
type: mteb/toxic_conversations_50k
|
| 8415 |
+
config: default
|
| 8416 |
+
split: test
|
| 8417 |
+
revision: edfaf9da55d3dd50d43143d90c1ac476895ae6de
|
| 8418 |
metrics:
|
| 8419 |
- type: accuracy
|
| 8420 |
value: 72.822265625
|
|
|
|
| 8428 |
value: 78.7454393727821
|
| 8429 |
- type: main_score
|
| 8430 |
value: 72.822265625
|
| 8431 |
+
- task:
|
| 8432 |
type: Classification
|
| 8433 |
+
dataset:
|
|
|
|
| 8434 |
name: MTEB TweetSentimentExtractionClassification (default)
|
|
|
|
|
|
|
| 8435 |
type: mteb/tweet_sentiment_extraction
|
| 8436 |
+
config: default
|
| 8437 |
+
split: test
|
| 8438 |
+
revision: d604517c81ca91fe16a244d1248fc021f9ecee7a
|
| 8439 |
metrics:
|
| 8440 |
- type: accuracy
|
| 8441 |
value: 72.54385964912281
|
|
|
|
| 8445 |
value: 72.18022450339639
|
| 8446 |
- type: main_score
|
| 8447 |
value: 72.54385964912281
|
| 8448 |
+
- task:
|
| 8449 |
+
type: Clustering
|
| 8450 |
+
dataset:
|
|
|
|
| 8451 |
name: MTEB TwentyNewsgroupsClustering (default)
|
|
|
|
|
|
|
| 8452 |
type: mteb/twentynewsgroups-clustering
|
| 8453 |
+
config: default
|
| 8454 |
+
split: test
|
| 8455 |
+
revision: 6125ec4e24fa026cec8a478383ee943acfbd5449
|
| 8456 |
metrics:
|
| 8457 |
- type: main_score
|
| 8458 |
value: 57.41861450414374
|
|
|
|
| 8460 |
value: 57.41861450414374
|
| 8461 |
- type: v_measure_std
|
| 8462 |
value: 1.1732394227153524
|
| 8463 |
+
- task:
|
| 8464 |
+
type: PairClassification
|
| 8465 |
+
dataset:
|
|
|
|
| 8466 |
name: MTEB TwitterSemEval2015 (default)
|
|
|
|
|
|
|
| 8467 |
type: mteb/twittersemeval2015-pairclassification
|
| 8468 |
+
config: default
|
| 8469 |
+
split: test
|
| 8470 |
+
revision: 70970daeab8776df92f5ea462b6173c0b46fd2d1
|
| 8471 |
metrics:
|
| 8472 |
- type: cosine_accuracy
|
| 8473 |
value: 85.65893783155511
|
|
|
|
| 8551 |
value: 64.0855106888361
|
| 8552 |
- type: similarity_recall
|
| 8553 |
value: 71.18733509234828
|
| 8554 |
+
- task:
|
| 8555 |
type: PairClassification
|
| 8556 |
+
dataset:
|
|
|
|
| 8557 |
name: MTEB TwitterURLCorpus (default)
|
|
|
|
|
|
|
| 8558 |
type: mteb/twitterurlcorpus-pairclassification
|
| 8559 |
+
config: default
|
| 8560 |
+
split: test
|
| 8561 |
+
revision: 8b6510b0b1fa4e4c4f879467980e9be563ec1cdf
|
| 8562 |
metrics:
|
| 8563 |
- type: cosine_accuracy
|
| 8564 |
value: 88.86754375751931
|
|
|
|
| 8642 |
value: 74.19310344827586
|
| 8643 |
- type: similarity_recall
|
| 8644 |
value: 82.83030489682784
|
|
|
|
|
|
|
| 8645 |
---
|
| 8646 |
# Contextual Document Embeddings (CDE)
|
| 8647 |
|
config.json
CHANGED
|
@@ -1,8 +1,14 @@
|
|
| 1 |
{
|
|
|
|
| 2 |
"architecture": "transductive",
|
| 3 |
"architectures": [
|
| 4 |
-
"
|
| 5 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
"biencoder_pooling_strategy": "mean",
|
| 7 |
"cache_dir": null,
|
| 8 |
"config_name": null,
|
|
|
|
| 1 |
{
|
| 2 |
+
"_name_or_path": "/fsx-checkpoints/jxm/cde/2024-09-18-supervised-final-bge--epoch-4/checkpoint-1820",
|
| 3 |
"architecture": "transductive",
|
| 4 |
"architectures": [
|
| 5 |
+
"DatasetTransformer"
|
| 6 |
],
|
| 7 |
+
"attn_implementation": null,
|
| 8 |
+
"auto_map": {
|
| 9 |
+
"AutoConfig": "misc.ContextualModelConfig",
|
| 10 |
+
"AutoModel": "model.DatasetTransformer"
|
| 11 |
+
},
|
| 12 |
"biencoder_pooling_strategy": "mean",
|
| 13 |
"cache_dir": null,
|
| 14 |
"config_name": null,
|
misc.py
ADDED
|
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Iterable, List, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import collections
|
| 4 |
+
import functools
|
| 5 |
+
import glob
|
| 6 |
+
import json
|
| 7 |
+
import hashlib
|
| 8 |
+
import itertools
|
| 9 |
+
import logging
|
| 10 |
+
import multiprocessing
|
| 11 |
+
import os
|
| 12 |
+
import pickle
|
| 13 |
+
import random
|
| 14 |
+
import requests
|
| 15 |
+
import sys
|
| 16 |
+
import zipfile
|
| 17 |
+
|
| 18 |
+
import datasets
|
| 19 |
+
import numpy as np
|
| 20 |
+
import safetensors
|
| 21 |
+
import torch
|
| 22 |
+
import tqdm
|
| 23 |
+
import transformers
|
| 24 |
+
|
| 25 |
+
from cde.lib.dist import get_num_proc, get_rank
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def get_cde_cache_dir() -> str:
|
| 29 |
+
script_directory = os.path.normpath(
|
| 30 |
+
os.path.join(
|
| 31 |
+
os.path.dirname(os.path.abspath(__file__)),
|
| 32 |
+
os.pardir, os.pardir,
|
| 33 |
+
)
|
| 34 |
+
)
|
| 35 |
+
return os.path.join(script_directory, "data")
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_cache_location_from_kwargs(**kwargs):
|
| 39 |
+
cache_location = os.path.join(
|
| 40 |
+
get_cde_cache_dir(), "cluster"
|
| 41 |
+
)
|
| 42 |
+
os.makedirs(cache_location, exist_ok=True)
|
| 43 |
+
return os.path.join(cache_location, md5_hash_kwargs(**kwargs))
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def process_qrels_uncached(corpus: datasets.Dataset, qrels: datasets.Dataset) -> Tuple[Dict[str, List[float]], Dict[str, List[str]]]:
|
| 47 |
+
qrels_idxs = collections.defaultdict(list)
|
| 48 |
+
qrels_scores = collections.defaultdict(list)
|
| 49 |
+
corpus_ids = np.array(corpus['_id'])
|
| 50 |
+
skipped_qrels = 0
|
| 51 |
+
|
| 52 |
+
for ex in tqdm.tqdm(qrels, desc='processing qrels', colour='#964B00', leave=False):
|
| 53 |
+
#
|
| 54 |
+
# example:
|
| 55 |
+
# {
|
| 56 |
+
# 'query-id': 1,
|
| 57 |
+
# 'corpus-id': 'b0680508-2019-04-18T13:48:51Z-00002-000',
|
| 58 |
+
# 'score': 2
|
| 59 |
+
# }
|
| 60 |
+
#
|
| 61 |
+
q_id = str(ex['query-id'])
|
| 62 |
+
c_idxs = (corpus_ids == str(ex['corpus-id'])).nonzero()[0]
|
| 63 |
+
#
|
| 64 |
+
assert len(c_idxs) <= 1, f"error - duplicate corpus ID? (found {len(c_idxs)} matches)"
|
| 65 |
+
#
|
| 66 |
+
if len(c_idxs):
|
| 67 |
+
qrels_idxs[q_id].append(c_idxs[0])
|
| 68 |
+
qrels_scores[q_id].append(ex['score'])
|
| 69 |
+
else:
|
| 70 |
+
skipped_qrels += 1
|
| 71 |
+
#
|
| 72 |
+
|
| 73 |
+
if skipped_qrels > 0:
|
| 74 |
+
logging.warning(f'Warning: Skipped {skipped_qrels}/{len(qrels)} qrels.')
|
| 75 |
+
|
| 76 |
+
return qrels_idxs, qrels_scores
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def process_qrels(
|
| 80 |
+
corpus: datasets.Dataset, qrels: datasets.Dataset,
|
| 81 |
+
use_cache: bool = True
|
| 82 |
+
) -> Tuple[Dict[str, List[float]], Dict[str, List[str]]]:
|
| 83 |
+
dataset_cache_file = '_'.join(
|
| 84 |
+
(corpus.cache_files[0]['filename'], qrels.cache_files[0]['filename'])
|
| 85 |
+
)
|
| 86 |
+
cache_file = strip_extension(dataset_cache_file) + '_processed_qrels.p'
|
| 87 |
+
os.makedirs(os.path.dirname(cache_file), exist_ok=True)
|
| 88 |
+
|
| 89 |
+
if not (use_cache and os.path.exists(cache_file)):
|
| 90 |
+
qrels_idxs, qrels_scores = process_qrels_uncached(
|
| 91 |
+
corpus=corpus, qrels=qrels
|
| 92 |
+
)
|
| 93 |
+
if use_cache:
|
| 94 |
+
pickle.dump((qrels_idxs, qrels_scores), open(cache_file, 'wb'))
|
| 95 |
+
else:
|
| 96 |
+
qrels_idxs, qrels_scores = pickle.load(open(cache_file, 'rb'))
|
| 97 |
+
|
| 98 |
+
return qrels_idxs, qrels_scores
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def strip_extension(filename: str) -> str:
|
| 102 |
+
"""Strips file extension.
|
| 103 |
+
|
| 104 |
+
Ex:
|
| 105 |
+
>> strip_extension('/root/dir/sub/file.ext')
|
| 106 |
+
'/root/dir/sub/file'
|
| 107 |
+
"""
|
| 108 |
+
return os.path.splitext(filename)[0]
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def md5_hash(t: Tuple[str]) -> str:
|
| 112 |
+
return hashlib.md5('__'.join(t).encode()).hexdigest()
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def md5_hash_kwargs(**kwargs) -> str:
|
| 116 |
+
# We ignore special hf args that start with _ like '__cached__setup_devices'.
|
| 117 |
+
safe_kwargs = {k: str(v) for k,v in kwargs.items() if not k.startswith('_')}
|
| 118 |
+
s = json.dumps(safe_kwargs, sort_keys=True)
|
| 119 |
+
return hashlib.md5(s.encode()).hexdigest()
|
| 120 |
+
|
| 121 |
+
def download_url(url: str, save_path: str, chunk_size: int = 1024):
|
| 122 |
+
"""Download url with progress bar using tqdm
|
| 123 |
+
https://stackoverflow.com/questions/15644964/python-progress-bar-and-downloads
|
| 124 |
+
Args:
|
| 125 |
+
url (str): downloadable url
|
| 126 |
+
save_path (str): local path to save the downloaded file
|
| 127 |
+
chunk_size (int, optional): chunking of files. Defaults to 1024.
|
| 128 |
+
"""
|
| 129 |
+
r = requests.get(url, stream=True)
|
| 130 |
+
total = int(r.headers.get('Content-Length', 0))
|
| 131 |
+
with open(save_path, 'wb') as fd, tqdm.tqdm(
|
| 132 |
+
desc=save_path,
|
| 133 |
+
total=total,
|
| 134 |
+
unit='iB',
|
| 135 |
+
unit_scale=True,
|
| 136 |
+
unit_divisor=chunk_size,
|
| 137 |
+
) as bar:
|
| 138 |
+
for data in r.iter_content(chunk_size=chunk_size):
|
| 139 |
+
size = fd.write(data)
|
| 140 |
+
bar.update(size)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def unzip(zip_file: str, out_dir: str):
|
| 144 |
+
print("unzipping =>", zip_file)
|
| 145 |
+
zip_ = zipfile.ZipFile(zip_file, "r")
|
| 146 |
+
zip_.extractall(path=out_dir)
|
| 147 |
+
zip_.close()
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def download_url_and_unzip(url: str, out_dir: str, chunk_size: int = 1024) -> str:
|
| 151 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 152 |
+
dataset = url.split("/")[-1]
|
| 153 |
+
zip_file = os.path.join(out_dir, dataset)
|
| 154 |
+
|
| 155 |
+
if not os.path.isfile(zip_file):
|
| 156 |
+
logging.info("Downloading {} ...".format(dataset))
|
| 157 |
+
download_url(url, zip_file, chunk_size)
|
| 158 |
+
|
| 159 |
+
if not os.path.isdir(zip_file.replace(".zip", "")):
|
| 160 |
+
logging.info("Unzipping {} ...".format(dataset))
|
| 161 |
+
unzip(zip_file, out_dir)
|
| 162 |
+
|
| 163 |
+
return os.path.join(out_dir, dataset.replace(".zip", ""))
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def tqdm_if_main_worker(iterable: Iterable, **kwargs) -> Iterable:
|
| 167 |
+
if get_rank() == 0:
|
| 168 |
+
return tqdm.tqdm(iterable, **kwargs)
|
| 169 |
+
else:
|
| 170 |
+
return iterable
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class ContextualModelConfig(transformers.configuration_utils.PretrainedConfig):
|
| 174 |
+
"""We create a dummy configuration class that will just set properties
|
| 175 |
+
based on whatever kwargs we pass in.
|
| 176 |
+
|
| 177 |
+
When this class is initialized (see experiments.py) we pass in the
|
| 178 |
+
union of all data, model, and training args, all of which should
|
| 179 |
+
get saved to the config json.
|
| 180 |
+
"""
|
| 181 |
+
|
| 182 |
+
def __init__(self, **kwargs):
|
| 183 |
+
for key, value in kwargs.items():
|
| 184 |
+
try:
|
| 185 |
+
json.dumps(value)
|
| 186 |
+
setattr(self, key, value)
|
| 187 |
+
except TypeError:
|
| 188 |
+
# value was not JSON-serializable, skip
|
| 189 |
+
continue
|
| 190 |
+
super().__init__()
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def independent_crop(
|
| 194 |
+
input_ids: torch.Tensor, pad_token_id: int,
|
| 195 |
+
l1: int = 256, l2: int = 256) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 196 |
+
"""Returns two independent crops from input_ids.
|
| 197 |
+
|
| 198 |
+
Assumes input_ids has a beginning and end token, like
|
| 199 |
+
[101, ..., 102, 0, 0, 0].
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
input_ids: tensor of IDs
|
| 203 |
+
pad_token_id: ID of pad tokens in input_ids
|
| 204 |
+
l1: length of span 1, cropped
|
| 205 |
+
l2: length of span 2, cropped
|
| 206 |
+
Returns:
|
| 207 |
+
span1: first crop (of length l1)
|
| 208 |
+
span2: second crop (of length l2)
|
| 209 |
+
"""
|
| 210 |
+
# Count tokens until pad.
|
| 211 |
+
if (input_ids == pad_token_id).sum() == 0:
|
| 212 |
+
N = len(input_ids)
|
| 213 |
+
else:
|
| 214 |
+
N = (input_ids == pad_token_id).int().argmax().item()
|
| 215 |
+
|
| 216 |
+
####
|
| 217 |
+
###
|
| 218 |
+
##
|
| 219 |
+
## Contriever: We use the random cropping data
|
| 220 |
+
## augmentation, with documents of 256 tokens and span
|
| 221 |
+
## sizes sampled between 5% and 50% of the document
|
| 222 |
+
## length
|
| 223 |
+
##
|
| 224 |
+
###
|
| 225 |
+
#####
|
| 226 |
+
####### LaPraDor: The maximum lengths set for queries and
|
| 227 |
+
####### documents are 64 and 350...
|
| 228 |
+
#####
|
| 229 |
+
# TODO is this divide-by-two a good idea? (Don't want s1=s2 ever..)
|
| 230 |
+
nl1 = min(N//2, l1)
|
| 231 |
+
nl2 = min(N//2, l2)
|
| 232 |
+
|
| 233 |
+
s1_start = random.randint(1, N-nl1)
|
| 234 |
+
s2_start = random.randint(1, N-nl2)
|
| 235 |
+
|
| 236 |
+
s1_idxs = itertools.chain(
|
| 237 |
+
[0], range(s1_start, s1_start+nl1), [N-1]
|
| 238 |
+
)
|
| 239 |
+
s1 = input_ids[torch.tensor(list(s1_idxs))]
|
| 240 |
+
s2_idxs = itertools.chain(
|
| 241 |
+
[0], range(s2_start, s2_start+nl2), [N-1]
|
| 242 |
+
)
|
| 243 |
+
s2 = input_ids[torch.tensor(list(s2_idxs))]
|
| 244 |
+
return (s1, s2)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def load_dataset_tables(
|
| 248 |
+
files: Iterable[str], num_workers: int = 16
|
| 249 |
+
) -> Iterable[datasets.table.MemoryMappedTable]:
|
| 250 |
+
import concurrent
|
| 251 |
+
from multiprocessing import Pool
|
| 252 |
+
|
| 253 |
+
# num_workers = min(num_workers, len(files))
|
| 254 |
+
num_workers = min(32, len(files))
|
| 255 |
+
|
| 256 |
+
use_threads = True
|
| 257 |
+
if use_threads:
|
| 258 |
+
pool_cls = concurrent.futures.ThreadPoolExecutor
|
| 259 |
+
pool_kwargs = {"max_workers": num_workers}
|
| 260 |
+
else:
|
| 261 |
+
pool_cls = Pool
|
| 262 |
+
pool_kwargs = {"processes": num_workers}
|
| 263 |
+
|
| 264 |
+
with pool_cls(**pool_kwargs) as pool:
|
| 265 |
+
if len(files) > 10:
|
| 266 |
+
files = tqdm_if_main_worker(
|
| 267 |
+
files,
|
| 268 |
+
desc=f"Loading {len(files)} files with {num_workers} workers",
|
| 269 |
+
total=len(files),
|
| 270 |
+
colour="#ffbd88"
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
result = list(
|
| 274 |
+
pool.map(datasets.table.MemoryMappedTable.from_file, files)
|
| 275 |
+
)
|
| 276 |
+
return result
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def datasets_fast_load_from_disk(cache_path: str) -> datasets.Dataset:
|
| 280 |
+
logging.info(f"fast_load_from_disk called with path:", cache_path)
|
| 281 |
+
dataset_info_path = os.path.join(cache_path, "dataset_info.json")
|
| 282 |
+
with open(dataset_info_path, encoding="utf-8") as dataset_info_file:
|
| 283 |
+
dataset_info = datasets.DatasetInfo.from_dict(json.load(dataset_info_file))
|
| 284 |
+
|
| 285 |
+
dataset_state_path = os.path.join(cache_path, "state.json")
|
| 286 |
+
with open(dataset_state_path, encoding="utf-8") as state_file:
|
| 287 |
+
state = json.load(state_file)
|
| 288 |
+
|
| 289 |
+
files = glob.glob(os.path.join(cache_path, "data-*.arrow"))
|
| 290 |
+
files = sorted(files)
|
| 291 |
+
num_workers = get_num_proc()
|
| 292 |
+
ds_tables = load_dataset_tables(
|
| 293 |
+
files=files,
|
| 294 |
+
num_workers=num_workers
|
| 295 |
+
)
|
| 296 |
+
arrow_table = datasets.table.concat_tables(ds_tables)
|
| 297 |
+
|
| 298 |
+
split = state["_split"]
|
| 299 |
+
split = datasets.splits.Split(split) if split is not None else split
|
| 300 |
+
|
| 301 |
+
# print("returning dataset")
|
| 302 |
+
return datasets.Dataset(
|
| 303 |
+
arrow_table=arrow_table,
|
| 304 |
+
info=dataset_info,
|
| 305 |
+
split=split,
|
| 306 |
+
fingerprint=state["_fingerprint"],
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def tokenize_dataset(
|
| 311 |
+
dataset: datasets.Dataset,
|
| 312 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 313 |
+
max_length: int,
|
| 314 |
+
text_key: str,
|
| 315 |
+
padding_strategy: str
|
| 316 |
+
) -> datasets.Dataset:
|
| 317 |
+
def tokenize_text(ex: Dict) -> Dict:
|
| 318 |
+
tt = tokenizer(
|
| 319 |
+
ex[text_key],
|
| 320 |
+
max_length=max_length,
|
| 321 |
+
truncation=True,
|
| 322 |
+
padding=padding_strategy,
|
| 323 |
+
)
|
| 324 |
+
for k,v in tt.items():
|
| 325 |
+
ex[f"{text_key}_{k}"] = v
|
| 326 |
+
ex["length"] = [len(tt) for tt in ex[f"{text_key}_input_ids"]]
|
| 327 |
+
return ex
|
| 328 |
+
|
| 329 |
+
# generate unique hash for tokenizer
|
| 330 |
+
vocab = tokenizer.vocab
|
| 331 |
+
vocab_words = tuple(sorted(vocab.keys(), key=lambda word: vocab[word]))
|
| 332 |
+
vocab_hash = md5_hash(vocab_words)
|
| 333 |
+
|
| 334 |
+
data_fingerprint = '__'.join((
|
| 335 |
+
dataset._fingerprint, str(vocab_hash), str(max_length),
|
| 336 |
+
text_key, padding_strategy
|
| 337 |
+
))
|
| 338 |
+
data_fingerprint = md5_hash(data_fingerprint)
|
| 339 |
+
dataset = dataset.map(
|
| 340 |
+
tokenize_text,
|
| 341 |
+
new_fingerprint=data_fingerprint,
|
| 342 |
+
batched=True,
|
| 343 |
+
load_from_cache_file=True,
|
| 344 |
+
)
|
| 345 |
+
return dataset
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
class TensorRunningAverages:
|
| 349 |
+
_store_sum: Dict[str, torch.Tensor]
|
| 350 |
+
_store_total: Dict[str, torch.Tensor]
|
| 351 |
+
|
| 352 |
+
def __init__(self):
|
| 353 |
+
self._store_sum = {}
|
| 354 |
+
self._store_total = {}
|
| 355 |
+
|
| 356 |
+
def __iter__(self) -> Iterable[str]:
|
| 357 |
+
return iter(self._store_sum.keys())
|
| 358 |
+
|
| 359 |
+
def update(self, key: str, val: Union[int, float, torch.Tensor]) -> None:
|
| 360 |
+
if key not in self._store_sum:
|
| 361 |
+
self.clear(key)
|
| 362 |
+
if isinstance(val, torch.Tensor):
|
| 363 |
+
val = val.item() # tensor -> num
|
| 364 |
+
self._store_sum[key] += val
|
| 365 |
+
self._store_total[key] += 1
|
| 366 |
+
|
| 367 |
+
def get(self, key: str) -> float:
|
| 368 |
+
total = max(self._store_total.get(key).item(), 1.0)
|
| 369 |
+
return (self._store_sum[key] / float(total)).item() or 0.0
|
| 370 |
+
|
| 371 |
+
def clear(self, key: str) -> None:
|
| 372 |
+
self._store_sum[key] = torch.tensor(0.0, dtype=torch.float32)
|
| 373 |
+
self._store_total[key] = torch.tensor(0, dtype=torch.int32)
|
| 374 |
+
|
| 375 |
+
def clear_all(self) -> None:
|
| 376 |
+
for key in self._store_sum:
|
| 377 |
+
self.clear(key)
|
| 378 |
+
|
| 379 |
+
def get_and_clear_all(self) -> Dict[str, float]:
|
| 380 |
+
metrics = {}
|
| 381 |
+
for key in self:
|
| 382 |
+
metrics[key] = self.get(key)
|
| 383 |
+
self.clear(key)
|
| 384 |
+
return metrics
|
| 385 |
+
|
| 386 |
+
def load_embedder_and_tokenizer(name: str) -> Tuple[
|
| 387 |
+
transformers.PreTrainedModel,
|
| 388 |
+
transformers.PreTrainedTokenizer
|
| 389 |
+
]:
|
| 390 |
+
if name.startswith("nomic") or (name == "bert-base-uncased"):
|
| 391 |
+
from cde.lib.nomic_bert import NomicBertModel
|
| 392 |
+
if name.endswith("--from-scratch"):
|
| 393 |
+
name = name.replace("--from-scratch", "")
|
| 394 |
+
config = transformers.AutoConfig.from_pretrained(name, trust_remote_code=True)
|
| 395 |
+
model = NomicBertModel._from_config(config)
|
| 396 |
+
else:
|
| 397 |
+
model = NomicBertModel.from_pretrained(
|
| 398 |
+
name, add_pooling_layer=False
|
| 399 |
+
)
|
| 400 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(name)
|
| 401 |
+
elif name in ["gtr-base", "gtr_base"]:
|
| 402 |
+
model = transformers.AutoModel.from_pretrained(
|
| 403 |
+
"sentence-transformers/gtr-t5-base"
|
| 404 |
+
).encoder
|
| 405 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 406 |
+
"sentence-transformers/gtr-t5-base"
|
| 407 |
+
)
|
| 408 |
+
elif name == "pile-t5-base-encoder":
|
| 409 |
+
model = transformers.AutoModel.from_pretrained(
|
| 410 |
+
"EleutherAI/pile-t5-base"
|
| 411 |
+
).encoder
|
| 412 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 413 |
+
"EleutherAI/pile-t5-base"
|
| 414 |
+
)
|
| 415 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 416 |
+
elif name == "pile-t5-base-decoder":
|
| 417 |
+
model = transformers.AutoModel.from_pretrained(
|
| 418 |
+
"EleutherAI/pile-t5-base"
|
| 419 |
+
).decoder
|
| 420 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 421 |
+
"EleutherAI/pile-t5-base"
|
| 422 |
+
)
|
| 423 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 424 |
+
elif name.startswith("gpt2") or name.startswith("meta-llama") or ("Llama" in name):
|
| 425 |
+
model = transformers.AutoModelForCausalLM.from_pretrained(
|
| 426 |
+
name,
|
| 427 |
+
# torch_dtype=torch.bfloat16,
|
| 428 |
+
attn_implementation="flash_attention_2",
|
| 429 |
+
low_cpu_mem_usage=True,
|
| 430 |
+
# device_map="auto",
|
| 431 |
+
)
|
| 432 |
+
model.padding_side = "right"
|
| 433 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(name)
|
| 434 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 435 |
+
tokenizer.add_eos_token = True
|
| 436 |
+
else:
|
| 437 |
+
model = transformers.AutoModel.from_pretrained(name, trust_remote_code=True)
|
| 438 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(name)
|
| 439 |
+
|
| 440 |
+
# if use_bettertransformer:
|
| 441 |
+
# from optimum.bettertransformer import BetterTransformer
|
| 442 |
+
# model = BetterTransformer.transform(model)
|
| 443 |
+
return model, tokenizer
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
def inputs_for_key(inputs: Dict[str, torch.Tensor], key: str):
|
| 447 |
+
key += "_"
|
| 448 |
+
return {k.replace(key, ""): v for k,v in inputs.items() if k.startswith(key)}
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def load_model_state_dict_from_path(folder: str) -> Dict:
|
| 452 |
+
checkpoint_folder = transformers.trainer_utils.get_last_checkpoint(folder)
|
| 453 |
+
if checkpoint_folder is None:
|
| 454 |
+
raise FileNotFoundError(f"no checkpoint found in {folder}")
|
| 455 |
+
WEIGHTS_NAME = "model.safetensors"
|
| 456 |
+
weights_path = os.path.join(checkpoint_folder, WEIGHTS_NAME)
|
| 457 |
+
if not os.path.exists(weights_path):
|
| 458 |
+
raise FileNotFoundError(f"no model weights found at {weights_path}")
|
| 459 |
+
return safetensors.torch.load_file(weights_path, device="cpu")
|
| 460 |
+
|
| 461 |
+
def count_cpus() -> int:
|
| 462 |
+
try:
|
| 463 |
+
return len(os.sched_getaffinity(0))
|
| 464 |
+
except AttributeError:
|
| 465 |
+
return multiprocessing.cpu_count()
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
def shuffle_batches(g: torch.Generator, list_of_tensors: List[torch.Tensor]) -> List[int]:
|
| 469 |
+
all_indices = []
|
| 470 |
+
for batch_tensor in tqdm_if_main_worker(list_of_tensors, colour="green", desc="Sampler shuffling per-batch"):
|
| 471 |
+
rand_perm = torch.randperm(len(batch_tensor), generator=g)
|
| 472 |
+
batch_list = batch_tensor[rand_perm].tolist()
|
| 473 |
+
all_indices.extend(batch_list)
|
| 474 |
+
return all_indices
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
# def shuffle_batches_multiproc(g: torch.Generator, list_of_tensors: List[torch.Tensor], num_processes: int = 8) -> List[int]:
|
| 478 |
+
# all_indices = []
|
| 479 |
+
# print(f"Shuffling {len(list_of_tensors)} tensors with {num_processes} workers.")
|
| 480 |
+
# pbar = tqdm_if_main_worker(list_of_tensors, colour="orange", desc=f"Sampler shuffling per-batch (nproc={num_processes})")
|
| 481 |
+
# pool = multiprocessing.Pool(processes=num_processes)
|
| 482 |
+
# chunk_size = len(list_of_tensors) // num_processes
|
| 483 |
+
# chunks = [list_of_tensors[i:i + chunk_size] for i in range(0, len(list_of_tensors), chunk_size)]
|
| 484 |
+
# worker_func = functools.partial(shuffle_batches, g=g)
|
| 485 |
+
# results = pool.map(worker_func, chunks)
|
| 486 |
+
# all_indices = []
|
| 487 |
+
# for result in results:
|
| 488 |
+
# all_indices.extend(result)
|
| 489 |
+
# pbar.update()
|
| 490 |
+
# return all_indices
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
def exit_if_running_or_finished_wandb(
|
| 494 |
+
project_name: str,
|
| 495 |
+
exp_group: str, exp_name: str
|
| 496 |
+
) -> None:
|
| 497 |
+
print("Checking if experiment is already running...")
|
| 498 |
+
import wandb
|
| 499 |
+
|
| 500 |
+
api = wandb.Api()
|
| 501 |
+
running_runs = api.runs(
|
| 502 |
+
path="tti-nomic-7",
|
| 503 |
+
filters={
|
| 504 |
+
"display_name": exp_name,
|
| 505 |
+
"state": {"$regex": "Running|Finished"},
|
| 506 |
+
"config.exp_group": exp_group,
|
| 507 |
+
}
|
| 508 |
+
)
|
| 509 |
+
print("Found", len(running_runs), f"runs with name {exp_name} and group {exp_group} in {project_name}.")
|
| 510 |
+
|
| 511 |
+
if len(running_runs) > 0:
|
| 512 |
+
print("Exiting because experiment is already running or completed.")
|
| 513 |
+
sys.exit(0)
|
| 514 |
+
|
model.py
ADDED
|
@@ -0,0 +1,607 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Optional, Union
|
| 2 |
+
|
| 3 |
+
import copy
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import transformers
|
| 7 |
+
|
| 8 |
+
from cde.lib.dist import print0
|
| 9 |
+
from cde.lib.tensor import mean_pool, mean_pool_3d, mean_pool_weighted, last_token_pool
|
| 10 |
+
|
| 11 |
+
from cde.lib import load_embedder_and_tokenizer, ContextualModelConfig
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def limit_layers(model: transformers.PreTrainedModel, n_layers: int) -> None:
|
| 15 |
+
if hasattr(model, 'transformer'):
|
| 16 |
+
if hasattr(model.transformer, 'h'):
|
| 17 |
+
# gpt2
|
| 18 |
+
model.transformer.h = model.transformer.h[:n_layers]
|
| 19 |
+
else:
|
| 20 |
+
model.transformer.layer = model.transformer.layer[:n_layers]
|
| 21 |
+
elif hasattr(model, 'encoder'):
|
| 22 |
+
if hasattr(model.encoder, 'layers'):
|
| 23 |
+
model.encoder.layers = model.encoder.layers[:n_layers]
|
| 24 |
+
else:
|
| 25 |
+
model.encoder.layer = model.encoder.layer[:n_layers]
|
| 26 |
+
else:
|
| 27 |
+
raise RuntimeError(f"unknown how to limit layers of model {type(model)}")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def disable_dropout(model: torch.nn.Module):
|
| 31 |
+
dropout_modules = [m for m in model.modules() if isinstance(m, torch.nn.Dropout)]
|
| 32 |
+
for m in dropout_modules:
|
| 33 |
+
m.p = 0.0
|
| 34 |
+
print0(
|
| 35 |
+
f"Disabled {len(dropout_modules)} dropout modules from model type {type(model)}"
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def disable_causality(model: torch.nn.Module):
|
| 40 |
+
disabled_modules = 0
|
| 41 |
+
for m in model.modules():
|
| 42 |
+
if hasattr(m, "is_causal"):
|
| 43 |
+
m.is_causal = False
|
| 44 |
+
disabled_modules += 1
|
| 45 |
+
print0(
|
| 46 |
+
f"Set is_causal=False in {disabled_modules} modules from model type {type(model)}"
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
class ContextualModelMixin(nn.Module):
|
| 50 |
+
@property
|
| 51 |
+
def num_corpus_tokens(self) -> int:
|
| 52 |
+
return self.transductive_corpus_size * self.transductive_tokens_per_document
|
| 53 |
+
|
| 54 |
+
def contextual_init(self):
|
| 55 |
+
self.n_soft_prompt = 8
|
| 56 |
+
self.prompt_projection = torch.nn.Sequential(
|
| 57 |
+
torch.nn.Linear(self.hidden_size, self.hidden_size),
|
| 58 |
+
torch.nn.ReLU(),
|
| 59 |
+
torch.nn.Linear(self.hidden_size, self.hidden_size * self.n_soft_prompt)
|
| 60 |
+
)
|
| 61 |
+
self.transductive_corpus_size = vars(self.config).get("transductive_corpus_size", 1)
|
| 62 |
+
self.transductive_tokens_per_document = vars(self.config).get("transductive_tokens_per_document", 1)
|
| 63 |
+
self.randomize_dataset_sequence_order = True
|
| 64 |
+
self.sequence_dropout_prob = vars(self.config).get("transductive_sequence_dropout_prob", 0.0)
|
| 65 |
+
if self.sequence_dropout_prob > 0.0:
|
| 66 |
+
self.sequence_dropout_null_embedding = torch.nn.Parameter(
|
| 67 |
+
torch.randn(self.hidden_size) * 0.01,
|
| 68 |
+
requires_grad = True
|
| 69 |
+
)
|
| 70 |
+
self.output_projection = torch.nn.Sequential(
|
| 71 |
+
torch.nn.Linear(self.hidden_size, self.hidden_size),
|
| 72 |
+
torch.nn.ReLU(),
|
| 73 |
+
torch.nn.Linear(self.hidden_size, self.hidden_size)
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
def _prepare_dataset_embeddings(
|
| 77 |
+
self,
|
| 78 |
+
input_ids: torch.Tensor, dataset_embeddings: torch.Tensor,
|
| 79 |
+
null_dataset_embedding: bool = False,
|
| 80 |
+
) -> torch.Tensor:
|
| 81 |
+
if not isinstance(dataset_embeddings, torch.Tensor):
|
| 82 |
+
dataset_embeddings = torch.tensor(dataset_embeddings)
|
| 83 |
+
|
| 84 |
+
if len(dataset_embeddings.shape) == 2:
|
| 85 |
+
# Auto-expand for a batch.
|
| 86 |
+
dataset_embeddings = dataset_embeddings[None, :, :] # (b, d) -> (1, b, d)
|
| 87 |
+
dataset_embeddings = dataset_embeddings.to(input_ids.device)
|
| 88 |
+
|
| 89 |
+
batch_size = input_ids.shape[0]
|
| 90 |
+
if (self.transductive_tokens_per_document > 1):
|
| 91 |
+
if self.training:
|
| 92 |
+
# Choose N random documents to fill our context window with.
|
| 93 |
+
# This logic is a little confusing but allows us to sample a
|
| 94 |
+
# different batch *per-document*
|
| 95 |
+
assert dataset_embeddings.shape[1] == self.transductive_tokens_per_document
|
| 96 |
+
R = torch.randint(
|
| 97 |
+
low=0,
|
| 98 |
+
high=len(dataset_embeddings),
|
| 99 |
+
size=(batch_size, self.config.transductive_corpus_size),
|
| 100 |
+
device=dataset_embeddings.device
|
| 101 |
+
)
|
| 102 |
+
# TODO make this deterministic somehow for evaluation?
|
| 103 |
+
dataset_embeddings = dataset_embeddings[R].reshape((batch_size, self.num_corpus_tokens, self.hidden_size))
|
| 104 |
+
else:
|
| 105 |
+
dataset_embeddings = dataset_embeddings.reshape((1, self.num_corpus_tokens, self.hidden_size))
|
| 106 |
+
# print("reshaped to dataset_embeddings.shape =", dataset_embeddings.shape)
|
| 107 |
+
|
| 108 |
+
if dataset_embeddings.shape[1] > self.num_corpus_tokens:
|
| 109 |
+
# If too many dataset embeddings are passed in, just take the first N until
|
| 110 |
+
# we have the proper number.
|
| 111 |
+
dataset_embeddings = dataset_embeddings[:, :self.num_corpus_tokens, :]
|
| 112 |
+
|
| 113 |
+
_, corpus_size, _hidden_size = dataset_embeddings.shape
|
| 114 |
+
if _ == 1:
|
| 115 |
+
# Auto-expand for a batch.
|
| 116 |
+
dataset_embeddings = dataset_embeddings.expand((batch_size, -1, -1))
|
| 117 |
+
|
| 118 |
+
if self.training and self.sequence_dropout_prob > 0.0:
|
| 119 |
+
sequence_dropout_mask = (
|
| 120 |
+
torch.rand((batch_size, corpus_size), device=dataset_embeddings.device) < self.sequence_dropout_prob
|
| 121 |
+
)
|
| 122 |
+
null_embeddings = self.sequence_dropout_null_embedding[None, None].expand(batch_size, corpus_size, -1)
|
| 123 |
+
dataset_embeddings = torch.where(
|
| 124 |
+
sequence_dropout_mask[..., None], null_embeddings, dataset_embeddings
|
| 125 |
+
)
|
| 126 |
+
elif null_dataset_embedding:
|
| 127 |
+
null_embeddings = self.sequence_dropout_null_embedding[None, None].expand(batch_size, corpus_size, -1)
|
| 128 |
+
dataset_embeddings = null_embeddings
|
| 129 |
+
|
| 130 |
+
# print(f"[ContextualModelMixin] dataset_embeddings.shape = {dataset_embeddings.shape}")
|
| 131 |
+
|
| 132 |
+
# backbone_max_seq_length = self.backbone.config.max_trained_positions
|
| 133 |
+
# assert batch_size + (2 * self.n_soft_prompt + corpus_size) <= backbone_max_seq_length, "too many hard negatives for backbone model"
|
| 134 |
+
soft_prompt = torch.ones((1, self.hidden_size), device=dataset_embeddings.device, dtype=dataset_embeddings.dtype)
|
| 135 |
+
soft_prompt = self.prompt_projection(soft_prompt).reshape((1, self.n_soft_prompt, self.hidden_size))
|
| 136 |
+
soft_prompt = soft_prompt.expand((len(dataset_embeddings), -1, -1)) # -> (b, 4+b, d) # soft_prompt.repeat((len(input_ids), 1, 1))
|
| 137 |
+
soft_prompt = torch.cat((dataset_embeddings, soft_prompt), dim=1)
|
| 138 |
+
|
| 139 |
+
# print(f"[ContextualModelMixin] soft_prompt.shape = {soft_prompt.shape}")
|
| 140 |
+
|
| 141 |
+
if self.training and self.randomize_dataset_sequence_order:
|
| 142 |
+
randomized_order = torch.stack(
|
| 143 |
+
[
|
| 144 |
+
torch.cat(
|
| 145 |
+
(
|
| 146 |
+
torch.randperm(corpus_size, device=soft_prompt.device),
|
| 147 |
+
torch.arange(self.n_soft_prompt, device=soft_prompt.device) + corpus_size
|
| 148 |
+
), dim=0)
|
| 149 |
+
for _ in range(batch_size)])
|
| 150 |
+
randomized_order = randomized_order.to(soft_prompt.device)
|
| 151 |
+
soft_prompt = soft_prompt.gather(1, randomized_order[..., None].expand_as(soft_prompt))
|
| 152 |
+
|
| 153 |
+
return soft_prompt
|
| 154 |
+
|
| 155 |
+
class BiEncoder(transformers.PreTrainedModel):
|
| 156 |
+
embedder: transformers.PreTrainedModel
|
| 157 |
+
def __init__(
|
| 158 |
+
self,
|
| 159 |
+
config, #: transformers.PreTrainedConfig,
|
| 160 |
+
):
|
| 161 |
+
super().__init__(config=config)
|
| 162 |
+
embedder, _ = load_embedder_and_tokenizer(
|
| 163 |
+
config.embedder,
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
if config.limit_layers:
|
| 167 |
+
print0(f"Limiting layers to {config.limit_layers}")
|
| 168 |
+
limit_layers(embedder, config.limit_layers)
|
| 169 |
+
|
| 170 |
+
self.embedder = embedder
|
| 171 |
+
# if ("t5" in embedder.config.model_type):
|
| 172 |
+
# print0(f"using torch.compile() on embedder of type `{embedder.config.model_type}`")
|
| 173 |
+
# self.embedder = torch.compile(self.embedder)
|
| 174 |
+
self.hidden_size = self.embedder.config.hidden_size
|
| 175 |
+
# Allow pooling to multiple tokens per document
|
| 176 |
+
self.transductive_tokens_per_document = vars(self.config).get("transductive_tokens_per_document", 1)
|
| 177 |
+
self.mlp = torch.nn.Sequential(
|
| 178 |
+
torch.nn.Linear(self.hidden_size, self.hidden_size),
|
| 179 |
+
torch.nn.GELU(),
|
| 180 |
+
torch.nn.Linear(self.hidden_size, self.config.embedding_output_dim or self.hidden_size),
|
| 181 |
+
)
|
| 182 |
+
self.temp = config.logit_scale
|
| 183 |
+
|
| 184 |
+
if config.disable_dropout:
|
| 185 |
+
disable_dropout(self)
|
| 186 |
+
self.pooling_strategy = vars(config).get("pooling_strategy", "mean")
|
| 187 |
+
|
| 188 |
+
def forward(
|
| 189 |
+
self,
|
| 190 |
+
input_ids: torch.Tensor,
|
| 191 |
+
attention_mask: torch.Tensor,
|
| 192 |
+
dataset_input_ids: Optional[torch.Tensor] = None,
|
| 193 |
+
dataset_attention_mask: Optional[torch.Tensor] = None,
|
| 194 |
+
token_type_ids = None,
|
| 195 |
+
output_hidden_states: bool = False,
|
| 196 |
+
) -> torch.Tensor:
|
| 197 |
+
"""
|
| 198 |
+
query_embedding (float torch.Tensor) - shape (batch_size, embedding_dim)
|
| 199 |
+
document_embeddings (float torch.Tensor) - shape (corpus_size, embedding_dim)
|
| 200 |
+
where the corpus_size >= batch_size and is structured like this:
|
| 201 |
+
[d1, d2, d3, hn1_1, hn1_2, hn2_1, hn2_2, hn3_1, hn3_2]
|
| 202 |
+
for a corpus with three documents and two hard negatives per document
|
| 203 |
+
"""
|
| 204 |
+
# del dataset_input_ids
|
| 205 |
+
# del dataset_attention_mask
|
| 206 |
+
del token_type_ids
|
| 207 |
+
|
| 208 |
+
# from cde.lib.dist import get_rank
|
| 209 |
+
# tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-uncased")
|
| 210 |
+
# if get_rank() == 0:
|
| 211 |
+
# breakpoint()
|
| 212 |
+
# torch.distributed.barrier()
|
| 213 |
+
outputs = (
|
| 214 |
+
self.embedder(
|
| 215 |
+
input_ids=input_ids,
|
| 216 |
+
attention_mask=attention_mask,
|
| 217 |
+
).last_hidden_state
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
if self.transductive_tokens_per_document > 1:
|
| 221 |
+
document_embeddings = None
|
| 222 |
+
batch_size, seq_length, output_dim = outputs.shape
|
| 223 |
+
|
| 224 |
+
if seq_length % self.transductive_tokens_per_document != 0:
|
| 225 |
+
# Pad to nearest multiple
|
| 226 |
+
n_extra_embeds = self.transductive_tokens_per_document - (seq_length % self.transductive_tokens_per_document)
|
| 227 |
+
outputs = torch.cat(
|
| 228 |
+
(outputs, torch.zeros((batch_size, n_extra_embeds, output_dim), device=outputs.device)),
|
| 229 |
+
dim=1
|
| 230 |
+
)
|
| 231 |
+
attention_mask = torch.cat(
|
| 232 |
+
(attention_mask, torch.zeros((batch_size, n_extra_embeds), device=attention_mask.device)),
|
| 233 |
+
dim=1
|
| 234 |
+
)
|
| 235 |
+
seq_length += n_extra_embeds
|
| 236 |
+
print(f"Added {n_extra_embeds} padding tokens to input_ids and attention_mask")
|
| 237 |
+
|
| 238 |
+
# print("ftransductive_tokens_per_document {self.transductive_tokens_per_document} outputs.shape =", outputs.shape)
|
| 239 |
+
|
| 240 |
+
outputs = outputs.reshape(
|
| 241 |
+
(batch_size, self.transductive_tokens_per_document, seq_length // self.transductive_tokens_per_document, output_dim)
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
attention_mask = attention_mask.reshape((batch_size, self.transductive_tokens_per_document, -1))
|
| 245 |
+
document_embeddings = mean_pool_3d(outputs, attention_mask)
|
| 246 |
+
|
| 247 |
+
document_embeddings = document_embeddings.reshape((batch_size, self.transductive_tokens_per_document, output_dim))
|
| 248 |
+
else:
|
| 249 |
+
if self.pooling_strategy == "mean":
|
| 250 |
+
document_embeddings = mean_pool(outputs, attention_mask)
|
| 251 |
+
else:
|
| 252 |
+
document_embeddings = document_embeddings.max(dim=1)
|
| 253 |
+
output = self.mlp(document_embeddings)
|
| 254 |
+
|
| 255 |
+
if output_hidden_states:
|
| 256 |
+
return {
|
| 257 |
+
"hidden_states": outputs,
|
| 258 |
+
"pooled": output,
|
| 259 |
+
}
|
| 260 |
+
else:
|
| 261 |
+
return output
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
class DatasetConditionedAutoregressive(transformers.PreTrainedModel, ContextualModelMixin):
|
| 265 |
+
def __init__(
|
| 266 |
+
self,
|
| 267 |
+
config,
|
| 268 |
+
dataset_backbone: transformers.PreTrainedModel,
|
| 269 |
+
first_stage_hidden_size: int,
|
| 270 |
+
):
|
| 271 |
+
super().__init__(config=config)
|
| 272 |
+
self.backbone = dataset_backbone
|
| 273 |
+
self.backbone_hidden_size = self.backbone.config.hidden_size
|
| 274 |
+
self.hidden_size = first_stage_hidden_size # Input token size
|
| 275 |
+
self.contextual_init()
|
| 276 |
+
disable_causality(self.backbone)
|
| 277 |
+
|
| 278 |
+
self.input_ln = torch.nn.LayerNorm(
|
| 279 |
+
self.backbone_hidden_size,
|
| 280 |
+
eps=1e-5
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
# Override contextual init
|
| 284 |
+
self.output_projection = torch.nn.Sequential(
|
| 285 |
+
torch.nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size),
|
| 286 |
+
torch.nn.ReLU(),
|
| 287 |
+
torch.nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size)
|
| 288 |
+
)
|
| 289 |
+
self._shift_rotary_embedding()
|
| 290 |
+
|
| 291 |
+
@property
|
| 292 |
+
def num_corpus_tokens(self) -> int:
|
| 293 |
+
return self.config.transductive_corpus_size * self.transductive_tokens_per_document
|
| 294 |
+
|
| 295 |
+
@property
|
| 296 |
+
def corpus_token_ratio(self) -> float:
|
| 297 |
+
# How many tokens from the first stage make one token in the second
|
| 298 |
+
# stage?
|
| 299 |
+
return self.backbone_hidden_size / self.hidden_size
|
| 300 |
+
|
| 301 |
+
def corpus_token_pad_size(self, n_tokens: int) -> int:
|
| 302 |
+
return self.hidden_size % self.backbone_hidden_size
|
| 303 |
+
|
| 304 |
+
def _shift_rotary_embedding(self) -> None:
|
| 305 |
+
disable_transductive_rotary_embedding = vars(self.config).get("disable_transductive_rotary_embedding", True)
|
| 306 |
+
# TODO: Can we do this for LLAMA?
|
| 307 |
+
print("Warning: Positional embedding disabling not implemented for LLAMA.")
|
| 308 |
+
|
| 309 |
+
def forward(
|
| 310 |
+
self,
|
| 311 |
+
input_ids: torch.Tensor,
|
| 312 |
+
attention_mask: torch.Tensor,
|
| 313 |
+
dataset_embeddings: torch.Tensor,
|
| 314 |
+
output_hidden_states: bool = False,
|
| 315 |
+
null_dataset_embedding: bool = False,
|
| 316 |
+
) -> torch.Tensor:
|
| 317 |
+
soft_prompt = self._prepare_dataset_embeddings(
|
| 318 |
+
input_ids=input_ids,
|
| 319 |
+
dataset_embeddings=dataset_embeddings,
|
| 320 |
+
null_dataset_embedding=null_dataset_embedding,
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
# Reshape for this model.
|
| 324 |
+
# print("[DatasetConditionedAutoregressive] 1 -> soft_prompt.shape =", soft_prompt.shape)
|
| 325 |
+
num_soft_elements = torch.prod(torch.tensor(soft_prompt.shape[1:])).item()
|
| 326 |
+
soft_prompt = soft_prompt.reshape((soft_prompt.shape[0], num_soft_elements))
|
| 327 |
+
num_padding_elements = self.backbone_hidden_size - (num_soft_elements % self.backbone_hidden_size)
|
| 328 |
+
padding = torch.ones((soft_prompt.shape[0], num_padding_elements), device=soft_prompt.device)
|
| 329 |
+
soft_prompt = torch.cat((soft_prompt, padding), dim=1)
|
| 330 |
+
soft_prompt = soft_prompt.reshape(
|
| 331 |
+
(soft_prompt.shape[0], -1, self.backbone_hidden_size)
|
| 332 |
+
)
|
| 333 |
+
soft_prompt = self.input_ln(soft_prompt)
|
| 334 |
+
# print("[DatasetConditionedAutoregressive] 2 -> soft_prompt.shape =", soft_prompt.shape)
|
| 335 |
+
|
| 336 |
+
backbone_attention_mask = torch.ones(
|
| 337 |
+
soft_prompt.shape[0:2],
|
| 338 |
+
dtype=torch.long,
|
| 339 |
+
device=soft_prompt.device,
|
| 340 |
+
)
|
| 341 |
+
token_embeddings = self.backbone.get_input_embeddings()
|
| 342 |
+
inputs_embeds = token_embeddings(input_ids) # (b, s) -> (b, s, d)
|
| 343 |
+
# print("[2] inputs_embeds.shape =", inputs_embeds.shape)
|
| 344 |
+
inputs_embeds = torch.cat((soft_prompt, inputs_embeds), dim=1) # (v, 4+b+s, d)
|
| 345 |
+
# print("[3.a] inputs_embeds.shape =", inputs_embeds.shape)
|
| 346 |
+
input_attention_mask = torch.cat((backbone_attention_mask, attention_mask), dim=1)
|
| 347 |
+
# print("[3.b] attention_mask.shape =", attention_mask.shape)
|
| 348 |
+
|
| 349 |
+
output = self.backbone(
|
| 350 |
+
inputs_embeds=inputs_embeds,
|
| 351 |
+
attention_mask=input_attention_mask,
|
| 352 |
+
output_hidden_states=True,
|
| 353 |
+
) # (1, 4 + b + s, d)
|
| 354 |
+
# trim soft prompt
|
| 355 |
+
last_hidden_state = output.hidden_states[-1]
|
| 356 |
+
n_soft_prompt_tokens = soft_prompt.shape[1]
|
| 357 |
+
|
| 358 |
+
output_vectors = last_hidden_state[:, n_soft_prompt_tokens:, :]
|
| 359 |
+
output_attention_mask = input_attention_mask[:, n_soft_prompt_tokens:]
|
| 360 |
+
|
| 361 |
+
# Take last token position
|
| 362 |
+
if vars(self.config).get("pooling_strategy") == "last_token":
|
| 363 |
+
output_pooled = last_token_pool(output_vectors, output_attention_mask)
|
| 364 |
+
elif vars(self.config).get("pooling_strategy") == "mean":
|
| 365 |
+
output_pooled = mean_pool(output_vectors, output_attention_mask)
|
| 366 |
+
else:
|
| 367 |
+
output_pooled = mean_pool_weighted(output_vectors, output_attention_mask)
|
| 368 |
+
|
| 369 |
+
# average with original vectors
|
| 370 |
+
# TODO: Argparse for pooling strategy.
|
| 371 |
+
output = self.output_projection(output_pooled) # (b, 2d) -> (b, d)
|
| 372 |
+
|
| 373 |
+
if output_hidden_states:
|
| 374 |
+
return {
|
| 375 |
+
"hidden_states": output_vectors,
|
| 376 |
+
"pooled": output,
|
| 377 |
+
}
|
| 378 |
+
else:
|
| 379 |
+
return output
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
class DatasetConditionedBiencoder(transformers.PreTrainedModel, ContextualModelMixin):
|
| 383 |
+
def __init__(
|
| 384 |
+
self,
|
| 385 |
+
config,
|
| 386 |
+
dataset_backbone: transformers.PreTrainedModel,
|
| 387 |
+
):
|
| 388 |
+
super().__init__(config=config)
|
| 389 |
+
self.backbone = dataset_backbone
|
| 390 |
+
self.hidden_size = self.backbone.config.hidden_size
|
| 391 |
+
self.hidden_size = dataset_backbone.config.hidden_size
|
| 392 |
+
# self.input_ln = torch.nn.LayerNorm(
|
| 393 |
+
# self.hidden_size,
|
| 394 |
+
# eps=self.backbone.config.layer_norm_epsilon
|
| 395 |
+
# )
|
| 396 |
+
self.contextual_init()
|
| 397 |
+
self._shift_rotary_embedding()
|
| 398 |
+
|
| 399 |
+
@property
|
| 400 |
+
def num_corpus_tokens(self) -> int:
|
| 401 |
+
return self.config.transductive_corpus_size * self.transductive_tokens_per_document
|
| 402 |
+
|
| 403 |
+
def _shift_rotary_embedding(self) -> None:
|
| 404 |
+
disable_transductive_rotary_embedding = vars(self.config).get("disable_transductive_rotary_embedding", True)
|
| 405 |
+
if self.backbone.config.model_type.startswith("nomic") and disable_transductive_rotary_embedding:
|
| 406 |
+
# We only want to apply positional embeddings to the
|
| 407 |
+
# *text* portion of the backbone network.
|
| 408 |
+
self.backbone.config.rotary_start_pos = 0.0
|
| 409 |
+
rotary_disabled = 0
|
| 410 |
+
|
| 411 |
+
rotary_start_pos = self.num_corpus_tokens
|
| 412 |
+
for module in self.backbone.modules():
|
| 413 |
+
if hasattr(module, "rotary_emb_dim"):
|
| 414 |
+
module.rotary_start_pos = rotary_start_pos
|
| 415 |
+
rotary_disabled += 1
|
| 416 |
+
print0(f"modified {rotary_disabled} rotary modules – set rotary_start_pos to {rotary_start_pos}")
|
| 417 |
+
|
| 418 |
+
def forward(
|
| 419 |
+
self,
|
| 420 |
+
input_ids: torch.Tensor,
|
| 421 |
+
attention_mask: torch.Tensor,
|
| 422 |
+
dataset_embeddings: torch.Tensor,
|
| 423 |
+
output_hidden_states: bool = False,
|
| 424 |
+
null_dataset_embedding: bool = False,
|
| 425 |
+
) -> torch.Tensor:
|
| 426 |
+
# print(f"[DatasetConditionedBiencoder - 0] input_ids.shape => {input_ids.shape} // dataset_embeddings.shape =", dataset_embeddings.shape)
|
| 427 |
+
soft_prompt = self._prepare_dataset_embeddings(
|
| 428 |
+
input_ids=input_ids,
|
| 429 |
+
dataset_embeddings=dataset_embeddings,
|
| 430 |
+
null_dataset_embedding=null_dataset_embedding,
|
| 431 |
+
)
|
| 432 |
+
# print(f"[DatasetConditionedBiencoder - 1] soft_prompt.shape => {soft_prompt.shape}")
|
| 433 |
+
backbone_attention_mask = torch.ones(
|
| 434 |
+
soft_prompt.shape[0:2],
|
| 435 |
+
dtype=torch.long,
|
| 436 |
+
device=soft_prompt.device,
|
| 437 |
+
)
|
| 438 |
+
inputs_embeds = self.backbone.embeddings(input_ids) # (b, s) -> (b, s, d)
|
| 439 |
+
# print("[2] inputs_embeds.shape =", inputs_embeds.shape)
|
| 440 |
+
inputs_embeds = torch.cat((soft_prompt, inputs_embeds), dim=1) # (v, 4+b+s, d)
|
| 441 |
+
# print("[3.a] inputs_embeds.shape =", inputs_embeds.shape)
|
| 442 |
+
attention_mask = torch.cat((backbone_attention_mask, attention_mask), dim=1)
|
| 443 |
+
# print("[3.b] attention_mask.shape =", attention_mask.shape)
|
| 444 |
+
output = self.backbone(
|
| 445 |
+
inputs_embeds=inputs_embeds,
|
| 446 |
+
attention_mask=attention_mask,
|
| 447 |
+
) # (1, 4 + b + s, d)
|
| 448 |
+
# trim soft prompt
|
| 449 |
+
output_vectors = output.last_hidden_state
|
| 450 |
+
|
| 451 |
+
# use only these tokens
|
| 452 |
+
n_soft_prompt_tokens = soft_prompt.shape[1]
|
| 453 |
+
# print("n_soft_prompt_tokens =", n_soft_prompt_tokens)
|
| 454 |
+
|
| 455 |
+
output_vectors = output.last_hidden_state[:, n_soft_prompt_tokens:, :]
|
| 456 |
+
output_attention_mask = attention_mask[:, n_soft_prompt_tokens:]
|
| 457 |
+
|
| 458 |
+
# print("pooling output_vectors.shape =", output_vectors.shape, "and output_attention_mask.shape =", output_attention_mask.shape)
|
| 459 |
+
output_pooled = mean_pool(output_vectors, output_attention_mask)
|
| 460 |
+
|
| 461 |
+
# average with original vectors
|
| 462 |
+
# TODO: Argparse for pooling strategy.
|
| 463 |
+
# output_vectors = torch.cat((soft_prompt_pooled, output_pooled), dim=1) # (b, d) + (b, d) -> (b, 2d)
|
| 464 |
+
# print("output_pooled.shape =", output_pooled.shape)
|
| 465 |
+
output = self.output_projection(output_pooled) # (b, 2d) -> (b, d)
|
| 466 |
+
|
| 467 |
+
# print("returning output.shape =", output.shape)
|
| 468 |
+
|
| 469 |
+
if output_hidden_states:
|
| 470 |
+
return {
|
| 471 |
+
"hidden_states": output_vectors,
|
| 472 |
+
"pooled": output,
|
| 473 |
+
}
|
| 474 |
+
else:
|
| 475 |
+
return output
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
class DatasetPrefixBiencoder(transformers.PreTrainedModel, ContextualModelMixin):
|
| 479 |
+
def __init__(
|
| 480 |
+
self,
|
| 481 |
+
config, #: transformers.PreTrainedConfig,
|
| 482 |
+
embedder: transformers.PreTrainedModel,
|
| 483 |
+
):
|
| 484 |
+
super().__init__(config=config)
|
| 485 |
+
self.embedder = embedder
|
| 486 |
+
self.hidden_size = self.embedder.config.hidden_size
|
| 487 |
+
self.contextual_init()
|
| 488 |
+
|
| 489 |
+
def forward(
|
| 490 |
+
self,
|
| 491 |
+
input_ids: torch.Tensor,
|
| 492 |
+
attention_mask: torch.Tensor,
|
| 493 |
+
dataset_input_ids: torch.Tensor,
|
| 494 |
+
dataset_attention_mask: torch.Tensor,
|
| 495 |
+
output_hidden_states: bool = False,
|
| 496 |
+
) -> torch.Tensor:
|
| 497 |
+
R = torch.randint(low=0, high=len(dataset_input_ids), size=(len(input_ids),), device=dataset_input_ids.device)
|
| 498 |
+
|
| 499 |
+
dataset_input_ids = dataset_input_ids[R]
|
| 500 |
+
input_ids = torch.cat((dataset_input_ids, input_ids), dim=1)
|
| 501 |
+
|
| 502 |
+
dataset_attention_mask = torch.ones_like(dataset_attention_mask, device=dataset_attention_mask.device)
|
| 503 |
+
input_attention_mask = torch.cat((dataset_attention_mask, attention_mask), dim=1)
|
| 504 |
+
output_attention_mask = torch.cat(
|
| 505 |
+
(torch.zeros_like(dataset_input_ids), attention_mask), dim=1
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
output = self.embedder(
|
| 509 |
+
input_ids=input_ids,
|
| 510 |
+
attention_mask=input_attention_mask,
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
output_vectors = output.last_hidden_state
|
| 514 |
+
output_pooled = mean_pool(output_vectors, output_attention_mask)
|
| 515 |
+
output = self.output_projection(output_pooled) # (b, 2d) -> (b, d)
|
| 516 |
+
|
| 517 |
+
if output_hidden_states:
|
| 518 |
+
S_d = dataset_attention_mask.shape[1]
|
| 519 |
+
output_vectors = output_vectors[:, S_d:, :]
|
| 520 |
+
return {
|
| 521 |
+
"hidden_states": output_vectors,
|
| 522 |
+
"pooled": output,
|
| 523 |
+
}
|
| 524 |
+
else:
|
| 525 |
+
return output
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
class DatasetTransformer(transformers.PreTrainedModel):
|
| 529 |
+
config_class = ContextualModelConfig
|
| 530 |
+
embedder: transformers.PreTrainedModel
|
| 531 |
+
dataset_backbone: transformers.PreTrainedModel
|
| 532 |
+
def __init__(
|
| 533 |
+
self,
|
| 534 |
+
config,
|
| 535 |
+
):
|
| 536 |
+
super().__init__(config=config)
|
| 537 |
+
dataset_backbone, _ = load_embedder_and_tokenizer(
|
| 538 |
+
vars(config).get("dataset_backbone", config.embedder)
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
if config.limit_layers:
|
| 542 |
+
print0(f"Limiting layers to {config.limit_layers}")
|
| 543 |
+
limit_layers(dataset_backbone, config.limit_layers)
|
| 544 |
+
|
| 545 |
+
biencoder_config = copy.deepcopy(config)
|
| 546 |
+
biencoder_config.embedding_output_dim = None
|
| 547 |
+
biencoder_config.limit_layers = vars(self.config).get("limit_layers_first_stage", None)
|
| 548 |
+
self.first_stage_model = BiEncoder(
|
| 549 |
+
config=biencoder_config,
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
if vars(config).get("autoregressive_backbone", False):
|
| 553 |
+
self.second_stage_model = DatasetConditionedAutoregressive(
|
| 554 |
+
config=config,
|
| 555 |
+
dataset_backbone=dataset_backbone,
|
| 556 |
+
first_stage_hidden_size=self.first_stage_model.hidden_size,
|
| 557 |
+
)
|
| 558 |
+
else:
|
| 559 |
+
self.second_stage_model = DatasetConditionedBiencoder(
|
| 560 |
+
config=config,
|
| 561 |
+
dataset_backbone=dataset_backbone
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
self.temp = config.logit_scale
|
| 565 |
+
if config.disable_dropout:
|
| 566 |
+
disable_dropout(self)
|
| 567 |
+
|
| 568 |
+
transductive_tie_token_embeddings = vars(self.config).get("transductive_tie_token_embeddings", False)
|
| 569 |
+
if transductive_tie_token_embeddings:
|
| 570 |
+
self.second_stage_model.backbone.embeddings.word_embeddings.weight = (
|
| 571 |
+
self.first_stage_model.embedder.embeddings.word_embeddings.weight
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
def forward(
|
| 575 |
+
self,
|
| 576 |
+
input_ids: torch.Tensor,
|
| 577 |
+
attention_mask: torch.Tensor,
|
| 578 |
+
dataset_input_ids: Optional[torch.Tensor],
|
| 579 |
+
dataset_attention_mask: Optional[torch.Tensor],
|
| 580 |
+
output_hidden_states: bool = False,
|
| 581 |
+
) -> torch.Tensor:
|
| 582 |
+
"""
|
| 583 |
+
input_ids (long torch.Tensor) – ids of input tokens
|
| 584 |
+
attention_mask (bool torch.Tensor)
|
| 585 |
+
"""
|
| 586 |
+
dataset_embeddings = self.first_stage_model(
|
| 587 |
+
input_ids=dataset_input_ids,
|
| 588 |
+
attention_mask=dataset_attention_mask
|
| 589 |
+
)
|
| 590 |
+
return self.second_stage_model(
|
| 591 |
+
input_ids=input_ids,
|
| 592 |
+
attention_mask=attention_mask,
|
| 593 |
+
dataset_embeddings=dataset_embeddings,
|
| 594 |
+
output_hidden_states=output_hidden_states,
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
def get_model_class(name: str):
|
| 600 |
+
if name in 'transductive':
|
| 601 |
+
return DatasetTransformer
|
| 602 |
+
elif name == 'biencoder':
|
| 603 |
+
return BiEncoder
|
| 604 |
+
elif name == "dataset_prefix_biencoder":
|
| 605 |
+
return DatasetPrefixBiencoder
|
| 606 |
+
else:
|
| 607 |
+
raise ValueError(f'unknown model cls {name}')
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6ec79407ada665817aebe929bdabbe83eecd816b75f7f26e3bdd8b4c092efb2a
|
| 3 |
+
size 1124594680
|