Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
/target/
/.vscode/
/.vscode/
best-agent.json
fitness-plot.svg
744 changes: 743 additions & 1 deletion Cargo.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -37,3 +37,4 @@ serde-big-array = { version = "0.5.1", optional = true }
[dev-dependencies]
bincode = "1.3.3"
serde_json = "1.0.114"
plotters = "0.3.5"
164 changes: 164 additions & 0 deletions examples/plot.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
use std::{
error::Error,
sync::{Arc, Mutex},
};

use neat::*;
use plotters::prelude::*;
use rand::prelude::*;

#[derive(RandomlyMutable, DivisionReproduction, Clone)]
struct AgentDNA {
network: NeuralNetworkTopology<2, 1>,
}

impl Prunable for AgentDNA {}

impl GenerateRandom for AgentDNA {
fn gen_random(rng: &mut impl Rng) -> Self {
Self {
network: NeuralNetworkTopology::new(0.01, 3, rng),
}
}
}

fn fitness(g: &AgentDNA) -> f32 {
let network = NeuralNetwork::from(&g.network);
let mut fitness = 0.;
let mut rng = rand::thread_rng();

for _ in 0..100 {
let n = rng.gen::<f32>() * 10000.;
let base = rng.gen::<f32>() * 10.;
let expected = n.log(base);

let [answer] = network.predict([n, base]);
network.flush_state();

fitness += 5. / (answer - expected).abs();
}

fitness
}

struct PlottingNG<F: NextgenFn<AgentDNA>> {
performance_stats: Arc<Mutex<Vec<PerformanceStats>>>,
actual_ng: F,
}

impl<F: NextgenFn<AgentDNA>> NextgenFn<AgentDNA> for PlottingNG<F> {
fn next_gen(&self, mut fitness: Vec<(AgentDNA, f32)>) -> Vec<AgentDNA> {
// it's a bit slower because of sorting twice but I don't want to rewrite the nextgen.
fitness.sort_by(|(_, fa), (_, fb)| fa.partial_cmp(fb).unwrap());

let l = fitness.len();

let high = fitness[l - 1].1;

let median = fitness[l / 2].1;

let low = fitness[0].1;

let mut ps = self.performance_stats.lock().unwrap();
ps.push(PerformanceStats { high, median, low });

self.actual_ng.next_gen(fitness)
}
}

struct PerformanceStats {
high: f32,
median: f32,
low: f32,
}

const OUTPUT_FILE_NAME: &'static str = "fitness-plot.svg";
const GENS: usize = 100;

fn main() -> Result<(), Box<dyn Error>> {
#[cfg(not(feature = "rayon"))]
let mut rng = rand::thread_rng();

let performance_stats = Arc::new(Mutex::new(Vec::with_capacity(GENS)));
let ng = PlottingNG {
performance_stats: performance_stats.clone(),
actual_ng: division_pruning_nextgen,
};

let mut sim = GeneticSim::new(
#[cfg(not(feature = "rayon"))]
Vec::gen_random(&mut rng, 100),
#[cfg(feature = "rayon")]
Vec::gen_random(100),
fitness,
ng,
);

println!("Training...");

for _ in 0..GENS {
sim.next_generation();
}

// prevent `Arc::into_inner` from failing
drop(sim);

println!("Training complete, collecting data and building chart...");

let root = SVGBackend::new(OUTPUT_FILE_NAME, (640, 480)).into_drawing_area();
root.fill(&WHITE)?;

let mut chart = ChartBuilder::on(&root)
.caption(
"agent fitness values per generation",
("sans-serif", 50).into_font(),
)
.margin(5)
.x_label_area_size(30)
.y_label_area_size(30)
.build_cartesian_2d(0usize..100, 0f32..200.0)?;

chart.configure_mesh().draw()?;

let data: Vec<_> = Arc::into_inner(performance_stats)
.unwrap()
.into_inner()
.unwrap()
.into_iter()
.enumerate()
.collect();

let highs = data
.iter()
.map(|(i, PerformanceStats { high, .. })| (*i, *high));

let medians = data
.iter()
.map(|(i, PerformanceStats { median, .. })| (*i, *median));

let lows = data
.iter()
.map(|(i, PerformanceStats { low, .. })| (*i, *low));

chart
.draw_series(LineSeries::new(highs, &GREEN))?
.label("high");

chart
.draw_series(LineSeries::new(medians, &YELLOW))?
.label("median");

chart.draw_series(LineSeries::new(lows, &RED))?.label("low");

chart
.configure_series_labels()
.background_style(&WHITE.mix(0.8))
.border_style(&BLACK)
.draw()?;

root.present()?;

println!("Complete");

Ok(())
}
61 changes: 61 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -23,3 +23,64 @@ pub use topology::*;

#[cfg(feature = "serde")]
pub use nnt_serde::*;

#[cfg(test)]
mod tests {
use super::*;
use rand::prelude::*;

#[derive(RandomlyMutable, DivisionReproduction, Clone)]
struct AgentDNA {
network: NeuralNetworkTopology<2, 1>,
}

impl Prunable for AgentDNA {}

impl GenerateRandom for AgentDNA {
fn gen_random(rng: &mut impl Rng) -> Self {
Self {
network: NeuralNetworkTopology::new(0.01, 3, rng),
}
}
}

#[test]
fn basic_test() {
let fitness = |g: &AgentDNA| {
let network = NeuralNetwork::from(&g.network);
let mut fitness = 0.;
let mut rng = rand::thread_rng();

for _ in 0..100 {
let n = rng.gen::<f32>() * 10000.;
let base = rng.gen::<f32>() * 10.;
let expected = n.log(base);

let [answer] = network.predict([n, base]);
network.flush_state();

fitness += 5. / (answer - expected).abs();
}

fitness
};

let mut rng = rand::thread_rng();

let mut sim = GeneticSim::new(
Vec::gen_random(&mut rng, 100),
fitness,
division_pruning_nextgen,
);

for _ in 0..100 {
sim.next_generation();
}

let mut fits: Vec<_> = sim.genomes.iter().map(fitness).collect();

fits.sort_by(|a, b| a.partial_cmp(&b).unwrap());

dbg!(fits);
}
}