Skip to content

Add testing phase #386

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 77 additions & 10 deletions examples/mnist/SOM_LM-SNNs.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,16 @@

# Sets up Gpu use
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running on Device = ", device)
if torch.cuda.is_available():
if gpu and torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
else:
torch.manual_seed(seed)
if gpu:
gpu = False
device = 'cpu'

torch.set_num_threads(os.cpu_count() - 1)
print("Running on Device = ", device)

# Determines number of workers to use
if n_workers == -1:
Expand Down Expand Up @@ -102,7 +106,7 @@
)

# Record spikes during the simulation.
spike_record = torch.zeros(update_interval, time, n_neurons).cpu()
spike_record = torch.zeros(update_interval, int(time/dt), n_neurons).cpu()

# Neuron assignments and spike proportions.
n_classes = 10
Expand All @@ -114,18 +118,18 @@
accuracy = {"all": [], "proportion": []}

# Voltage recording for excitatory and inhibitory layers.
som_voltage_monitor = Monitor(network.layers["Y"], ["v"], time=time)
som_voltage_monitor = Monitor(network.layers["Y"], ["v"], time=int(time/dt))
network.add_monitor(som_voltage_monitor, name="som_voltage")

# Set up monitors for spikes and voltages
spikes = {}
for layer in set(network.layers):
spikes[layer] = Monitor(network.layers[layer], state_vars=["s"], time=time)
spikes[layer] = Monitor(network.layers[layer], state_vars=["s"], time=int(time/dt))
network.add_monitor(spikes[layer], name="%s_spikes" % layer)

voltages = {}
for layer in set(network.layers) - {"X"}:
voltages[layer] = Monitor(network.layers[layer], state_vars=["v"], time=time)
voltages[layer] = Monitor(network.layers[layer], state_vars=["v"], time=int(time/dt))
network.add_monitor(voltages[layer], name="%s_voltages" % layer)

inpt_ims, inpt_axes = None, None
Expand Down Expand Up @@ -164,7 +168,7 @@

for step, batch in enumerate(tqdm(dataloader)):
# Get next input sample.
inputs = {"X": batch["encoded_image"].view(time, 1, 1, 28, 28).to(device)}
inputs = {"X": batch["encoded_image"].view(int(time/dt), 1, 1, 28, 28).to(device)}

if step > 0:
if step % update_inhibation_weights == 0:
Expand Down Expand Up @@ -243,10 +247,10 @@
if temp_spikes.sum().sum() < 2:
inputs["X"] *= (
poisson(
datum=factor * batch["image"].clamp(min=0), dt=dt, time=time
datum=factor * batch["image"].clamp(min=0), dt=dt, time=int(time/dt)
)
.to(device)
.view(time, 1, 1, 28, 28)
.view(int(time/dt), 1, 1, 28, 28)
)
factor *= factor
else:
Expand All @@ -256,7 +260,7 @@
exc_voltages = som_voltage_monitor.get("v")

# Add to spikes recording.
spike_record[step % update_interval] = temp_spikes.detach().clone().cpu()
# spike_record[step % update_interval] = temp_spikes.detach().clone().cpu()
spike_record[step % update_interval].copy_(temp_spikes, non_blocking=True)

# Optionally plot various simulation information.
Expand Down Expand Up @@ -291,3 +295,66 @@

print("Progress: %d / %d (%.4f seconds)" % (epoch + 1, n_epochs, t() - start))
print("Training complete.\n")


# Load MNIST data.
test_dataset = MNIST(
PoissonEncoder(time=time, dt=dt),
None,
root=os.path.join("..", "..", "data", "MNIST"),
download=True,
train=False,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)]
),
)

# Sequence of accuracy estimates.
accuracy = {"all": 0, "proportion": 0}

# Record spikes during the simulation.
spike_record = torch.zeros(1, int(time/dt), n_neurons)

# Train the network.
print("\nBegin testing\n")
network.train(mode=False)
start = t()

for step, batch in enumerate(tqdm(test_dataset)):
# Get next input sample.
inputs = {"X": batch["encoded_image"].view(int(time/dt), 1, 1, 28, 28)}
if gpu:
inputs = {k: v.cuda() for k, v in inputs.items()}

# Run the network on the input.
network.run(inputs=inputs, time=time, input_time_dim=1)

# Add to spikes recording.
spike_record[0] = spikes["Y"].get("s").squeeze()

# Convert the array of labels into a tensor
label_tensor = torch.tensor(batch["label"])

# Get network predictions.
all_activity_pred = all_activity(
spikes=spike_record, assignments=assignments, n_labels=n_classes
)
proportion_pred = proportion_weighting(
spikes=spike_record,
assignments=assignments,
proportions=proportions,
n_labels=n_classes,
)

# Compute network accuracy according to available classification strategies.
accuracy["all"] += float(torch.sum(label_tensor.long() == all_activity_pred).item())
accuracy["proportion"] += float(torch.sum(label_tensor.long() == proportion_pred).item())

network.reset_state_variables() # Reset state variables.

print("\nAll activity accuracy: %.2f" % (accuracy["all"] / test_dataset.test_labels.shape[0]))
print("Proportion weighting accuracy: %.2f \n" % ( accuracy["proportion"] / test_dataset.test_labels.shape[0]))


