Skip to content

HyperNEAT substrate size and classification tasks #34

@LuukvanKeeken

Description

@LuukvanKeeken

Thanks for developing this nice repository! I have two questions/issues (probably due to my inexperience with JAX), and I was wondering if you might have some suggestions/answers.

I've been doing some experiments trying out HyperNEAT on some simple tasks. One issue I'm running into is that my GPU doesn't seem to be able to handle large networks due to storage limits. A population of 1000 with a substrate of two hidden layers with 100 nodes is already near the maximum. Decreasing the population size helps somewhat, but I can't get near at least tens of thousands of neurons I was hoping for, and of course for the evolutionary process the population should also be kept sufficiently large.

Aside from using a larger GPU, or the multi-GPU setup, do you have some suggestions for getting around these substrate size limits?

A second question I have is related to trying to implement classification with a kind of population encoding. That is, each class would be represented by some number of output nodes. The class for which the sum of activations in its corresponding group of nodes is largest is selected as the network's class prediction. I've started simple with a binary classification task and just one neuron per class, i.e. just two neurons in the second-to-last layer. For the singular output node I've implemented two aggregation functions. One to filter out the NaNs in the input, leaving just the two activations coming from the previous two nodes. Another to perform argmax, returning then as network output either 0 or 1.

def filter_nans_(z):
    mask = ~jnp.isnan(z)
    idxs = jnp.nonzero(mask, size=2, fill_value=0)[0]
    return z[idxs]
def argmax_(z):
    return jnp.asarray(jnp.argmax(z), dtype=jnp.float32)

This actually seems to be working, but sometimes I run into some peculiar issues where the genome that is saved as self.best_genome in the pipeline suddenly changes. It even looks like an entire additional generation step is performed (not sure about if that is correct), changing multiple fitnesses in the population, even though at this point the set fitness target has been reached. This results in the prints from within pipeline's analysis() function showing good maximum fitness values, and the evolution stopping because the fitness target has been reached, but then if I run pipeline.show() on the supposedly best genome returned by pipeline.auto_run(), the performance is much worse. The weird thing is that this change doesn't happen every time a new best genome is found, and it seems to only ever happen if I use my custom argmax and filter_nans functions. If I use the same random seed in this setup, it does consistently happen in exactly the same generation, with the same fitness values etc. each time.

Is this behaviour something that you've seen before?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions