Skip to content

Commit 14c86ed

Browse files
committed
Updates to the MGN synthetic data reader
1 parent b0aca32 commit 14c86ed

File tree

4 files changed

+33
-20
lines changed

4 files changed

+33
-20
lines changed

applications/graph/MeshGraphNet/GNN.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import lbann
2-
from .GNNComponents import MLP, GraphProcessor
2+
from GNNComponents import MLP, GraphProcessor
33

44

55
def input_data_splitter(input_layer,
@@ -96,7 +96,7 @@ def LBANN_GNN_Model(num_nodes, num_edges,
9696

9797
calculated_features = node_decoder(node_features)
9898

99-
loss = loss.MeanSquaredError(calculated_features, out_features)
99+
loss = lbann.MeanSquaredError(calculated_features, out_features)
100100

101101
# Define some of the usual callbacks
102102

applications/graph/MeshGraphNet/GNNComponents.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def __init__(self,
8686
norm_type=lbann.LayerNorm,
8787
name=None):
8888
super().__init__()
89+
EdgeProcessor.global_count += 1
8990
self.instance = 0
9091
self.name = (name if name
9192
else f'EdgeProcessor_{EdgeProcessor.global_count}')
@@ -139,6 +140,7 @@ def __init__(self,
139140
norm_type=lbann.LayerNorm,
140141
name=None):
141142
super().__init__()
143+
NodeProcessor.global_count += 1
142144
self.instance = 0
143145
self.name = (name if name
144146
else f'NodeProcessor_{NodeProcessor.global_count}')
@@ -166,7 +168,7 @@ def forward(self,
166168
self.instance += 1
167169

168170
edge_feature_sum = lbann.Scatter(edge_features, target_edge_indices,
169-
name="f{self.name}_{self.instance}_scatter",
171+
name=f"{self.name}_{self.instance}_scatter",
170172
dims=[self.num_nodes, self.in_dim_edge],
171173
axis=0)
172174

@@ -175,7 +177,7 @@ def forward(self,
175177
name=f"{self.name}_{self.instance}_concat_features")
176178
x = self.node_mlp(x)
177179

178-
return lbann.Sum(edge_features, x,
180+
return lbann.Sum(node_features, x,
179181
name=f"{self.name}_{self.instance}_residual_sum")
180182

181183

applications/graph/MeshGraphNet/SyntheticData.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
import numpy as np
2+
import configparser
3+
4+
5+
DATA_CONFIG = configparser.ConfigParser()
6+
DATA_CONFIG.read("data_config.ini")
7+
NUM_NODES = 100 # int(DATA_CONFIG['DEFAULT']['NUM_NODES'])
8+
NUM_EDGES = 10000 # int(DATA_CONFIG['DEFAULT']['NUM_EDGES'])
9+
NODE_FEATS = 5 # int(DATA_CONFIG['DEFAULT']['NODE_FEATURES'])
10+
EDGE_FEATS = 3 # int(DATA_CONFIG['DEFAULT']['EDGE_FEATURES'])
11+
OUT_FEATS = 3 # int(DATA_CONFIG['DEFAULT']['OUT_FEATURES'])
12+
NUM_SAMPLES = 100
213

314

4-
NUM_SAMPLES = 10000
5-
NUM_NODES = 100
6-
NUM_EDGES = 1000
7-
NODE_FEATS = 5
8-
EDGE_FEATS = 3
9-
OUT_FEATS = 3
1015

1116
NODE_FEATURE_SIZE = NUM_NODES * NODE_FEATS
1217
EDGE_FEATURE_SIZE = NUM_EDGES * EDGE_FEATS
@@ -27,3 +32,7 @@ def sample_dims_func():
2732

2833
size = NODE_FEATURE_SIZE + EDGE_FEATURE_SIZE + OUT_FEATURE_SIZE + 2 * NUM_EDGES
2934
return (size, )
35+
36+
37+
if __name__ == '__main__':
38+
print(NUM_NODES)

applications/graph/MeshGraphNet/Trainer.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
import lbann.contrib.launcher
33
import lbann.contrib.args
44
import argparse
5-
5+
import configparser
66
import os.path as osp
7-
from .GNN import LBANN_GNN_Model
7+
from GNN import LBANN_GNN_Model
88

99
data_dir = osp.dirname(osp.realpath(__file__))
1010

@@ -21,8 +21,8 @@
2121
help='number of epochs (deafult: 3)', metavar='NUM')
2222

2323
parser.add_argument(
24-
'--mini-batch-size', action='store', default=256, type=int,
25-
help="mini-batch size (default: 256)", metavar='NUM')
24+
'--mini-batch-size', action='store', default=4, type=int,
25+
help="mini-batch size (default: 4)", metavar='NUM')
2626

2727
parser.add_argument(
2828
'--job-name', action='store', default="MGN", type=str,
@@ -38,12 +38,14 @@
3838
JOB_NAME = args.job_name
3939

4040
# Some synthetic attributes to get the model running
41+
DATA_CONFIG = configparser.ConfigParser()
42+
DATA_CONFIG.read("data_config.ini")
4143

42-
NUM_NODES = 100
43-
NUM_EDGES = 1000
44-
NODE_FEATS = 5
45-
EDGE_FEATS = 3
46-
OUT_FEATS = 3
44+
NUM_NODES = int(DATA_CONFIG['DEFAULT']['NUM_NODES'])
45+
NUM_EDGES = int(DATA_CONFIG['DEFAULT']['NUM_EDGES'])
46+
NODE_FEATS = int(DATA_CONFIG['DEFAULT']['NODE_FEATURES'])
47+
EDGE_FEATS = int(DATA_CONFIG['DEFAULT']['EDGE_FEATURES'])
48+
OUT_FEATS = int(DATA_CONFIG['DEFAULT']['OUT_FEATURES'])
4749

4850
def make_data_reader(classname,
4951
sample='get_sample_func',
@@ -85,4 +87,4 @@ def main():
8587
**kwargs)
8688

8789
if __name__ == '__main__':
88-
main()
90+
main()

0 commit comments

Comments
 (0)