Skip to content

SearchAndRescue

Bases: Environment

A multi-agent search environment

Environment modelling a collection of agents collectively searching for a set of targets on a 2d environment. Agents are rewarded (individually) for coming within a fixed range of a target that has not already been detected. Agents visualise their local environment (i.e. the location of other agents and targets) via a simple segmented view model. The environment area is a uniform square space with wrapped boundaries.

An episode will terminate if all targets have been located by the team of searching agents.

  • observation: Observation searcher_views: jax array (float) of shape (num_searchers, channels, num_vision) Individual local views of positions of other agents and targets, where channels can be used to differentiate between agents and targets types. Each entry in the view indicates the distance to another agent/target along a ray from the agent, and is -1.0 if nothing is in range along the ray. The view model can be customised by implementing the ObservationFn interface. targets_remaining: (float) Number of targets remaining to be found from the total scaled to the range [0, 1] (i.e. a value of 1.0 indicates all the targets are still to be found). step: (int) current simulation step. positions: jax array (float) of shape (num_searchers, 2) search agent positions.

  • action: jax array (float) of shape (num_searchers, 2) Array of individual agent actions. Each agents actions rotate and accelerate/decelerate the agent as [rotation, acceleration] on the range [-1, 1]. These values are then scaled to update agent velocities within given parameters (i.e. a value of -+1 is the maximum acceleration/rotation).

  • reward: jax array (float) of shape (num_searchers,) Arrays of individual agent rewards. A reward of +1 is granted when an agent comes into contact range with a target that has not yet been found, and the target is within the searchers view cone. It is possible for multiple agents to newly find the same target within a given step, by default in this case the reward is split between the locating agents. By default, rewards granted linearly decrease over time, with zero reward granted at the environment time-limit. These defaults can be modified by flags in IndividualRewardFn, or further customised by implementing the RewardFn interface.

  • state: State

    • searchers: AgentState
      • pos: jax array (float) of shape (num_searchers, 2) in the range [0, env_size].
      • heading: jax array (float) of shape (num_searcher,) in the range [0, 2π].
      • speed: jax array (float) of shape (num_searchers,) in the range [min_speed, max_speed].
    • targets: TargetState
      • pos: jax array (float) of shape (num_targets, 2) in the range [0, env_size].
      • vel: jax array (float) of shape (num_targets, 2).
      • found: jax array (bool) of shape (num_targets,) flag indicating if target has been located by an agent.
    • key: jax array (uint32) of shape (2,)
    • step: int representing the current simulation step.
1
2
3
4
5
6
7
8
9
from jumanji.environments import SearchAndRescue

env = SearchAndRescue()
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
env.render(state)

Instantiates a SearchAndRescue environment

Parameters:

Name Type Description Default
target_contact_range float

Range at which a searchers will 'find' a target.

0.02
searcher_max_rotate float

Maximum rotation searcher agents can turn within a step. Should be a value from [0,1] representing a fraction of π-radians.

0.25
searcher_max_accelerate float

Magnitude of the maximum acceleration/deceleration a searcher agent can apply within a step.

0.005
searcher_min_speed float

Minimum speed a searcher agent can move at.

0.005
searcher_max_speed float

Maximum speed a searcher agent can move at.

0.02
time_limit int

Maximum number of environment steps allowed for search.

400
viewer Optional[Viewer[State]]

Viewer used for rendering. Defaults to SearchAndRescueViewer.

None
target_dynamics Optional[TargetDynamics]

Target object dynamics model, implemented as a TargetDynamics interface. Defaults to RandomWalk.

None
generator Optional[Generator]

Initial state Generator instance. Defaults to RandomGenerator with 50 targets and 2 searchers, with targets uniformly distributed across the environment.

None
reward_fn Optional[RewardFn]

Reward aggregation function. Defaults to IndividualRewardFn where agents split rewards if they locate a target simultaneously, and rewards linearly decrease to zero over time.

None
observation Optional[ObservationFn]

Agent observation view generation function. Defaults to AgentAndTargetObservationFn where all targets (found and unfound) and other searching agents are included in the generated view.

