dzenzzz commited on
Commit
9e9178e
·
1 Parent(s): beb4147

revert changes

Browse files
__pycache__/app.cpython-311.pyc ADDED
Binary file (2.05 kB). View file
 
__pycache__/neural_searcher.cpython-311.pyc ADDED
Binary file (2.84 kB). View file
 
app.py CHANGED
@@ -13,11 +13,11 @@ app = FastAPI()
13
 
14
  neural_searcher = NeuralSearcher(collection_name=os.getenv('COLLECTION_NAME'))
15
 
16
- REQUEST_TIMEOUT_ERROR = 1
17
 
18
  @app.get("/api/search")
19
- def search(q: str):
20
- data = neural_searcher.search(text=q)
21
  return data
22
 
23
 
 
13
 
14
  neural_searcher = NeuralSearcher(collection_name=os.getenv('COLLECTION_NAME'))
15
 
16
+ REQUEST_TIMEOUT_ERROR = 30
17
 
18
  @app.get("/api/search")
19
+ async def search(q: str):
20
+ data = await neural_searcher.search(text=q)
21
  return data
22
 
23
 
ner.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForTokenClassification, AutoTokenizer
2
+ import torch
3
+
4
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5
+ tokenizer = AutoTokenizer.from_pretrained("kalusev/NER4Legal_SRB", use_auth_token=True)
6
+ model = AutoModelForTokenClassification.from_pretrained("kalusev/NER4Legal_SRB", use_auth_token=True).to(device)
7
+
8
+ id_to_label = {
9
+ 0: 'O',
10
+ 1: 'B-COURT',
11
+ 2: 'B-DATE',
12
+ 3: 'B-DECISION',
13
+ 4: 'B-LAW',
14
+ 5: 'B-MONEY',
15
+ 6: 'B-OFFICIAL GAZZETE',
16
+ 7: 'B-PERSON',
17
+ 8: 'B-REFERENCE',
18
+ 9: 'I-COURT',
19
+ 10: 'I-LAW',
20
+ 11: 'I-MONEY',
21
+ 12: 'I-OFFICIAL GAZZETE',
22
+ 13: 'I-PERSON',
23
+ 14: 'I-REFERENCE'
24
+ }
25
+
26
+ def perform_ner(text):
27
+ try:
28
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
29
+ with torch.no_grad():
30
+ outputs = model(**inputs)
31
+ logits = outputs.logits
32
+ predictions = torch.argmax(logits, dim=2).squeeze().tolist()
33
+
34
+ except RuntimeError as e:
35
+ if "CUDA out of memory" in str(e):
36
+ print("Switching to CPU due to memory constraints.")
37
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
38
+ with torch.no_grad():
39
+ outputs = model.cpu()(**inputs) # Run model on CPU
40
+ logits = outputs.logits
41
+ predictions = torch.argmax(logits, dim=2).squeeze().tolist()
42
+ else:
43
+ raise e
44
+ tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"].squeeze())
45
+ labels = [id_to_label[pred] for pred in predictions]
46
+
47
+ results = [
48
+ (token, label)
49
+ for token, label in zip(tokens, labels)
50
+ if token not in tokenizer.all_special_tokens
51
+ ]
52
+ return results
53
+
54
+ text = """1
55
+ osnovni sud u bijelom polju je vrsio veliku nuzdu
56
+ """
57
+
58
+ def merge_entities(token_label_pairs):
59
+ merged_words, merged_labels = [], []
60
+ current_word, current_label = "", None
61
+
62
+ for token, label in token_label_pairs:
63
+ if token.startswith("##"):
64
+ current_word += token[2:]
65
+ else:
66
+ if current_word:
67
+ merged_words.append(current_word)
68
+ merged_labels.append(current_label)
69
+
70
+ current_word, current_label = token, label
71
+
72
+ if current_word:
73
+ merged_words.append(current_word)
74
+ merged_labels.append(current_label)
75
+
76
+ final_words, final_labels = [], []
77
+
78
+ for i, (word, label) in enumerate(zip(merged_words, merged_labels)):
79
+ if final_labels and (
80
+ label == final_labels[-1] or
81
+ (label.startswith("I-") and final_labels[-1].endswith(label[2:])) or
82
+ (label.startswith("B-") and final_labels[-1].endswith(label[2:]))
83
+ ):
84
+
85
+ final_words[-1] += " " + word
86
+ else:
87
+ final_words.append(word)
88
+ final_labels.append(label)
89
+
90
+ return final_words, final_labels
91
+
92
+ results = perform_ner(text)
93
+
94
+ words,labels = merge_entities(results)
95
+
96
+ for i,b in zip(words,labels):
97
+ print(i + " ### " + b)
neural_searcher.py CHANGED
@@ -5,45 +5,47 @@ from sentence_transformers import SentenceTransformer
5
  import os
