Skip to content

Neuron ref and monitor optimization #446

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 34 additions & 16 deletions bindsnet/network/monitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,36 +28,47 @@ def __init__(
state_vars: Iterable[str],
time: Optional[int] = None,
batch_size: int = 1,
device: str = "cpu",
):
# language=rst
"""
Constructs a ``Monitor`` object.

:param obj: An object to record state variables from during network simulation.
:param state_vars: Iterable of strings indicating names of state variables to
record.
:param state_vars: Iterable of strings indicating names of state variables to record.
:param time: If not ``None``, pre-allocate memory for state variable recording.
:param device: Allow the monitor to be on different device separate from Network device
"""
super().__init__()

self.obj = obj
self.state_vars = state_vars
self.time = time
self.batch_size = batch_size
self.device = device

# if time is not specified the monitor variable accumulate the logs
if self.time is None:
self.device = "cpu"

# Deal with time later, the same underlying list is used
self.recording = {v: [] for v in self.state_vars}
self.recording = []
self.reset_state_variables()

def get(self, var: str) -> torch.Tensor:
# language=rst
"""
Return recording to user.

:param var: State variable recording to return.
:return: Tensor of shape ``[time, n_1, ..., n_k]``, where ``[n_1, ..., n_k]`` is
the shape of the recorded state variable.
:return: Tensor of shape ``[time, n_1, ..., n_k]``, where ``[n_1, ..., n_k]`` is the shape of the recorded state
variable.
Note, if time == `None`, get return the logs and empty the monitor variable

"""
return torch.cat(self.recording[var], 0)
return_logs = torch.cat(self.recording[var], 0)
if self.time is None:
self.recording[var] = []
return return_logs

def record(self) -> None:
# language=rst
Expand All @@ -66,20 +77,27 @@ def record(self) -> None:
"""
for v in self.state_vars:
data = getattr(self.obj, v).unsqueeze(0)
self.recording[v].append(data.detach().clone())

# remove the oldest element (first in the list)
if self.time is not None:
for v in self.state_vars:
if len(self.recording[v]) > self.time:
self.recording[v].pop(0)
# self.recording[v].append(data.detach().clone().to(self.device))
self.recording[v].append(
torch.empty_like(data, device=self.device, requires_grad=False).copy_(
data, non_blocking=True
)
)
# remove the oldest element (first in the list)
if self.time is not None:
self.recording[v].pop(0)

def reset_state_variables(self) -> None:
# language=rst
"""
Resets recordings to empty ``torch.Tensor``s.
Resets recordings to empty ``List``s.
"""
self.recording = {v: [] for v in self.state_vars}
if self.time is None:
self.recording = {v: [] for v in self.state_vars}
else:
self.recording = {
v: [[] for i in range(self.time)] for v in self.state_vars
}


class NetworkMonitor(AbstractMonitor):
Expand Down
6 changes: 3 additions & 3 deletions bindsnet/network/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ class LIFNodes(Nodes):
# language=rst
"""
Layer of `leaky integrate-and-fire (LIF) neurons
<http://icwww.epfl.ch/~gerstner/SPNM/node26.html#SECTION02311000000000000000>`_.
<http://web.archive.org/web/20190318204706/http://icwww.epfl.ch/~gerstner/SPNM/node26.html#SECTION02311000000000000000>`_.
"""

def __init__(
Expand Down Expand Up @@ -683,7 +683,7 @@ class CurrentLIFNodes(Nodes):
# language=rst
"""
Layer of `current-based leaky integrate-and-fire (LIF) neurons
<http://icwww.epfl.ch/~gerstner/SPNM/node26.html#SECTION02313000000000000000>`_.
<http://web.archive.org/web/20190318204706/http://icwww.epfl.ch/~gerstner/SPNM/node26.html#SECTION02313000000000000000>`_.
Total synaptic input current is modeled as a decaying memory of input spikes multiplied by synaptic strengths.
"""

Expand Down Expand Up @@ -1148,7 +1148,7 @@ def set_batch_size(self, batch_size) -> None:
class IzhikevichNodes(Nodes):
# language=rst
"""
Layer of Izhikevich neurons.
Layer of `Izhikevich neurons<https://www.izhikevich.org/publications/spikes.htm>`_.
"""

def __init__(
Expand Down
6 changes: 3 additions & 3 deletions bindsnet/network/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def __init__(
w = self.wmin + torch.rand(source.n, target.n) * (self.wmax - self.wmin)
else:
if self.wmin != -np.inf or self.wmax != np.inf:
w = torch.clamp(w, self.wmin, self.wmax)
w = torch.clamp(torch.as_tensor(w), self.wmin, self.wmax)

self.w = Parameter(w, requires_grad=False)

Expand Down Expand Up @@ -381,10 +381,10 @@ def normalize(self) -> None:
if self.norm is not None:
# get a view and modify in place
w = self.w.view(
self.w.size(0) * self.w.size(1), self.w.size(2) * self.w.size(3)
self.w.shape[0] * self.w.shape[1], self.w.shape[2] * self.w.shape[3]
)

for fltr in range(w.size(0)):
for fltr in range(w.shape[0]):
w[fltr] *= self.norm / w[fltr].sum(0)

def reset_state_variables(self) -> None:
Expand Down
Binary file added examples/mnist/plots/assaiments/assaiments.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/mnist/plots/performance/performance.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/mnist/plots/weights/weights.1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.