Spaces:
Runtime error
Runtime error
| import datetime as dt | |
| import os | |
| import sqlite3 | |
| from types import SimpleNamespace | |
| import pytest | |
| from fastapi.testclient import TestClient | |
| def is_roughly_now(datetime_str): | |
| """Check if a datetime string is roughly from now""" | |
| now = dt.datetime.now(dt.timezone.utc) | |
| datetime = dt.datetime.fromisoformat(datetime_str) | |
| # set timezone, otherwise cannot subtract | |
| datetime = datetime.replace(tzinfo=dt.timezone.utc) | |
| return (now - datetime).total_seconds() < 3 | |
| class TestWebservice: | |
| def db_file(self, tmp_path): | |
| filename = tmp_path / "test-db.sqlite" | |
| os.environ["DB_FILE_NAME"] = str(filename) | |
| def cursor(self): | |
| from gistillery.db import get_db_cursor | |
| with get_db_cursor() as cursor: | |
| yield cursor | |
| def client(self): | |
| from gistillery.webservice import app | |
| client = TestClient(app) | |
| client.get("/clear") | |
| return client | |
| def registry(self): | |
| # use dummy models | |
| from gistillery.tools import Summarizer, Tagger | |
| from gistillery.preprocessing import RawTextProcessor | |
| from gistillery.registry import ToolRegistry | |
| class DummySummarizer(Summarizer): | |
| """Returns the first 10 characters of the input""" | |
| def get_name(self): | |
| return "dummy summarizer" | |
| def __call__(self, x): | |
| return x[:10] | |
| class DummyTagger(Tagger): | |
| """Returns the first 3 words of the input""" | |
| def get_name(self): | |
| return "dummy tagger" | |
| def __call__(self, x): | |
| return ["#" + word for word in x.split(maxsplit=4)[:3]] | |
| registry = ToolRegistry() | |
| registry.register_processor(RawTextProcessor()) | |
| # arguments don't matter for dummy summarizer and tagger | |
| summarizer = DummySummarizer() | |
| registry.register_summarizer(summarizer) | |
| tagger = DummyTagger() | |
| registry.register_tagger(tagger) | |
| return registry | |
| def process_jobs(self, registry): | |
| # emulate work of the background worker | |
| from gistillery.worker import check_pending_jobs, process_job | |
| jobs = check_pending_jobs() | |
| for job in jobs: | |
| process_job(job, registry) | |
| def test_status(self, client): | |
| resp = client.get("/status") | |
| assert resp.status_code == 200 | |
| assert resp.json() == "OK" | |
| def test_recent_empty(self, client): | |
| resp = client.get("/recent") | |
| assert resp.json() == [] | |
| def test_recent_tag_empty(self, client, monkeypatch): | |
| resp = client.get("/recent/general") | |
| assert resp.json() == [] | |
| # monkeypatch uuid4 to return a known value | |
| job_id = "abc1234" | |
| monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex=job_id)) | |
| client.post("/submit", json={"author": "ben", "content": "this is a test"}) | |
| resp = client.get(f"/check_job_status/{job_id}") | |
| output = resp.json() | |
| last_updated = output.pop("last_updated") | |
| assert output == { | |
| "id": job_id, | |
| "status": "pending", | |
| } | |
| assert is_roughly_now(last_updated) | |
| def test_submitted_job_status_not_found(self, client, monkeypatch): | |
| # monkeypatch uuid4 to return a known value | |
| job_id = "abc1234" | |
| monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex=job_id)) | |
| client.post("/submit", json={"author": "ben", "content": "this is a test"}) | |
| other_job_id = "def5678" | |
| resp = client.get(f"/check_job_status/{other_job_id}") | |
| output = resp.json() | |
| last_updated = output.pop("last_updated") | |
| assert output == { | |
| "id": other_job_id, | |
| "status": "not found", | |
| } | |
| assert last_updated is None | |
| def test_submitted_job_failed(self, client, registry, monkeypatch): | |
| # monkeypatch uuid4 to return a known value | |
| job_id = "abc1234" | |
| monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex=job_id)) | |
| client.post("/submit", json={"author": "ben", "content": "this is a test"}) | |
| # patch gistillery.worker._process_job to raise an exception | |
| def raise_(ex): | |
| raise ex | |
| # make the job processing fail | |
| monkeypatch.setattr( | |
| "gistillery.worker._process_job", | |
| lambda job, registry: raise_(RuntimeError("something went wrong")), | |
| ) | |
| self.process_jobs(registry) | |
| resp = client.get(f"/check_job_status/{job_id}") | |
| output = resp.json() | |
| output.pop("last_updated") | |
| assert output == { | |
| "id": job_id, | |
| "status": "failed", | |
| } | |
| def test_submitted_job_status_done(self, client, registry, monkeypatch): | |
| # monkeypatch uuid4 to return a known value | |
| job_id = "abc1234" | |
| monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex=job_id)) | |
| client.post("/submit", json={"author": "ben", "content": "this is a test"}) | |
| self.process_jobs(registry) | |
| resp = client.get(f"/check_job_status/{job_id}") | |
| output = resp.json() | |
| last_updated = output.pop("last_updated") | |
| assert output == { | |
| "id": job_id, | |
| "status": "done", | |
| } | |
| assert is_roughly_now(last_updated) | |
| def test_status_pending_jobs(self, client, registry, monkeypatch): | |
| resp = client.get("/check_job_status/") | |
| output = resp.json() | |
| assert output == "No pending jobs found" | |
| monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex="abc0")) | |
| client.post("/submit", json={"author": "ben", "content": "this is a test"}) | |
| resp = client.get("/check_job_status/") | |
| output = resp.json() | |
| expected = "Found 1 pending job(s): abc0" | |
| assert output == expected | |
| for i in range(1, 10): | |
| monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex=f"abc{i}")) | |
| client.post("/submit", json={"author": "ben", "content": "this is a test"}) | |
| resp = client.get("/check_job_status/") | |
| output = resp.json() | |
| expected = "Found 10 pending job(s): abc0, abc1, abc2, ..." | |
| assert output == expected | |
| def test_recent_with_entries(self, client, registry): | |
| # submit 2 entries | |
| client.post( | |
| "/submit", json={"author": "maxi", "content": "this is a first test"} | |
| ) | |
| client.post( | |
| "/submit", | |
| json={"author": "mini", "content": "this would be something else"}, | |
| ) | |
| self.process_jobs(registry) | |
| resp = client.get("/recent").json() | |
| # results are sorted by recency but since dummy models are so fast, the | |
| # date in the db could be the same, so we sort by author | |
| resp = sorted(resp, key=lambda x: x["author"]) | |
| assert len(resp) == 2 | |
| resp0 = resp[0] | |
| assert resp0["author"] == "maxi" | |
| assert resp0["summary"] == "this is a " | |
| assert resp0["tags"] == sorted(["#this", "#is", "#a"]) | |
| resp1 = resp[1] | |
| assert resp1["author"] == "mini" | |
| assert resp1["summary"] == "this would" | |
| assert resp1["tags"] == sorted(["#this", "#would", "#be"]) | |
| def test_recent_source_snippet_shortened(self, client, registry): | |
| # submit 2 entries | |
| client.post("/submit", json={"author": "alice", "content": "this is short"}) | |
| client.post( | |
| "/submit", | |
| json={"author": "bob", "content": "this is long " * 100}, | |
| ) | |
| self.process_jobs(registry) | |
| resp = client.get("/recent").json() | |
| resp = sorted(resp, key=lambda x: x["author"]) | |
| assert resp[0]["source_snippet"] == "this is short" | |
| expected_shortened = ( | |
| "this is long this is long this is long this is long this is long th" | |
| "...ng this is long this is long this is long this is long this is long " | |
| ) | |
| assert resp[1]["source_snippet"] == expected_shortened | |
| def test_recent_tag_with_entries(self, client, registry): | |
| # submit 2 entries | |
| client.post( | |
| "/submit", json={"author": "maxi", "content": "this is a first test"} | |
| ) | |
| client.post( | |
| "/submit", | |
| json={"author": "mini", "content": "this would be something else"}, | |
| ) | |
| self.process_jobs(registry) | |
| # the "this" tag is in both entries | |
| resp = client.get("/recent/this").json() | |
| assert len(resp) == 2 | |
| # the "would" tag is in only one entry | |
| resp = client.get("/recent/would").json() | |
| assert len(resp) == 1 | |
| resp0 = resp[0] | |
| assert resp0["author"] == "mini" | |
| assert resp0["summary"] == "this would" | |
| assert resp0["tags"] == sorted(["#this", "#would", "#be"]) | |
| def test_recent_multiple_entries(self, client, registry): | |
| # submit 2 entries | |
| client.post( | |
| "/submit", json={"author": "maxi", "content": "aardvark ant antelope"} | |
| ) | |
| client.post( | |
| "/submit", | |
| json={"author": "mini", "content": "bat bear bee"}, | |
| ) | |
| client.post( | |
| "/submit", | |
| json={"author": "mini", "content": "camel canary cat"}, | |
| ) | |
| self.process_jobs(registry) | |
| # the "ant" tag is in only one entry | |
| resp = client.get("/recent/ant").json() | |
| assert len(resp) == 1 | |
| # "ant" and "bee" are in two entries | |
| resp = client.get("/recent/ant,bee").json() | |
| assert len(resp) == 2 | |
| # "ant" and "bee" and "cat" are in three entries | |
| resp = client.get("/recent/cat,ant,bee").json() | |
| assert len(resp) == 3 | |
| def test_tag_count(self, client, registry): | |
| # submit 2 entries | |
| client.post( | |
| "/submit", json={"author": "ben", "content": "aardvark ant antelope"} | |
| ) | |
| client.post( | |
| "/submit", | |
| json={"author": "ben", "content": "aardvark ant bat"}, | |
| ) | |
| client.post( | |
| "/submit", | |
| json={"author": "ben", "content": "aardvark camel canary"}, | |
| ) | |
| self.process_jobs(registry) | |
| resp = client.get("/tag_counts").json() | |
| expected = { | |
| "#aardvark": 3, | |
| "#ant": 2, | |
| "#antelope": 1, | |
| "#bat": 1, | |
| "#camel": 1, | |
| "#canary": 1, | |
| } | |
| assert resp == expected | |
| def test_clear(self, client, cursor, registry): | |
| client.post("/submit", json={"author": "ben", "content": "this is a test"}) | |
| self.process_jobs(registry) | |
| assert cursor.execute("SELECT COUNT(*) c FROM entries").fetchone()[0] == 1 | |
| client.get("/clear") | |
| assert cursor.execute("SELECT COUNT(*) c FROM entries").fetchone()[0] == 0 | |
| def test_inputs_stored(self, client, cursor, registry): | |
| client.post("/submit", json={"author": "ben", "content": " this is a test\n"}) | |
| self.process_jobs(registry) | |
| rows = cursor.execute("SELECT * FROM inputs").fetchall() | |
| assert len(rows) == 1 | |
| assert rows[0].input == "this is a test" | |
| def test_submit_url(self, client, cursor, registry, monkeypatch): | |
| class MockClient: | |
| """Mock httpx Client, return www.example.com content""" | |
| def get(self, url): | |
| return SimpleNamespace( | |
| text=''' <!doctype html>\n<html>\n<head>\n <title>Example | |
| Domain</title>\n\n <meta charset="utf-8" />\n <meta | |
| http-equiv="Content-type" content="text/html; charset=utf-8" | |
| />\n <meta name="viewport" content="width=device-width, | |
| initial-scale=1" />\n <style type="text/css">\n body {\n | |
| background-color: #f0f0f2;\n margin: 0;\n padding: 0;\n | |
| font-family: -apple-system, system-ui, BlinkMacSystemFont, | |
| "Segoe UI", "Open Sans", "Helvetica Neue", Helvetica, Arial, | |
| sans-serif;\n \n }\n div {\n width: 600px;\n margin: 5em | |
| auto;\n padding: 2em;\n background-color: #fdfdff;\n | |
| border-radius: 0.5em;\n box-shadow: 2px 3px 7px 2px | |
| rgba(0,0,0,0.02);\n }\n a:link, a:visited {\n color: | |
| #38488f;\n text-decoration: none;\n }\n @media (max-width: | |
| 700px) {\n div {\n margin: 0 auto;\n width: auto;\n }\n }\n | |
| </style> \n</head>\n\n<body>\n<div>\n <h1>Example | |
| Domain</h1>\n <p>This domain is for use in illustrative | |
| examples in documents. You may use this\n domain in | |
| literature without prior coordination or asking for | |
| permission.</p>\n <p><a | |
| href="https://www.iana.org/domains/example">More | |
| information...</a></p>\n</div>\n</body>\n</html>\n''' | |
| ) | |
| monkeypatch.setattr("gistillery.preprocessing.Client", MockClient) | |
| from gistillery.preprocessing import DefaultUrlProcessor | |
| # register url processor, put it before the default processor | |
| registry.register_processor(DefaultUrlProcessor(), last=False) | |
| client.post( | |
| "/submit", | |
| json={ | |
| "author": "ben", | |
| "content": "https://en.wikipedia.org/wiki/non-existing-page", | |
| }, | |
| ) | |
| self.process_jobs(registry) | |
| rows = cursor.execute("SELECT * FROM inputs").fetchall() | |
| assert len(rows) == 1 | |
| expected = "\n".join( | |
| [ | |
| 'https://en.wikipedia.org/wiki/non-existing-page', | |
| '', | |
| 'This domain is for use in illustrative', | |
| 'examples in documents. You may use this', | |
| 'domain in', | |
| 'literature without prior coordination or asking for', | |
| 'permission.', | |
| 'More', | |
| 'information...', | |
| ] | |
| ) | |
| assert rows[0].input == expected | |
| def test_backup(self, client, tmp_path): | |
| # submit an entry, create a backup, check that the backup contains the entry | |
| from gistillery.db import namedtuple_factory | |
| client.post("/submit", json={"author": "Pie Test", "content": "this is a pie"}) | |
| resp = client.get("/backup") | |
| assert resp.status_code == 200 | |
| with open(tmp_path / "backup.db", "wb") as f: | |
| f.write(resp.content) | |
| conn = sqlite3.connect(tmp_path / "backup.db") | |
| conn.row_factory = namedtuple_factory | |
| cursor = conn.cursor() | |
| res = cursor.execute("select * from entries").fetchall() | |
| assert len(res) == 1 | |
| assert is_roughly_now(res[0].created_at) | |
| assert res[0].author == "Pie Test" | |