JobShopCPRL / app.py
Pierre Tassel
wip
d746b98
raw
history blame
6.43 kB
import collections
import random
import time
import multiprocessing as mp
import json
from PIL import Image
from compiled_jss.CPEnv import CompiledJssEnvCP
from stable_baselines3.common.vec_env import VecEnvWrapper
from torch.distributions import Categorical
import torch
import numpy as np
from MyVecEnv import WrapperRay
import gradio as gr
import docplex.cp.utils_visu as visu
import matplotlib.pyplot as plt
class VecPyTorch(VecEnvWrapper):
def __init__(self, venv, device):
super(VecPyTorch, self).__init__(venv)
self.device = device
def reset(self):
return self.venv.reset()
def step_async(self, actions):
self.venv.step_async(actions)
def step_wait(self):
return self.venv.step_wait()
def make_env(seed, instance):
def thunk():
_env = CompiledJssEnvCP(instance)
return _env
return thunk
def solve(file):
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
num_workers = min(mp.cpu_count(), 32)
with torch.inference_mode():
device = torch.device('cpu')
actor = torch.jit.load('actor.pt', map_location=device)
actor.eval()
start_time = time.time()
fn_env = [make_env(0, file.name)
for _ in range(num_workers)]
ray_wrapper_env = WrapperRay(lambda n: fn_env[n](),
num_workers, 1, device)
envs = VecPyTorch(ray_wrapper_env, device)
current_solution_cost = float('inf')
current_solution = ''
obs = envs.reset()
total_episode = 0
while total_episode < envs.num_envs:
logits = actor(obs['interval_rep'], obs['attention_interval_mask'], obs['job_resource_mask'],
obs['action_mask'], obs['index_interval'], obs['start_end_tokens'])
# temperature vector
if num_workers >= 4:
temperature = torch.arange(0.5, 2.0, step=(1.5 / num_workers), device=device)
else:
temperature = torch.ones(num_workers, device=device)
logits = logits / temperature[:, None]
probs = Categorical(logits=logits).probs
# random sample based on logits
actions = torch.multinomial(probs, probs.shape[1]).cpu().numpy()
obs, reward, done, infos = envs.step(actions)
total_episode += done.sum()
# total_actions += 1
# print(f'Episode {total_episode} / {envs.num_envs} - Actions {total_actions}', end='\r')
for env_idx, info in enumerate(infos):
if 'makespan' in info and int(info['makespan']) < current_solution_cost:
current_solution_cost = int(info['makespan'])
current_solution = json.loads(info['solution'])
total_time = time.time() - start_time
pretty_output = ""
for job_id in range(len(current_solution)):
pretty_output += f"Job {job_id}: {current_solution[job_id]}\n"
jobs_data = []
file.seek(0)
line_str: str = file.readline()
line_cnt: int = 1
while line_str:
data = []
split_data = line_str.split()
if line_cnt == 1:
jobs_count, machines_count = int(split_data[0]), int(
split_data[1]
)
else:
i = 0
this_job_op_count = 0
while i < len(split_data):
machine, op_time = int(split_data[i]), int(split_data[i + 1])
data.append((machine, op_time))
i += 2
this_job_op_count += 1
jobs_data.append(data)
line_str = file.readline()
line_cnt += 1
visu.timeline(f'Solution for job-shop, solved using ')
visu.panel('Jobs')
# convert to integer the current_solution
current_solution = [[int(x) for x in y] for y in current_solution]
for job_id in range(len(current_solution)):
visu.sequence(name=f'J{job_id}', intervals=[(current_solution[job_id][task_id],
current_solution[job_id][task_id] + jobs_data[job_id][task_id][
1], jobs_data[job_id][task_id][0],
f'M{jobs_data[job_id][task_id][0]}')
for task_id in
range(len(current_solution[job_id]))])
visu.panel('Machines')
machine_solution = collections.defaultdict(list)
for job_id in range(len(current_solution)):
for task_id in range(len(current_solution[job_id])):
machine = jobs_data[job_id][task_id][1]
machine_solution[machine].append((current_solution[job_id][task_id],
current_solution[job_id][task_id] + jobs_data[job_id][task_id][1],
machine, f'J{job_id}'))
# sort dictionary keys
machine_solution = {k: machine_solution[k] for k in sorted(machine_solution.keys())}
for machine_id in machine_solution:
visu.sequence(name=f'M{machine_id}',
intervals=machine_solution[machine_id])
plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams["font.size"] = "30"
plt.gca().set_aspect('equal')
plt.rcParams["figure.figsize"] = (45, 50)
from io import BytesIO
buffer = BytesIO()
visu.show(pngfile=buffer)
reloadedPILImage = Image.open(buffer)
return pretty_output, reloadedPILImage, str(total_time) + " seconds"
title = "Job-Shop Scheduling CP RL"
description = "A Job-Shop Scheduling Reinforcement Learning based solver, using an underlying CP model as an " \
"environment. "
article = "<p style='text-align: center'>Article Under Review</p>"
examples = ['ta01', 'dmu01.txt', 'la01.txt']
iface = gr.Interface(fn=solve, inputs=gr.File(label="Instance File"), outputs=[gr.Text(label="Solution"), gr.Image(label="Solution's Gantt Chart"), gr.Text(label="Elapsed Time")], title=title, description=description, article=article, examples=examples)
iface.launch(enable_queue=True)