pattern_size = 6
from collections import Counter
from dataclasses import dataclass

@dataclass(eq=True, frozen=True)
class ScheduledNode:
    type: str
    stage: int
    minibatch: int
    start_time: int
    completion_time: int

def transform_schedule(schedule, f, b, w, c):
    result = []

    stage_order = []
    local_prev = {}
    stages = len(schedule)
    
    for sid, stage in enumerate(schedule):
        counter = Counter()
        order = []
        for p in stage:
            if not p.strip():
                continue
            mb = counter.get(p, 0)
            if order:
                local_prev[(sid, p, mb)] = order[-1]
            order.append((p, mb))
            counter.update(p)
        stage_order.append(order)
    nmb = max(counter.values())
    time_map = {}
    cost = {
        'F': f,
        'B': b,
        'W': w,
    }
    def get_time(stage, type, mb):
        if (stage, type, mb) in time_map:
            return time_map.get((stage, type, mb))
        time = 0
        if (stage, type, mb) in local_prev:
            time = get_time(stage, *local_prev[(stage, type, mb)])
        if type in ('F') and stage > 0:
            time = max(time, get_time(stage - 1, type, mb) + c)
        if type in ('B') and stage + 1< len(schedule):
            time = max(time, get_time(stage + 1, type, mb) + c)
        # print(f'{stage} {type}:{mb}', time + cost[type])
        time_map[(stage, type, mb)] = time + cost[type]
        return time_map[(stage, type, mb)]
    r = 0
    for sid, stage in enumerate(schedule):
        r = max(get_time(sid, 'W', nmb - 1) - get_time(sid, 'F', 0) + f, r)

    for sid, stage in enumerate(stage_order):
        result_stage = []
        for p, mb in stage:
            result_stage.append(ScheduledNode(
                p.upper(),
                sid,
                mb,
                get_time(sid, p, mb) - cost[p],
                get_time(sid, p, mb)
            )
            )
        result.append(result_stage)
    return result




def process_warmup_without_increasing_peak_mem(schedules, m):
    peak_mem = 0
    mem = [[0 for _ in range(len(schedules[0]))] for _ in range(len(schedules))]
    loc = [[{key: -1 for key in ('F', 'B', 'W')} for _ in range(m + 2)] for _ in range(len(schedules))]
    cntr = [{key: 0 for key in ('F', 'B', 'W')} for _ in range(len(schedules))]
    for sid in range(len(schedules)):
        cur = 0
        for i in range(len(schedules[sid])):
            if schedules[sid][i] in ('F'):
                cur += 1
            if schedules[sid][i] in ('W'):
                cur -= 1
            mem[sid][i] = cur
            peak_mem = max(peak_mem, cur)
    for i in range(len(schedules[0])):
        for sid in range(len(schedules)):
            if schedules[sid][i] == ' ':
                continue
            cntr[sid][schedules[sid][i]] += 1
            cnt = cntr[sid][schedules[sid][i]]
            pos = -1
            if cnt > 1:
                pos = loc[sid][cnt - 1][schedules[sid][i]]
            if schedules[sid][i] == 'W':
                pos = max(pos, loc[sid][cnt]['B'])
            if schedules[sid][i] == 'F' and sid > 0:
                pos = max(pos, loc[sid - 1][cnt]['F'])
            if schedules[sid][i] == 'B':
                if sid != len(schedules) - 1:
                    pos = max(pos, loc[sid + 1][cnt]['B'])
                else :
                    pos = max(pos, loc[sid][cnt]['F'])
            pos += 1
            while schedules[sid][pos] != ' ' and pos < i:
                pos += 1
            if pos == i:
                loc[sid][cnt][schedules[sid][i]] = i
                continue
            if schedules[sid][i] in ('B', 'W'):
                schedules[sid][pos] = schedules[sid][i]
                schedules[sid][i] = ' '
                if schedules[sid][pos] in ('W'):
                    for j in range(pos, i):
                        mem[sid][j] -= 1
                loc[sid][cnt][schedules[sid][pos]] = pos
                continue
            
            #If F:
            if (sid == 0):
                print(cnt, pos, i)
            place = i
            while place > pos and mem[sid][place - 1] < peak_mem:
                place -= 1
            while place < i and schedules[sid][place] != ' ':
                place += 1
            if place == i:
                loc[sid][cnt][schedules[sid][i]] = i
                continue
            if (sid == 0):
                print(place)
            pos = place
            schedules[sid][pos] = schedules[sid][i]
            schedules[sid][i] = ' '
            for j in range(pos, i):
                mem[sid][j] += 1
            loc[sid][cnt][schedules[sid][pos]] = pos
    return schedules

def schedule(p, m, cost):
    schedules = [[' ' for _ in range(6 * m + 2 * p + 6)] for _ in range(p)]
    f_0, f_1, b_0, b_1= p-1, p+1, p, p + 2
    for sid in range(p - 1, -1, -1):
        for mid in range((m + 1) // 2):
            if mid * 2 < m:
                schedules[sid][f_0 + mid * 6], schedules[sid][b_0 + mid * 6] = 'F', 'B'
            if mid * 2 + 1 < m:
                schedules[sid][f_1 + mid * 6], schedules[sid][b_1 + mid * 6] = 'F', 'B'
        f_0 -= 1
        f_1 -= 1
        b_0 += 1
        b_1 += 1
        cnt = 0
        for i in range(len(schedules[0])):
            if schedules[sid][i] == 'B':
                cnt += 1
            if schedules[sid][i] == ' ' and cnt > 0:
                cnt -= 1
                schedules[sid][i] = 'W'
    schedules = process_warmup_without_increasing_peak_mem(schedules, m)
    res = transform_schedule(schedules, *cost)
    return res