print("Progress: %d / %d (%.4f seconds)" % (epoch + 1, n_epochs, t() - start))
print("Testing complete.\n")
93 changes: 84 additions & 9 deletions examples/mnist/batch_eth_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
parser.add_argument("--test", dest="train", action="store_false")
parser.add_argument("--plot", dest="plot", action="store_true")
parser.add_argument("--gpu", dest="gpu", action="store_true")
parser.set_defaults(plot=False, gpu=False, train=True)
parser.set_defaults(plot=False, gpu=False)

args = parser.parse_args()

Expand Down Expand Up @@ -73,6 +73,8 @@
torch.cuda.manual_seed_all(seed)
else:
torch.manual_seed(seed)
if gpu:
gpu = False

# Determines number of workers to use
if n_workers == -1:
Expand Down Expand Up @@ -119,20 +121,20 @@
accuracy = {"all": [], "proportion": []}

# Voltage recording for excitatory and inhibitory layers.
exc_voltage_monitor = Monitor(network.layers["Ae"], ["v"], time=time)
inh_voltage_monitor = Monitor(network.layers["Ai"], ["v"], time=time)
exc_voltage_monitor = Monitor(network.layers["Ae"], ["v"], time=int(time/dt))
inh_voltage_monitor = Monitor(network.layers["Ai"], ["v"], time=int(time/dt))
network.add_monitor(exc_voltage_monitor, name="exc_voltage")
network.add_monitor(inh_voltage_monitor, name="inh_voltage")

# Set up monitors for spikes and voltages
spikes = {}
for layer in set(network.layers):
spikes[layer] = Monitor(network.layers[layer], state_vars=["s"], time=time)
spikes[layer] = Monitor(network.layers[layer], state_vars=["s"], time=int(time/dt))
network.add_monitor(spikes[layer], name="%s_spikes" % layer)

voltages = {}
for layer in set(network.layers) - {"X"}:
voltages[layer] = Monitor(network.layers[layer], state_vars=["v"], time=time)
voltages[layer] = Monitor(network.layers[layer], state_vars=["v"], time=int(time/dt))
network.add_monitor(voltages[layer], name="%s_voltages" % layer)

inpt_ims, inpt_axes = None, None
Expand All @@ -142,7 +144,7 @@
perf_ax = None
voltage_axes, voltage_ims = None, None

spike_record = torch.zeros(update_interval, time, n_neurons)
spike_record = torch.zeros(update_interval, int(time/dt), n_neurons)

# Train the network.
print("\nBegin training.\n")
Expand All @@ -156,15 +158,15 @@
start = t()

# Create a dataloader to iterate and batch data
dataloader = DataLoader(
train_dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=n_workers,
pin_memory=gpu,
)

for step, batch in enumerate(tqdm(dataloader)):
for step, batch in enumerate(tqdm(train_dataloader)):
# Get next input sample.
inputs = {"X": batch["encoded_image"]}
if gpu:
Expand Down Expand Up @@ -262,7 +264,7 @@
weights_im = plot_weights(square_weights, im=weights_im)
assigns_im = plot_assignments(square_assignments, im=assigns_im)
perf_ax = plot_performance(
accuracy, x_scale=update_steps * batch_size, ax=perf_ax
accuracy, ax=perf_ax
)
voltage_ims, voltage_axes = plot_voltages(
voltages, ims=voltage_ims, axes=voltage_axes, plot_type="line"
Expand All @@ -272,5 +274,78 @@

network.reset_state_variables() # Reset state variables.

if step % update_steps == 0 and step > 0:
break

print("Progress: %d / %d (%.4f seconds)" % (epoch + 1, n_epochs, t() - start))
print("Training complete.\n")



# Load MNIST data.
test_dataset = MNIST(
PoissonEncoder(time=time, dt=dt),
None,
root=os.path.join(ROOT_DIR, "data", "MNIST"),
download=True,
train=False,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)]
),
)

# Create a dataloader to iterate and batch data
test_dataloader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=n_workers,
pin_memory=gpu,
)

# Sequence of accuracy estimates.
accuracy = {"all": 0, "proportion": 0}

# Train the network.
print("\nBegin testing\n")
network.train(mode=False)
start = t()

for step, batch in enumerate(tqdm(test_dataloader)):
# Get next input sample.
inputs = {"X": batch["encoded_image"]}
if gpu:
inputs = {k: v.cuda() for k, v in inputs.items()}

# Run the network on the input.
network.run(inputs=inputs, time=time, input_time_dim=1)

# Add to spikes recording.
spike_record = spikes["Ae"].get("s").permute((1, 0, 2))

# Convert the array of labels into a tensor
label_tensor = torch.tensor(batch["label"])

# Get network predictions.
all_activity_pred = all_activity(
spikes=spike_record, assignments=assignments, n_labels=n_classes
)
proportion_pred = proportion_weighting(
spikes=spike_record,
assignments=assignments,
proportions=proportions,
n_labels=n_classes,
)

# Compute network accuracy according to available classification strategies.
accuracy["all"] += float(torch.sum(label_tensor.long() == all_activity_pred).item())
accuracy["proportion"] += float(torch.sum(label_tensor.long() == proportion_pred).item())

network.reset_state_variables() # Reset state variables.

print("\nAll activity accuracy: %.2f" % (accuracy["all"] / test_dataset.test_labels.shape[0]))
print("Proportion weighting accuracy: %.2f \n" % ( accuracy["proportion"] / test_dataset.test_labels.shape[0]))


print("Progress: %d / %d (%.4f seconds)" % (epoch + 1, n_epochs, t() - start))
print("Testing complete.\n")
Loading