File size: 4,053 Bytes
04595e7
 
 
 
 
 
 
 
 
9a45764
04595e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a45764
04595e7
9a45764
04595e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a45764
04595e7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import warnings
from datetime import datetime
from typing import Any, Optional

from qdrant_client import QdrantClient
from qdrant_client.http.models import QueryResponse
from qdrant_client.models import FieldCondition, Filter, MatchValue, Range

from article_embedding.embed import StellaEmbedder
from article_embedding.utils import env_str

warnings.simplefilter(action="ignore", category=FutureWarning)


def as_timestamp(date: datetime | str) -> float:
    if isinstance(date, datetime):
        return date.timestamp()
    return datetime.strptime(date, "%Y-%m-%d").timestamp()


def make_date_condition(
    *, field: str = "published", date_from: datetime | str | None = None, date_to: datetime | str | None = None
) -> FieldCondition | None:
    kwargs = {}
    if date_from:
        kwargs["gte"] = as_timestamp(date_from)
    if date_to:
        kwargs["lt"] = as_timestamp(date_to)
    if kwargs:
        return FieldCondition(key=field, range=Range(**kwargs))
    return None


def make_topic_condition(topic_id: str) -> FieldCondition:
    return FieldCondition(key="topics[]", match=MatchValue(value=topic_id))


class Query:
    _instance: Optional["Query"] = None
    _embedding_model_instance: Optional[StellaEmbedder] = None

    def __init__(self, index: str = "wsws", client: QdrantClient | None = None) -> None:
        self.embedding_model = Query.embedding_model_singleton()
        self.qdrant = QdrantClient(env_str("QDRANT_URL")) if client is None else client
        self.index = index

    @staticmethod
    def embedding_model_singleton() -> StellaEmbedder:
        if Query._embedding_model_instance is None:
            Query._embedding_model_instance = StellaEmbedder()
        return Query._embedding_model_instance

    @staticmethod
    def singleton() -> "Query":
        if Query._instance is None:
            Query._instance = Query()
        return Query._instance

    def embed(self, query: str) -> Any:
        return self.embedding_model.embed([query])[0]

    def query(
        self,
        query: str,
        query_filter: Filter | None = None,
        limit: int = 10,
    ) -> QueryResponse:
        vector = self.embedding_model.embed([query])[0]
        return self.qdrant.query_points(self.index, query=vector, query_filter=query_filter, limit=limit)


if __name__ == "__main__":
    import gspread
    from dotenv import load_dotenv
    from gspread.utils import ValueInputOption

    data = [
        ("2021-01-01", "2021-05-01", "The COVID winter wave, the emergence of the Delta variant and the January 6th coup"),
        (
            "2021-05-01",
            "2021-09-01",
            "The COVID vaccine rollout, Biden declaring independence from COVID while the Delta wave continues",
        ),
        (
            "2021-09-01",
            "2022-01-01",
            "The emergence of the COVID Omicron variant and the embrace of herd immunity by the ruling class",
        ),
    ]

    load_dotenv()
    query = Query()
    rows: list[list[str]] = []
    for date_from, date_to, sentence in data:
        result = query.query(
            sentence,
            query_filter=Filter(should=make_date_condition(date_from=date_from, date_to=date_to)),
        )
        rows.append([sentence])
        for point in result.points:
            doc = point.payload
            assert doc is not None
            print(f'{point.score * 100:.1f}% https://www.wsws.org{doc["path"]} - {doc["title"]}')
            rows.append(
                [
                    f"{point.score * 100:.1f}%",
                    datetime.fromtimestamp(doc["published"]).strftime("%Y/%m/%d"),
                    ", ".join(doc["authors"]),
                    f'=hyperlink("https://www.wsws.org{doc["path"]}", "{doc["title"]}")',
                ]
            )
        rows.append([])

    gc = gspread.auth.oauth(credentials_filename=env_str("GOOGLE_CREDENTIALS"))
    sh = gc.open("COVID-19 Compilation")
    ws = sh.get_worksheet(0)
    ws.append_rows(rows, value_input_option=ValueInputOption.user_entered)