6
 
7
  class NeuralSearcher:
8
-
9
  def __init__(self, collection_name):
10
  self.collection_name = collection_name
11
  self.dense_model = SentenceTransformer(os.getenv('DENSE_MODEL'),device="cpu")
12
  self.sparse_model = SparseTextEmbedding(os.getenv('SPARSE_MODEL'))
13
  self.late_interaction_model = LateInteractionTextEmbedding(os.getenv('LATE_INTERACTION_MODEL'))
14
- self.qdrant_client = QdrantClient(os.getenv('QDRANT_URL'),api_key=os.getenv('QDRANT_API_KEY'),https=True)
15
 
16
- def search(self, text: str):
17
-
18
  dense_query = self.dense_model.encode(text).tolist()
19
  sparse_query = next(self.sparse_model.query_embed(text))
20
- late_query = next(self.late_interaction_model.query_embed(text))
21
 
22
  prefetch = [
23
  models.Prefetch(
24
  query=dense_query,
25
  using=os.getenv('DENSE_MODEL'),
26
- limit=100
27
  ),
28
  models.Prefetch(
29
  query=models.SparseVector(**sparse_query.as_object()),
30
  using=os.getenv('SPARSE_MODEL'),
31
- limit=100
32
  )
33
  ]
34
 
35
- search_result = self.qdrant_client.query_points_groups(
36
  collection_name= self.collection_name,
37
- group_by="dbid",
38
  prefetch=prefetch,
39
- group_size=3,
40
- query=late_query,
41
- using=os.getenv('LATE_INTERACTION_MODEL'),
 
42
  with_payload=True,
43
- score_threshold=0.8,
44
  limit = 10
45
- ).groups
 
 
 
 
 
46
 
47
- for group in search_result:
48
- print(group)
49
- return search_result
 
5
  import os
6
 
7
  class NeuralSearcher:
8
+
9
  def __init__(self, collection_name):
10
  self.collection_name = collection_name
11
  self.dense_model = SentenceTransformer(os.getenv('DENSE_MODEL'),device="cpu")
12
  self.sparse_model = SparseTextEmbedding(os.getenv('SPARSE_MODEL'))
13
  self.late_interaction_model = LateInteractionTextEmbedding(os.getenv('LATE_INTERACTION_MODEL'))
14
+ self.qdrant_client = QdrantClient(os.getenv('QDRANT_URL'),api_key=os.getenv('QDRANT_API_KEY'))
15
 
16
+ async def search(self, text: str):
17
+
18
  dense_query = self.dense_model.encode(text).tolist()
19
  sparse_query = next(self.sparse_model.query_embed(text))
20
+ # late_query = next(self.late_interaction_model.query_embed(text))
21
 
22
  prefetch = [
23
  models.Prefetch(
24
  query=dense_query,
25
  using=os.getenv('DENSE_MODEL'),
26
+ limit=200
27
  ),
28
  models.Prefetch(
29
  query=models.SparseVector(**sparse_query.as_object()),
30
  using=os.getenv('SPARSE_MODEL'),
31
+ limit=200
32
  )
33
  ]
34
 
35
+ search_result = self.qdrant_client.query_points(
36
  collection_name= self.collection_name,
 
37
  prefetch=prefetch,
38
+ query=models.FusionQuery(
39
+ fusion=models.Fusion.RRF,
40
+ ),
41
+ # using=os.getenv('LATE_INTERACTION_MODEL'),
42
  with_payload=True,
 
43
  limit = 10
44
+ ).points
45
+
46
+ data = []
47
+
48
+ for hit in search_result:
49
+ data.append(hit.payload)
50
 
51
+ return data