SituatedEmbedding commited on
Commit
9bb59e4
·
verified ·
1 Parent(s): 0e90098

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +141 -3
README.md CHANGED
@@ -1,3 +1,141 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ base_model:
4
+ - Qwen/Qwen3-Embedding-8B
5
+ pipeline_tag: sentence-similarity
6
+ ---
7
+
8
+ The model of SitEmb-v1.5-Qwen3.
9
+
10
+ ### Transformer Usage
11
+ ```python
12
+ import torch
13
+
14
+ from transformers import AutoTokenizer, AutoModel
15
+
16
+ residual = True
17
+ residual_factor = 0.5
18
+
19
+ tokenizer = AutoTokenizer.from_pretrained(
20
+ "Qwen/Qwen3-Embedding-8B",
21
+ use_fast=True,
22
+ padding_side='left',
23
+ )
24
+
25
+ model = AutoModel.from_pretrained(
26
+ "lossisnotanumber/sit_test",
27
+ torch_dtype=torch.bfloat16 if args.bf16 else torch.float32,
28
+ device_map={"": 0},
29
+ )
30
+
31
+ def _pooling(last_hidden_state, attention_mask, pooling, normalize, input_ids=None, match_idx=None):
32
+ if pooling in ['cls', 'first']:
33
+ reps = last_hidden_state[:, 0]
34
+ elif pooling in ['mean', 'avg', 'average']:
35
+ masked_hiddens = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
36
+ reps = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
37
+ elif pooling in ['last', 'eos']:
38
+ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
39
+ if left_padding:
40
+ reps = last_hidden_state[:, -1]
41
+ else:
42
+ sequence_lengths = attention_mask.sum(dim=1) - 1
43
+ batch_size = last_hidden_state.shape[0]
44
+ reps = last_hidden_state[torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths]
45
+ elif pooling == 'ext':
46
+ if match_idx is None:
47
+ # default mean
48
+ masked_hiddens = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
49
+ reps = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
50
+ else:
51
+ for k in range(input_ids.shape[0]):
52
+ sep_index = input_ids[k].tolist().index(match_idx)
53
+ attention_mask[k][sep_index:] = 0
54
+ masked_hiddens = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
55
+ reps = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
56
+ else:
57
+ raise ValueError(f'unknown pooling method: {pooling}')
58
+ if normalize:
59
+ reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
60
+ return reps
61
+
62
+
63
+ def first_eos_token_pooling(
64
+ last_hidden_states,
65
+ first_eos_position,
66
+ normalize,
67
+ ):
68
+ batch_size = last_hidden_states.shape[0]
69
+ reps = last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), first_eos_position]
70
+ if normalize:
71
+ reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
72
+ return reps
73
+
74
+ def encode_query(tokenizer, model, pooling, queries, batch_size, normalize, max_length, residual):
75
+ if residual:
76
+ task = "Given a search query, retrieve relevant chunks from fictions that answer the query"
77
+ else:
78
+ task = "Given a web search query, retrieve relevant passages that answer the query"
79
+ sents = []
80
+ for query in queries:
81
+ sents.append(get_detailed_instruct(task, query))
82
+
83
+ return encode_passage(tokenizer, model, pooling, sents, batch_size, normalize, max_length)
84
+
85
+
86
+ def encode_passage(tokenizer, model, pooling, passages, batch_size, normalize, max_length, residual=False):
87
+ pas_embs = []
88
+ pas_embs_residual = []
89
+ total = len(passages) // batch_size + (1 if len(passages) % batch_size != 0 else 0)
90
+ with tqdm(total=total) as pbar:
91
+ for sent_b in chunked(passages, batch_size):
92
+ batch_dict = tokenizer(sent_b, max_length=max_length, padding=True, truncation=True,
93
+ return_tensors='pt').to(model.device)
94
+ if residual:
95
+ batch_list_dict = tokenizer(sent_b, max_length=max_length, padding=True, truncation=True, )
96
+ input_ids = batch_list_dict['input_ids']
97
+ attention_mask = batch_list_dict['attention_mask']
98
+ max_len = len(input_ids[0])
99
+ input_starts = [max_len - sum(att) for att in attention_mask]
100
+ eos_pos = []
101
+ for ii, it in zip(input_ids, input_starts):
102
+ pos = ii.index(tokenizer.pad_token_id, it)
103
+ eos_pos.append(pos)
104
+ eos_pos = torch.tensor(eos_pos).to(model.device)
105
+ else:
106
+ eos_pos = None
107
+ outputs = model(**batch_dict)
108
+ pemb_ = _pooling(outputs.last_hidden_state, batch_dict['attention_mask'], pooling, normalize)
109
+ if residual:
110
+ remb_ = first_eos_token_pooling(outputs.last_hidden_state, eos_pos, normalize)
111
+ pas_embs_residual.append(remb_)
112
+ pas_embs.append(pemb_)
113
+ pbar.update(1)
114
+ pas_embs = torch.cat(pas_embs, dim=0)
115
+ if pas_embs_residual:
116
+ pas_embs_residual = torch.cat(pas_embs_residual, dim=0)
117
+ else:
118
+ pas_embs_residual = None
119
+ return pas_embs, pas_embs_residual
120
+
121
+ query_hidden, _ = encode_query(
122
+ tokenizer, model, pooling_type="eos", queries=["Your query"],
123
+ batch_size=8, normalize=True, max_length=8192, residual=residual,
124
+ )
125
+ candidate_hidden, candidate_hidden_residual = encode_passage(
126
+ tokenizer, model, pooling_type="eos", passages=["Your chunk<|endoftext|>Your context"],
127
+ batch_size=4, normalize=True, max_length=8192, residual=residual,
128
+ )
129
+
130
+ query2candidate = query_hidden @ candidate_hidden.T # [num_queries, num_candidates]
131
+ if candidate_hidden_residual is not None:
132
+ query2candidate_residual = query_hidden @ candidate_hidden_residual.T
133
+ if residual_factor == 1.:
134
+ query2candidate = query2candidate_residual
135
+ elif residual_factor == 0.:
136
+ pass
137
+ else:
138
+ query2candidate = query2candidate * (1. - residual_factor) + query2candidate_residual * residual_factor
139
+
140
+ print(query2candidate.tolist())
141
+ ```