Spaces:
Runtime error
Runtime error
import collections | |
import random | |
import time | |
import multiprocessing as mp | |
import json | |
from PIL import Image | |
from compiled_jss.CPEnv import CompiledJssEnvCP | |
from stable_baselines3.common.vec_env import VecEnvWrapper | |
from torch.distributions import Categorical | |
import torch | |
import numpy as np | |
from MyVecEnv import WrapperRay | |
import gradio as gr | |
import docplex.cp.utils_visu as visu | |
import matplotlib.pyplot as plt | |
class VecPyTorch(VecEnvWrapper): | |
def __init__(self, venv, device): | |
super(VecPyTorch, self).__init__(venv) | |
self.device = device | |
def reset(self): | |
return self.venv.reset() | |
def step_async(self, actions): | |
self.venv.step_async(actions) | |
def step_wait(self): | |
return self.venv.step_wait() | |
def make_env(seed, instance): | |
def thunk(): | |
_env = CompiledJssEnvCP(instance) | |
return _env | |
return thunk | |
def solve(file): | |
random.seed(0) | |
np.random.seed(0) | |
torch.manual_seed(0) | |
num_workers = min(mp.cpu_count(), 32) | |
with torch.inference_mode(): | |
device = torch.device('cpu') | |
actor = torch.jit.load('actor.pt', map_location=device) | |
actor.eval() | |
start_time = time.time() | |
fn_env = [make_env(0, file.name) | |
for _ in range(num_workers)] | |
ray_wrapper_env = WrapperRay(lambda n: fn_env[n](), | |
num_workers, 1, device) | |
envs = VecPyTorch(ray_wrapper_env, device) | |
current_solution_cost = float('inf') | |
current_solution = '' | |
obs = envs.reset() | |
total_episode = 0 | |
while total_episode < envs.num_envs: | |
logits = actor(obs['interval_rep'], obs['attention_interval_mask'], obs['job_resource_mask'], | |
obs['action_mask'], obs['index_interval'], obs['start_end_tokens']) | |
# temperature vector | |
if num_workers >= 4: | |
temperature = torch.arange(0.5, 2.0, step=(1.5 / num_workers), device=device) | |
else: | |
temperature = torch.ones(num_workers, device=device) | |
logits = logits / temperature[:, None] | |
probs = Categorical(logits=logits).probs | |
# random sample based on logits | |
actions = torch.multinomial(probs, probs.shape[1]).cpu().numpy() | |
obs, reward, done, infos = envs.step(actions) | |
total_episode += done.sum() | |
# total_actions += 1 | |
# print(f'Episode {total_episode} / {envs.num_envs} - Actions {total_actions}', end='\r') | |
for env_idx, info in enumerate(infos): | |
if 'makespan' in info and int(info['makespan']) < current_solution_cost: | |
current_solution_cost = int(info['makespan']) | |
current_solution = json.loads(info['solution']) | |
total_time = time.time() - start_time | |
pretty_output = "" | |
for job_id in range(len(current_solution)): | |
pretty_output += f"Job {job_id}: {current_solution[job_id]}\n" | |
jobs_data = [] | |
file.seek(0) | |
line_str: str = file.readline() | |
line_cnt: int = 1 | |
while line_str: | |
data = [] | |
split_data = line_str.split() | |
if line_cnt == 1: | |
jobs_count, machines_count = int(split_data[0]), int( | |
split_data[1] | |
) | |
else: | |
i = 0 | |
this_job_op_count = 0 | |
while i < len(split_data): | |
machine, op_time = int(split_data[i]), int(split_data[i + 1]) | |
data.append((machine, op_time)) | |
i += 2 | |
this_job_op_count += 1 | |
jobs_data.append(data) | |
line_str = file.readline() | |
line_cnt += 1 | |
visu.timeline(f'Solution for job-shop, solved using ') | |
visu.panel('Jobs') | |
# convert to integer the current_solution | |
current_solution = [[int(x) for x in y] for y in current_solution] | |
for job_id in range(len(current_solution)): | |
visu.sequence(name=f'J{job_id}', intervals=[(current_solution[job_id][task_id], | |
current_solution[job_id][task_id] + jobs_data[job_id][task_id][ | |
1], jobs_data[job_id][task_id][0], | |
f'M{jobs_data[job_id][task_id][0]}') | |
for task_id in | |
range(len(current_solution[job_id]))]) | |
visu.panel('Machines') | |
machine_solution = collections.defaultdict(list) | |
for job_id in range(len(current_solution)): | |
for task_id in range(len(current_solution[job_id])): | |
machine = jobs_data[job_id][task_id][1] | |
machine_solution[machine].append((current_solution[job_id][task_id], | |
current_solution[job_id][task_id] + jobs_data[job_id][task_id][1], | |
machine, f'J{job_id}')) | |
# sort dictionary keys | |
machine_solution = {k: machine_solution[k] for k in sorted(machine_solution.keys())} | |
for machine_id in machine_solution: | |
visu.sequence(name=f'M{machine_id}', | |
intervals=machine_solution[machine_id]) | |
plt.rcParams["font.family"] = "Times New Roman" | |
plt.rcParams["font.size"] = "30" | |
plt.gca().set_aspect('equal') | |
plt.rcParams["figure.figsize"] = (45, 50) | |
from io import BytesIO | |
buffer = BytesIO() | |
visu.show(pngfile=buffer) | |
reloadedPILImage = Image.open(buffer) | |
return pretty_output, reloadedPILImage, str(total_time) + " seconds" | |
title = "Job-Shop Scheduling CP RL" | |
description = "A Job-Shop Scheduling Reinforcement Learning based solver, using an underlying CP model as an " \ | |
"environment. " | |
article = "<p style='text-align: center'>Article Under Review</p>" | |
examples = ['ta01', 'dmu01.txt', 'la01.txt'] | |
iface = gr.Interface(fn=solve, inputs=gr.File(label="Instance File"), outputs=[gr.Text(label="Solution"), gr.Image(label="Solution's Gantt Chart"), gr.Text(label="Elapsed Time")], title=title, description=description, article=article, examples=examples) | |
iface.launch(enable_queue=True) | |