File size: 2,970 Bytes
dd2bdcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import re
import sys
import traceback
from typing import NoReturn

import pytest

from .._util import (
    bytesify,
    LocalProtocolError,
    ProtocolError,
    RemoteProtocolError,
    Sentinel,
    validate,
)


def test_ProtocolError() -> None:
    with pytest.raises(TypeError):
        ProtocolError("abstract base class")


def test_LocalProtocolError() -> None:
    try:
        raise LocalProtocolError("foo")
    except LocalProtocolError as e:
        assert str(e) == "foo"
        assert e.error_status_hint == 400

    try:
        raise LocalProtocolError("foo", error_status_hint=418)
    except LocalProtocolError as e:
        assert str(e) == "foo"
        assert e.error_status_hint == 418

    def thunk() -> NoReturn:
        raise LocalProtocolError("a", error_status_hint=420)

    try:
        try:
            thunk()
        except LocalProtocolError as exc1:
            orig_traceback = "".join(traceback.format_tb(sys.exc_info()[2]))
            exc1._reraise_as_remote_protocol_error()
    except RemoteProtocolError as exc2:
        assert type(exc2) is RemoteProtocolError
        assert exc2.args == ("a",)
        assert exc2.error_status_hint == 420
        new_traceback = "".join(traceback.format_tb(sys.exc_info()[2]))
        assert new_traceback.endswith(orig_traceback)


def test_validate() -> None:
    my_re = re.compile(rb"(?P<group1>[0-9]+)\.(?P<group2>[0-9]+)")
    with pytest.raises(LocalProtocolError):
        validate(my_re, b"0.")

    groups = validate(my_re, b"0.1")
    assert groups == {"group1": b"0", "group2": b"1"}

    # successful partial matches are an error - must match whole string
    with pytest.raises(LocalProtocolError):
        validate(my_re, b"0.1xx")
    with pytest.raises(LocalProtocolError):
        validate(my_re, b"0.1\n")


def test_validate_formatting() -> None:
    my_re = re.compile(rb"foo")

    with pytest.raises(LocalProtocolError) as excinfo:
        validate(my_re, b"", "oops")
    assert "oops" in str(excinfo.value)

    with pytest.raises(LocalProtocolError) as excinfo:
        validate(my_re, b"", "oops {}")
    assert "oops {}" in str(excinfo.value)

    with pytest.raises(LocalProtocolError) as excinfo:
        validate(my_re, b"", "oops {} xx", 10)
    assert "oops 10 xx" in str(excinfo.value)


def test_make_sentinel() -> None:
    class S(Sentinel, metaclass=Sentinel):
        pass

    assert repr(S) == "S"
    assert S == S
    assert type(S).__name__ == "S"
    assert S in {S}
    assert type(S) is S

    class S2(Sentinel, metaclass=Sentinel):
        pass

    assert repr(S2) == "S2"
    assert S != S2
    assert S not in {S2}
    assert type(S) is not type(S2)


def test_bytesify() -> None:
    assert bytesify(b"123") == b"123"
    assert bytesify(bytearray(b"123")) == b"123"
    assert bytesify("123") == b"123"

    with pytest.raises(UnicodeEncodeError):
        bytesify("\u1234")

    with pytest.raises(TypeError):
        bytesify(10)