None
Source code in jumanji/environments/swarms/search_and_rescue/env.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
def __init__(
    self,
    target_contact_range: float = 0.02,
    searcher_max_rotate: float = 0.25,
    searcher_max_accelerate: float = 0.005,
    searcher_min_speed: float = 0.005,
    searcher_max_speed: float = 0.02,
    time_limit: int = 400,
    viewer: Optional[Viewer[State]] = None,
    target_dynamics: Optional[TargetDynamics] = None,
    generator: Optional[Generator] = None,
    reward_fn: Optional[RewardFn] = None,
    observation: Optional[ObservationFn] = None,
) -> None:
    """Instantiates a `SearchAndRescue` environment

    Args:
        target_contact_range: Range at which a searchers will 'find' a target.
        searcher_max_rotate: Maximum rotation searcher agents can
            turn within a step. Should be a value from [0,1]
            representing a fraction of π-radians.
        searcher_max_accelerate: Magnitude of the maximum
            acceleration/deceleration a searcher agent can apply within a step.
        searcher_min_speed: Minimum speed a searcher agent can move at.
        searcher_max_speed: Maximum speed a searcher agent can move at.
        time_limit: Maximum number of environment steps allowed for search.
        viewer: `Viewer` used for rendering. Defaults to `SearchAndRescueViewer`.
        target_dynamics: Target object dynamics model, implemented as a
            `TargetDynamics` interface. Defaults to `RandomWalk`.
        generator: Initial state `Generator` instance. Defaults to `RandomGenerator`
            with 50 targets and 2 searchers, with targets uniformly distributed
            across the environment.
        reward_fn: Reward aggregation function. Defaults to `IndividualRewardFn` where
            agents split rewards if they locate a target simultaneously, and
            rewards linearly decrease to zero over time.
        observation: Agent observation view generation function. Defaults to
            `AgentAndTargetObservationFn` where all targets (found and unfound)
            and other searching agents are included in the generated view.
    """

    self.target_contact_range = target_contact_range

    self.searcher_params = AgentParams(
        max_rotate=searcher_max_rotate,
        max_accelerate=searcher_max_accelerate,
        min_speed=searcher_min_speed,
        max_speed=searcher_max_speed,
    )
    self.time_limit = time_limit
    self._target_dynamics = target_dynamics or RandomWalk(acc_std=0.0001, vel_max=0.002)
    self.generator = generator or RandomGenerator(num_targets=40, num_searchers=2)
    self._viewer = viewer or SearchAndRescueViewer()
    self._reward_fn = reward_fn or IndividualRewardFn()
    self._observation_fn = observation or AgentAndTargetObservationFn(
        num_vision=128,
        searcher_vision_range=0.4,
        target_vision_range=0.1,
        view_angle=0.4,
        agent_radius=target_contact_range,
        env_size=self.generator.env_size,
    )
    super().__init__()

action_spec cached property #

Returns the action spec.

2d array of individual agent actions. Each agents action is an array representing [rotation, acceleration] in the range [-1, 1].

Returns:

Name Type Description
action_spec BoundedArray

Action array spec

observation_spec cached property #

Returns the observation spec.

Local searcher agent views representing the distance to the closest neighbouring agents and targets in the environment.

Returns:

Name Type Description
observation_spec Spec[Observation]

Search-and-rescue observation spec

reward_spec cached property #

Returns the reward spec.

Array of individual rewards for each agent.

Returns:

Name Type Description
reward_spec BoundedArray

Reward array spec.

animate(states, interval=100, save_path=None) #

Create an animation from a sequence of environment states.

Parameters:

Name Type Description Default
states Sequence[State]

sequence of environment states corresponding to consecutive timesteps.

required
interval int

delay between frames in milliseconds.

100
save_path Optional[str]

the path where the animation file should be saved. If it is None, the plot will not be saved.

None

Returns:

Type Description
FuncAnimation

Animation that can be saved as a GIF, MP4, or rendered with HTML.

