Spaces:
Sleeping
Sleeping
File size: 4,053 Bytes
04595e7 9a45764 04595e7 9a45764 04595e7 9a45764 04595e7 9a45764 04595e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import warnings
from datetime import datetime
from typing import Any, Optional
from qdrant_client import QdrantClient
from qdrant_client.http.models import QueryResponse
from qdrant_client.models import FieldCondition, Filter, MatchValue, Range
from article_embedding.embed import StellaEmbedder
from article_embedding.utils import env_str
warnings.simplefilter(action="ignore", category=FutureWarning)
def as_timestamp(date: datetime | str) -> float:
if isinstance(date, datetime):
return date.timestamp()
return datetime.strptime(date, "%Y-%m-%d").timestamp()
def make_date_condition(
*, field: str = "published", date_from: datetime | str | None = None, date_to: datetime | str | None = None
) -> FieldCondition | None:
kwargs = {}
if date_from:
kwargs["gte"] = as_timestamp(date_from)
if date_to:
kwargs["lt"] = as_timestamp(date_to)
if kwargs:
return FieldCondition(key=field, range=Range(**kwargs))
return None
def make_topic_condition(topic_id: str) -> FieldCondition:
return FieldCondition(key="topics[]", match=MatchValue(value=topic_id))
class Query:
_instance: Optional["Query"] = None
_embedding_model_instance: Optional[StellaEmbedder] = None
def __init__(self, index: str = "wsws", client: QdrantClient | None = None) -> None:
self.embedding_model = Query.embedding_model_singleton()
self.qdrant = QdrantClient(env_str("QDRANT_URL")) if client is None else client
self.index = index
@staticmethod
def embedding_model_singleton() -> StellaEmbedder:
if Query._embedding_model_instance is None:
Query._embedding_model_instance = StellaEmbedder()
return Query._embedding_model_instance
@staticmethod
def singleton() -> "Query":
if Query._instance is None:
Query._instance = Query()
return Query._instance
def embed(self, query: str) -> Any:
return self.embedding_model.embed([query])[0]
def query(
self,
query: str,
query_filter: Filter | None = None,
limit: int = 10,
) -> QueryResponse:
vector = self.embedding_model.embed([query])[0]
return self.qdrant.query_points(self.index, query=vector, query_filter=query_filter, limit=limit)
if __name__ == "__main__":
import gspread
from dotenv import load_dotenv
from gspread.utils import ValueInputOption
data = [
("2021-01-01", "2021-05-01", "The COVID winter wave, the emergence of the Delta variant and the January 6th coup"),
(
"2021-05-01",
"2021-09-01",
"The COVID vaccine rollout, Biden declaring independence from COVID while the Delta wave continues",
),
(
"2021-09-01",
"2022-01-01",
"The emergence of the COVID Omicron variant and the embrace of herd immunity by the ruling class",
),
]
load_dotenv()
query = Query()
rows: list[list[str]] = []
for date_from, date_to, sentence in data:
result = query.query(
sentence,
query_filter=Filter(should=make_date_condition(date_from=date_from, date_to=date_to)),
)
rows.append([sentence])
for point in result.points:
doc = point.payload
assert doc is not None
print(f'{point.score * 100:.1f}% https://www.wsws.org{doc["path"]} - {doc["title"]}')
rows.append(
[
f"{point.score * 100:.1f}%",
datetime.fromtimestamp(doc["published"]).strftime("%Y/%m/%d"),
", ".join(doc["authors"]),
f'=hyperlink("https://www.wsws.org{doc["path"]}", "{doc["title"]}")',
]
)
rows.append([])
gc = gspread.auth.oauth(credentials_filename=env_str("GOOGLE_CREDENTIALS"))
sh = gc.open("COVID-19 Compilation")
ws = sh.get_worksheet(0)
ws.append_rows(rows, value_input_option=ValueInputOption.user_entered)
|