|
from dataclasses import dataclass |
|
from expand import Series, BatchExpander, Expansion, Batch, TokenCandidates, BatchCandidates, CompletedSequence, CompletedBatch, expand |
|
|
|
possible_sequences = [ |
|
[1, 21, 31, 41], |
|
[1, 21, 31, 42], |
|
[1, 21, 32, 41, 51], |
|
[1, 22, 33, 41], |
|
[1, 22, 34, 41], |
|
] |
|
|
|
def expand_series(series: Series) -> list[Expansion]: |
|
all_tokens = series.get_all_tokens() |
|
l = len(all_tokens) |
|
items = [s[l] for s in possible_sequences if s[:l] == all_tokens and len(s) > l] |
|
candidates = [Expansion(token=l, cost=-1.0) for l in dict.fromkeys(items)] |
|
return candidates |
|
|
|
class PredefinedSequenceExpander(BatchExpander): |
|
def expand(self, batch: Batch) -> BatchCandidates: |
|
result = [] |
|
for s in batch.items: |
|
expansions = expand_series(s) |
|
result.append(TokenCandidates(series=s, expansions=expansions)) |
|
return BatchCandidates(items=result) |
|
|
|
expander = PredefinedSequenceExpander() |
|
|
|
def test_expander_zero_budget(): |
|
s = Series(id=0, tokens=[1], budget=0.0) |
|
expanded = expander.expand(Batch(items=[s])) |
|
expected = BatchCandidates( |
|
items=[TokenCandidates(series=s, expansions=[ |
|
Expansion(token=21, cost=-1.0), |
|
Expansion(token=22, cost=-1.0), |
|
])] |
|
) |
|
assert expected == expanded |
|
|
|
def test_expander_budget_one(): |
|
s = Series(id=0, tokens=[1], budget=1.0) |
|
expanded = expander.expand(Batch(items=[s])) |
|
expected = BatchCandidates( |
|
items=[TokenCandidates(series=s, expansions=[ |
|
Expansion(token=21, cost=-1.0), |
|
Expansion(token=22, cost=-1.0), |
|
])] |
|
) |
|
assert expected == expanded |
|
|
|
def test_expander_budget_two(): |
|
s = Series(id=0, tokens=[1], budget=2.0) |
|
expanded = expander.expand(Batch(items=[s])) |
|
expected = BatchCandidates( |
|
items=[TokenCandidates(series=s, expansions=[ |
|
Expansion(token=21, cost=-1.0), |
|
Expansion(token=22, cost=-1.0), |
|
])] |
|
) |
|
assert expected == expanded |
|
|
|
def test_expander_budget_one_no_expansion(): |
|
s = Series(id=0, tokens=[1, 20], budget=1.0) |
|
expanded = expander.expand(Batch(items=[s])) |
|
expected = BatchCandidates( |
|
items=[TokenCandidates(series=s, expansions=[])] |
|
) |
|
assert expected == expanded |
|
|
|
def test_expander_budget_one_two_tokens(): |
|
s = Series(id=0, tokens=[1, 22], budget=1.0) |
|
expanded = expander.expand(Batch(items=[s])) |
|
expected = BatchCandidates( |
|
items=[TokenCandidates(series=s, expansions=[ |
|
Expansion(token=33, cost=-1.0), |
|
Expansion(token=34, cost=-1.0), |
|
])] |
|
) |
|
assert expected == expanded |
|
|
|
def test_expander_budget_one_two_tokens_two_series(): |
|
s1 = Series(id=0, tokens=[1, 21, 31], budget=1.0) |
|
s2 = Series(id=1, tokens=[1, 22], budget=1.0) |
|
expanded = expander.expand(Batch(items=[s1, s2])) |
|
expected = BatchCandidates( |
|
items=[ |
|
TokenCandidates(series=s1, expansions=[ |
|
Expansion(token=41, cost=-1.0), |
|
Expansion(token=42, cost=-1.0), |
|
]), |
|
TokenCandidates(series=s2, expansions=[ |
|
Expansion(token=33, cost=-1.0), |
|
Expansion(token=34, cost=-1.0), |
|
]) |
|
] |
|
) |
|
assert expected == expanded |
|
|
|
def test_expand_01(): |
|
batch = Batch(items=[ |
|
Series(id=0, tokens=[1, 21], budget=1.0), |
|
Series(id=1, tokens=[1, 22], budget=1.0), |
|
]) |
|
expanded = expand(batch, expander) |
|
assert expanded == CompletedBatch(items=[ |
|
CompletedSequence( |
|
series=Series(id=0, tokens=[1, 21], budget=1.0), |
|
expansions=[ |
|
[Expansion(token=31, cost=-1.0)], |
|
[Expansion(token=32, cost=-1.0)], |
|
] |
|
), |
|
CompletedSequence( |
|
series=Series(id=1, tokens=[1, 22], budget=1.0), |
|
expansions=[ |
|
[Expansion(token=33, cost=-1.0)], |
|
[Expansion(token=34, cost=-1.0)], |
|
] |
|
), |
|
]) |
|
|
|
def test_expand_02(): |
|
batch = Batch(items=[ |
|
Series(id=0, tokens=[1, 21], budget=2.0), |
|
Series(id=1, tokens=[1, 22], budget=1.0), |
|
]) |
|
expanded = expand(batch, expander) |
|
assert expanded == CompletedBatch(items=[ |
|
CompletedSequence( |
|
series=Series(id=0, tokens=[1, 21], budget=2.0), |
|
expansions=[ |
|
[Expansion(token=31, cost=-1.0), Expansion(token=41, cost=-1.0)], |
|
[Expansion(token=31, cost=-1.0), Expansion(token=42, cost=-1.0)], |
|
[Expansion(token=32, cost=-1.0), Expansion(token=41, cost=-1.0)], |
|
] |
|
), |
|
CompletedSequence( |
|
series=Series(id=1, tokens=[1, 22], budget=1.0), |
|
expansions=[ |
|
[Expansion(token=33, cost=-1.0)], |
|
[Expansion(token=34, cost=-1.0)], |
|
] |
|
), |
|
]) |
|
|
|
def test_expand_03(): |
|
batch = Batch(items=[ |
|
Series(id=0, tokens=[1, 21], budget=3.0), |
|
Series(id=1, tokens=[1, 22], budget=0.0), |
|
]) |
|
expanded = expand(batch, expander) |
|
assert expanded == CompletedBatch(items=[ |
|
CompletedSequence( |
|
series=Series(id=0, tokens=[1, 21], budget=3.0), |
|
expansions=[ |
|
[Expansion(token=31, cost=-1.0), Expansion(token=41, cost=-1.0)], |
|
[Expansion(token=31, cost=-1.0), Expansion(token=42, cost=-1.0)], |
|
[Expansion(token=32, cost=-1.0), Expansion(token=41, cost=-1.0), Expansion(token=51, cost=-1.0)], |
|
] |
|
), |
|
CompletedSequence( |
|
series=Series(id=1, tokens=[1, 22], budget=0.0), |
|
expansions=[], |
|
), |
|
]) |
|
|