Skip to content

JobShop

Bases: Environment[State, MultiDiscreteArray, Observation]

The Job Shop Scheduling Problem, as described in [1], is one of the best known combinatorial optimization problems. We are given num_jobs jobs, each consisting of at most max_num_ops ops, which need to be processed on num_machines machines. Each operation (op) has a specific machine that it needs to be processed on and a duration (which must be less than or equal to max_duration_op). The goal is to minimise the total length of the schedule, also known as the makespan.

[1] https://developers.google.com/optimization/scheduling/job_shop.

  • observation: Observation

    • ops_machine_ids: jax array (int32) of (num_jobs, max_num_ops) id of the machine each operation must be processed on.
    • ops_durations: jax array (int32) of (num_jobs, max_num_ops) processing time of each operation.
    • ops_mask: jax array (bool) of (num_jobs, max_num_ops) indicating which operations have yet to be scheduled.
    • machines_job_ids: jax array (int32) of shape (num_machines,) id of the job (or no-op) that each machine is processing.
    • machines_remaining_times: jax array (int32) of shape (num_machines,) specifying, for each machine, the number of time steps until available.
    • action_mask: jax array (bool) of shape (num_machines, num_jobs + 1) indicates which job(s) (or no-op) can legally be scheduled on each machine.
  • action: jax array (int32) of shape (num_machines,).

  • reward: jax array (float) of shape (). A reward of -1 is given each time step. If all machines are simultaneously idle or the agent selects an invalid action, the agent is given a large penalty of -num_jobs * max_num_ops * max_op_duration which is an upper bound on the makespan.

  • episode termination:

    • Finished schedule: all operations (and thus all jobs) every job have been processed.
    • Illegal action: the agent ignores the action mask and takes an illegal action.
    • Simultaneously idle: all machines are inactive at the same time.
  • state: State

    • ops_machine_ids: same as observation.
    • ops_durations: same as observation.
    • ops_mask: same as observation.
    • machines_job_ids: same as observation.
    • machines_remaining_times: same as observation.
    • action_mask: same as observation.
    • step_count: jax array (int32) of shape (), the number of time steps in the episode so far.
    • scheduled_times: jax array (int32) of shape (num_jobs, max_num_ops), specifying the timestep at which every op (scheduled so far) was scheduled.
1
2
3
4
5
6
7
8
from jumanji.environments import JobShop
env = JobShop()
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
env.render(state)

Instantiate a JobShop environment.

Parameters:

Name Type Description Default
generator Optional[Generator]

Generator whose __call__ instantiates an environment instance. Implemented options are ['ToyGenerator', 'RandomGenerator']. Defaults to RandomGenerator with 20 jobs, 10 machines, up to 8 ops for any given job, and a max operation duration of 6.

None
viewer Optional[Viewer[State]]

Viewer used for rendering. Defaults to JobShopViewer.

None
Source code in jumanji/environments/packing/job_shop/env.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
def __init__(
    self,
    generator: Optional[Generator] = None,
    viewer: Optional[Viewer[State]] = None,
):
    """Instantiate a `JobShop` environment.

    Args:
        generator: `Generator` whose `__call__` instantiates an environment instance.
            Implemented options are ['ToyGenerator', 'RandomGenerator'].
            Defaults to `RandomGenerator` with 20 jobs, 10 machines, up to 8 ops
            for any given job, and a max operation duration of 6.
        viewer: `Viewer` used for rendering. Defaults to `JobShopViewer`.
    """
    self.generator = generator or RandomGenerator(
        num_jobs=20,
        num_machines=10,
        max_num_ops=8,
        max_op_duration=6,
    )
    self.num_jobs = self.generator.num_jobs
    self.num_machines = self.generator.num_machines
    self.max_num_ops = self.generator.max_num_ops
    self.max_op_duration = self.generator.max_op_duration
    super().__init__()

    # Define the "job id" of a no-op action as the number of jobs
    self.no_op_idx = self.num_jobs

    # Create viewer used for rendering
    self._viewer = viewer or JobShopViewer(
        "JobShop",
        self.num_jobs,
        self.num_machines,
        self.max_num_ops,
        self.max_op_duration,
    )

