import collections import random import time import multiprocessing as mp import json from PIL import Image from compiled_jss.CPEnv import CompiledJssEnvCP from stable_baselines3.common.vec_env import VecEnvWrapper from torch.distributions import Categorical import torch import numpy as np from MyVecEnv import WrapperRay import gradio as gr import docplex.cp.utils_visu as visu import matplotlib.pyplot as plt class VecPyTorch(VecEnvWrapper): def __init__(self, venv, device): super(VecPyTorch, self).__init__(venv) self.device = device def reset(self): return self.venv.reset() def step_async(self, actions): self.venv.step_async(actions) def step_wait(self): return self.venv.step_wait() def make_env(seed, instance): def thunk(): _env = CompiledJssEnvCP(instance) return _env return thunk def solve(file): random.seed(0) np.random.seed(0) torch.manual_seed(0) num_workers = min(mp.cpu_count(), 32) with torch.inference_mode(): device = torch.device('cpu') actor = torch.jit.load('actor.pt', map_location=device) actor.eval() start_time = time.time() fn_env = [make_env(0, file.name) for _ in range(num_workers)] ray_wrapper_env = WrapperRay(lambda n: fn_env[n](), num_workers, 1, device) envs = VecPyTorch(ray_wrapper_env, device) current_solution_cost = float('inf') current_solution = '' obs = envs.reset() total_episode = 0 while total_episode < envs.num_envs: logits = actor(obs['interval_rep'], obs['attention_interval_mask'], obs['job_resource_mask'], obs['action_mask'], obs['index_interval'], obs['start_end_tokens']) # temperature vector if num_workers >= 4: temperature = torch.arange(0.5, 2.0, step=(1.5 / num_workers), device=device) else: temperature = torch.ones(num_workers, device=device) logits = logits / temperature[:, None] probs = Categorical(logits=logits).probs # random sample based on logits actions = torch.multinomial(probs, probs.shape[1]).cpu().numpy() obs, reward, done, infos = envs.step(actions) total_episode += done.sum() # total_actions += 1 # print(f'Episode {total_episode} / {envs.num_envs} - Actions {total_actions}', end='\r') for env_idx, info in enumerate(infos): if 'makespan' in info and int(info['makespan']) < current_solution_cost: current_solution_cost = int(info['makespan']) current_solution = json.loads(info['solution']) total_time = time.time() - start_time pretty_output = "" for job_id in range(len(current_solution)): pretty_output += f"Job {job_id}: {current_solution[job_id]}\n" jobs_data = [] file.seek(0) line_str: str = file.readline() line_cnt: int = 1 while line_str: data = [] split_data = line_str.split() if line_cnt == 1: jobs_count, machines_count = int(split_data[0]), int( split_data[1] ) else: i = 0 this_job_op_count = 0 while i < len(split_data): machine, op_time = int(split_data[i]), int(split_data[i + 1]) data.append((machine, op_time)) i += 2 this_job_op_count += 1 jobs_data.append(data) line_str = file.readline() line_cnt += 1 visu.timeline(f'Solution for job-shop, solved using ') visu.panel('Jobs') # convert to integer the current_solution current_solution = [[int(x) for x in y] for y in current_solution] for job_id in range(len(current_solution)): visu.sequence(name=f'J{job_id}', intervals=[(current_solution[job_id][task_id], current_solution[job_id][task_id] + jobs_data[job_id][task_id][ 1], jobs_data[job_id][task_id][0], f'M{jobs_data[job_id][task_id][0]}') for task_id in range(len(current_solution[job_id]))]) visu.panel('Machines') machine_solution = collections.defaultdict(list) for job_id in range(len(current_solution)): for task_id in range(len(current_solution[job_id])): machine = jobs_data[job_id][task_id][1] machine_solution[machine].append((current_solution[job_id][task_id], current_solution[job_id][task_id] + jobs_data[job_id][task_id][1], machine, f'J{job_id}')) # sort dictionary keys machine_solution = {k: machine_solution[k] for k in sorted(machine_solution.keys())} for machine_id in machine_solution: visu.sequence(name=f'M{machine_id}', intervals=machine_solution[machine_id]) plt.rcParams["font.family"] = "Times New Roman" plt.rcParams["font.size"] = "30" plt.gca().set_aspect('equal') plt.rcParams["figure.figsize"] = (45, 50) from io import BytesIO buffer = BytesIO() visu.show(pngfile=buffer) reloadedPILImage = Image.open(buffer) return pretty_output, reloadedPILImage, str(total_time) + " seconds" title = "Job-Shop Scheduling CP RL" description = "A Job-Shop Scheduling Reinforcement Learning based solver, using an underlying CP model as an " \ "environment. " article = "

Article Under Review

" examples = ['ta01', 'dmu01.txt', 'la01.txt'] iface = gr.Interface(fn=solve, inputs=gr.File(label="Instance File"), outputs=[gr.Text(label="Solution"), gr.Image(label="Solution's Gantt Chart"), gr.Text(label="Elapsed Time")], title=title, description=description, article=article, examples=examples) iface.launch(enable_queue=True)