Source code in jumanji/environments/swarms/search_and_rescue/env.py
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
def animate(
    self,
    states: Sequence[State],
    interval: int = 100,
    save_path: Optional[str] = None,
) -> FuncAnimation:
    """Create an animation from a sequence of environment states.

    Args:
        states: sequence of environment states corresponding to consecutive
            timesteps.
        interval: delay between frames in milliseconds.
        save_path: the path where the animation file should be saved. If it
            is None, the plot will not be saved.

    Returns:
        Animation that can be saved as a GIF, MP4, or rendered with HTML.
    """
    return self._viewer.animate(states, interval=interval, save_path=save_path)

close() #

Perform any necessary cleanup.

Source code in jumanji/environments/swarms/search_and_rescue/env.py
392
393
394
def close(self) -> None:
    """Perform any necessary cleanup."""
    self._viewer.close()

render(state) #

Render a frame of the environment for a given state using matplotlib.

Parameters:

Name Type Description Default
state State

State object.

required
Source code in jumanji/environments/swarms/search_and_rescue/env.py
364
365
366
367
368
369
370
def render(self, state: State) -> None:
    """Render a frame of the environment for a given state using matplotlib.

    Args:
        state: State object.
    """
    self._viewer.render(state)

reset(key) #

Initialise searcher and target initial states.

Parameters:

Name Type Description Default
key PRNGKey

Random key used to reset the environment.

required

Returns:

Name Type Description
state State

Initial environment state.

timestep TimeStep[Observation]

TimeStep with individual search agent views.

Source code in jumanji/environments/swarms/search_and_rescue/env.py
203
204
205
206
207
208
209
210
211
212
213
214
215
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
    """Initialise searcher and target initial states.

    Args:
        key: Random key used to reset the environment.

    Returns:
        state: Initial environment state.
        timestep: TimeStep with individual search agent views.
    """
    state = self.generator(key, self.searcher_params)
    timestep = restart(observation=self._state_to_observation(state), shape=(self.num_agents,))
    return state, timestep

step(state, actions) #

Environment update.

Update searcher velocities and consequently their positions, mark found targets, and generate rewards and local observations.

Parameters:

Name Type Description Default
state State

Environment state.

required
actions Array

2d array of searcher steering actions.

required

Returns:

Name Type Description
state State

Updated searcher and target positions and velocities.

timestep TimeStep[Observation]

Transition timestep with individual agent local observations.

Source code in jumanji/environments/swarms/search_and_rescue/env.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Observation]]:
    """Environment update.

    Update searcher velocities and consequently their positions,
    mark found targets, and generate rewards and local observations.

    Args:
        state: Environment state.
        actions: 2d array of searcher steering actions.

    Returns:
        state: Updated searcher and target positions and velocities.
        timestep: Transition timestep with individual agent local observations.
    """
    key, target_key = jax.random.split(state.key, num=2)
    searchers = update_state(
        self.generator.env_size, self.searcher_params, state.searchers, actions
    )
    targets = self._target_dynamics(target_key, state.targets, self.generator.env_size)

    # Searchers return an array of flags of any targets they are in range of,
    #  and that have not already been located, result shape here is (n-searcher, n-targets)
    targets_found = spatial(
        utils.searcher_detect_targets,
        reduction=esquilax.reductions.logical_or((self.generator.num_targets,)),
        i_range=self.target_contact_range,
        dims=self.generator.env_size,
    )(
        self._observation_fn.view_angle,
        searchers,
        (jnp.arange(self.generator.num_targets), targets),
        pos=searchers.pos,
        pos_b=targets.pos,
        env_size=self.generator.env_size,
        n_targets=self.generator.num_targets,
    )

    rewards = self._reward_fn(targets_found, state.step, self.time_limit)

    targets_found = jnp.any(targets_found, axis=0)
    # Targets need to remain found if they already have been
    targets_found = jnp.logical_or(targets_found, state.targets.found)

    state = State(
        searchers=searchers,
        targets=TargetState(pos=targets.pos, vel=targets.vel, found=targets_found),
        key=key,
        step=state.step + 1,
    )
    observation = self._state_to_observation(state)
    observation = jax.lax.stop_gradient(observation)
    timestep = jax.lax.cond(
        jnp.logical_or(state.step >= self.time_limit, jnp.all(targets_found)),
        partial(termination, shape=(self.num_agents,)),
        partial(transition, shape=(self.num_agents,)),
        rewards,
        observation,
    )
    return state, timestep