Skip to content

Commit 1ad2e1f

Browse files
authored
Have a dedicated neighborhood property and a get_neighborhood method on Cell (#2309)
1 parent 25925a2 commit 1ad2e1f

File tree

4 files changed

+46
-10
lines changed

4 files changed

+46
-10
lines changed

benchmarks/Schelling/schelling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(self, model, agent_type, radius, homophily):
2424
def step(self):
2525
"""Run one step of the agent."""
2626
similar = 0
27-
neighborhood = self.cell.neighborhood(radius=self.radius)
27+
neighborhood = self.cell.get_neighborhood(radius=self.radius)
2828
for neighbor in neighborhood.agents:
2929
if neighbor.type == self.type:
3030
similar += 1

benchmarks/WolfSheep/wolf_sheep.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(self, model, energy, p_reproduce, energy_from_food):
3333

3434
def random_move(self):
3535
"""Move to a random neighboring cell."""
36-
self.move_to(self.cell.neighborhood().select_random_cell())
36+
self.move_to(self.cell.neighborhood.select_random_cell())
3737

3838
def spawn_offspring(self):
3939
"""Create offspring."""

mesa/experimental/cell_space/cell.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from functools import cache
5+
from functools import cache, cached_property
66
from random import Random
77
from typing import TYPE_CHECKING
88

@@ -34,6 +34,7 @@ class Cell:
3434
"capacity",
3535
"properties",
3636
"random",
37+
"__dict__",
3738
]
3839

3940
# def __new__(cls,
@@ -131,10 +132,33 @@ def is_full(self) -> bool:
131132
def __repr__(self): # noqa
132133
return f"Cell({self.coordinate}, {self.agents})"
133134

135+
@cached_property
136+
def neighborhood(self) -> CellCollection:
137+
"""Returns the direct neigborhood of the cell.
138+
139+
This is equivalent to cell.get_neighborhood(radius=1)
140+
141+
"""
142+
return self.get_neighborhood()
143+
134144
# FIXME: Revisit caching strategy on methods
135145
@cache # noqa: B019
136-
def neighborhood(self, radius: int = 1, include_center: bool = False):
137-
"""Returns a list of all neighboring cells."""
146+
def get_neighborhood(
147+
self, radius: int = 1, include_center: bool = False
148+
) -> CellCollection:
149+
"""Returns a list of all neighboring cells for the given radius.
150+
151+
For getting the direct neighborhood (i.e., radius=1) you can also use
152+
the `neighborhood` property.
153+
154+
Args:
155+
radius (int): the radius of the neighborhood
156+
include_center (bool): include the center of the neighborhood
157+
158+
Returns:
159+
a list of all neighboring cells
160+
161+
"""
138162
return CellCollection[Cell](
139163
self._neighborhood(radius=radius, include_center=include_center),
140164
random=self.random,

tests/test_cell_space.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -280,33 +280,45 @@ def test_cell_neighborhood():
280280
height = 10
281281
grid = OrthogonalVonNeumannGrid((width, height), torus=False, capacity=None)
282282
for radius, n in zip(range(1, 4), [2, 5, 9]):
283-
neighborhood = grid._cells[(0, 0)].neighborhood(radius=radius)
283+
if radius == 1:
284+
neighborhood = grid._cells[(0, 0)].neighborhood
285+
else:
286+
neighborhood = grid._cells[(0, 0)].get_neighborhood(radius=radius)
284287
assert len(neighborhood) == n
285288

286289
## Moore
287290
width = 10
288291
height = 10
289292
grid = OrthogonalMooreGrid((width, height), torus=False, capacity=None)
290293
for radius, n in zip(range(1, 4), [3, 8, 15]):
291-
neighborhood = grid._cells[(0, 0)].neighborhood(radius=radius)
294+
if radius == 1:
295+
neighborhood = grid._cells[(0, 0)].neighborhood
296+
else:
297+
neighborhood = grid._cells[(0, 0)].get_neighborhood(radius=radius)
292298
assert len(neighborhood) == n
293299

294300
with pytest.raises(ValueError):
295-
grid._cells[(0, 0)].neighborhood(radius=0)
301+
grid._cells[(0, 0)].get_neighborhood(radius=0)
296302

297303
# hexgrid
298304
width = 10
299305
height = 10
300306
grid = HexGrid((width, height), torus=False, capacity=None)
301307
for radius, n in zip(range(1, 4), [2, 6, 11]):
302-
neighborhood = grid._cells[(0, 0)].neighborhood(radius=radius)
308+
if radius == 1:
309+
neighborhood = grid._cells[(0, 0)].neighborhood
310+
else:
311+
neighborhood = grid._cells[(0, 0)].get_neighborhood(radius=radius)
303312
assert len(neighborhood) == n
304313

305314
width = 10
306315
height = 10
307316
grid = HexGrid((width, height), torus=False, capacity=None)
308317
for radius, n in zip(range(1, 4), [5, 10, 17]):
309-
neighborhood = grid._cells[(1, 0)].neighborhood(radius=radius)
318+
if radius == 1:
319+
neighborhood = grid._cells[(1, 0)].neighborhood
320+
else:
321+
neighborhood = grid._cells[(1, 0)].get_neighborhood(radius=radius)
310322
assert len(neighborhood) == n
311323

312324
# networkgrid

0 commit comments

Comments
 (0)