action_spec: specs.MultiDiscreteArray cached property #

Specifications of the action in the JobShop environment. The action gives each machine a job id ranging from 0, 1, ..., num_jobs where the last value corresponds to a no-op.

Returns:

Name Type Description
action_spec MultiDiscreteArray

a specs.MultiDiscreteArray spec.

observation_spec: specs.Spec[Observation] cached property #

Specifications of the observation of the JobShop environment.

Returns:

Type Description
Spec[Observation]

Spec containing the specifications for all the Observation fields:

Spec[Observation]
  • ops_machine_ids: BoundedArray (int32) of shape (num_jobs, max_num_ops).
Spec[Observation]
  • ops_durations: BoundedArray (int32) of shape (num_jobs, max_num_ops).
Spec[Observation]
  • ops_mask: BoundedArray (bool) of shape (num_jobs, max_num_ops).
Spec[Observation]
  • machines_job_ids: BoundedArray (int32) of shape (num_machines,).
Spec[Observation]
  • machines_remaining_times: BoundedArray (int32) of shape (num_machines,).
Spec[Observation]
  • action_mask: BoundedArray (bool) of shape (num_machines, num_jobs + 1).

animate(states, interval=200, save_path=None) #

Creates an animated gif of the Jobshop environment based on the sequence of states.

Parameters:

Name Type Description Default
states Sequence[State]

sequence of environment states corresponding to consecutive timesteps.

required
interval int

delay between frames in milliseconds, default to 200.

200
save_path Optional[str]

the path where the animation file should be saved. If it is None, the plot will not be saved.

None

Returns:

Type Description
FuncAnimation

animation.FuncAnimation: the animation object that was created.

Source code in jumanji/environments/packing/job_shop/env.py
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
def animate(
    self,
    states: Sequence[State],
    interval: int = 200,
    save_path: Optional[str] = None,
) -> matplotlib.animation.FuncAnimation:
    """Creates an animated gif of the Jobshop environment based on the sequence of states.

    Args:
        states: sequence of environment states corresponding to consecutive timesteps.
        interval: delay between frames in milliseconds, default to 200.
        save_path: the path where the animation file should be saved. If it is None, the plot
            will not be saved.

    Returns:
        animation.FuncAnimation: the animation object that was created.
    """
    return self._viewer.animate(states, interval, save_path)

close() #

Perform any necessary cleanup.

Environments will automatically :meth:close() themselves when garbage collected or when the program exits.

Source code in jumanji/environments/packing/job_shop/env.py
443
444
445
446
447
448
449
def close(self) -> None:
    """Perform any necessary cleanup.

    Environments will automatically :meth:`close()` themselves when
    garbage collected or when the program exits.
    """
    self._viewer.close()

render(state) #

Render the given state of the environment. This rendering shows which job (or no-op) is running on each machine for the current time step and previous time steps.

Parameters:

Name Type Description Default
state State

State object containing the current environment state.

required
Source code in jumanji/environments/packing/job_shop/env.py
434
435
436
437
438
439
440
441
def render(self, state: State) -> Optional[NDArray]:
    """Render the given state of the environment. This rendering shows which job (or no-op)
    is running on each machine for the current time step and previous time steps.

    Args:
        state: `State` object containing the current environment state.
    """
    return self._viewer.render(state)

reset(key) #

Resets the environment by creating a new problem instance and initialising the state and timestep.

Parameters:

Name Type Description Default
key PRNGKey

random key used to reset the environment.

required

Returns:

Name Type Description
state State

the environment state after the reset.

timestep TimeStep[Observation]

the first timestep returned by the environment after the reset.

Source code in jumanji/environments/packing/job_shop/env.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
    """Resets the environment by creating a new problem instance and initialising the state
    and timestep.

    Args:
        key: random key used to reset the environment.

    Returns:
        state: the environment state after the reset.
        timestep: the first timestep returned by the environment after the reset.
    """
    # Generate a new problem instance
    state = self.generator(key)

    # Create the action mask and update the state
    state.action_mask = self._create_action_mask(
        state.machines_job_ids,
        state.machines_remaining_times,
        state.ops_machine_ids,
        state.ops_mask,
    )

    # Get the observation and the timestep
    obs = self._observation_from_state(state)
    timestep = restart(observation=obs)

    return state, timestep

step(state, action) #

Updates the status of all machines, the status of the operations, and increments the time step. It updates the environment state and the timestep (which contains the new observation). It calculates the reward based on the three terminal conditions: - The action provided by the agent is invalid. - The schedule has finished. - All machines do a no-op that leads to all machines being simultaneously idle.

Parameters:

Name Type Description Default
state State

the environment state.

required
action Array

the action to take.

required

Returns:

Name Type Description
state State

the updated environment state.

timestep TimeStep[Observation]

the updated timestep.

Source code in jumanji/environments/packing/job_shop/env.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]:
    """Updates the status of all machines, the status of the operations, and increments the
    time step. It updates the environment state and the timestep (which contains the new
    observation). It calculates the reward based on the three terminal conditions:
        - The action provided by the agent is invalid.
        - The schedule has finished.
        - All machines do a no-op that leads to all machines being simultaneously idle.

    Args:
        state: the environment state.
        action: the action to take.

    Returns:
        state: the updated environment state.
        timestep: the updated timestep.
    """
    # Check the action is legal
    invalid = ~jnp.all(state.action_mask[jnp.arange(self.num_machines), action])  # type: ignore

    # Obtain the id for every job's next operation
    op_ids = jnp.argmax(state.ops_mask, axis=-1)

    # Update the status of all machines
    (
        updated_machines_job_ids,
        updated_machines_remaining_times,
    ) = self._update_machines(
        action,
        op_ids,
        state.machines_job_ids,
        state.machines_remaining_times,
        state.ops_durations,
    )

    # Update the status of operations that have been scheduled
    updated_ops_mask, updated_scheduled_times = self._update_operations(
        action,
        op_ids,
        state.step_count,
        state.scheduled_times,
        state.ops_mask,
    )

    # Update the action_mask
    updated_action_mask = self._create_action_mask(
        updated_machines_job_ids,
        updated_machines_remaining_times,
        state.ops_machine_ids,
        updated_ops_mask,
    )

    # Increment the time step
    updated_step_count = jnp.array(state.step_count + 1, jnp.int32)

    # Check if all machines are idle simultaneously
    all_machines_idle = jnp.all(
        (updated_machines_job_ids == self.no_op_idx) & (updated_machines_remaining_times == 0)
    )

    # Check if the schedule has finished
    all_operations_scheduled = ~jnp.any(updated_ops_mask)
    schedule_finished = all_operations_scheduled & jnp.all(
        updated_machines_remaining_times == 0
    )

    # Update the state and extract the next observation
    next_state = State(
        ops_machine_ids=state.ops_machine_ids,
        ops_durations=state.ops_durations,
        ops_mask=updated_ops_mask,
        machines_job_ids=updated_machines_job_ids,
        machines_remaining_times=updated_machines_remaining_times,
        action_mask=updated_action_mask,
        step_count=updated_step_count,
        scheduled_times=updated_scheduled_times,
        key=state.key,
    )
    next_obs = self._observation_from_state(next_state)

    # Compute terminal condition
    done = invalid | all_machines_idle | schedule_finished

    # Compute reward
    reward = jnp.where(
        invalid | all_machines_idle,
        jnp.array(-self.num_jobs * self.max_num_ops * self.max_op_duration, float),
        jnp.array(-1, float),
    )

    timestep = jax.lax.cond(
        done,
        termination,
        transition,
        reward,
        next_obs,
    )

    return next_state, timestep