Skip to content

Commit a172772

Browse files
GLE-10323 fix(algo): fix tg_maxflow algorithm (#179) (#181)
Co-authored-by: David Fan <[email protected]>
1 parent e2664f6 commit a172772

File tree

1 file changed

+106
-159
lines changed

1 file changed

+106
-159
lines changed

algorithms/Path/maxflow/tg_maxflow.gsql

Lines changed: 106 additions & 159 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ CREATE QUERY tg_maxflow(VERTEX source, VERTEX sink, Set<STRING> v_type, SET<STRI
33
BOOL print_results = TRUE, BOOL display_edges = TRUE, STRING file_path = "") SYNTAX V1 {
44

55
/*
6-
First Author: <First Author Name>
7-
First Commit Date: <First Commit Date>
6+
First Author: David Zelong Fan
7+
First Commit Date: Jun 6, 2025
88

99
Recent Author: <Recent Commit Author Name>
1010
Recent Commit Date: <Recent Commit Date>
@@ -50,178 +50,125 @@ CREATE QUERY tg_maxflow(VERTEX source, VERTEX sink, Set<STRING> v_type, SET<STRI
5050
file to write CSV output to
5151
*/
5252

53-
TYPEDEF TUPLE<INT prev_flow, BOOL is_forward, VERTEX prev> tb_node;
54-
GroupByAccum<VERTEX source, VERTEX targ, SumAccum<FLOAT> flow> @@group_by_flow_accum;
55-
SetAccum<VERTEX> @@curr_set;
53+
// variables used in max flow computation
54+
GroupByAccum<VERTEX v_from, VERTEX v_to, SumAccum<INT> flow> @@flow_gb;
55+
GroupByAccum<VERTEX v_from, VERTEX v_to, SumAccum<INT> flow> @@residual_gb;
56+
MinAccum<INT> @@path_flow;
57+
58+
// variables used in bfs
59+
ListAccum<VERTEX> @path_list;
60+
OrAccum @end_point;
61+
ListAccum<ListAccum<VERTEX>> @@total_path_list;
62+
ListAccum<VERTEX> @@shortest_path;
63+
64+
// variables used for printing
65+
FILE f(file_path);
5666
SetAccum<EDGE> @@edges_set;
57-
HeapAccum<tb_node>(1, prev_flow DESC) @trace_back_heap;
58-
59-
MaxAccum<FLOAT> @@max_cap_threshold;
60-
SumAccum<FLOAT> @@sum_max_flow = 0;
61-
MinAccum<FLOAT> @@min_flow;
62-
OrAccum @or_is_visited, @@or_is_found;
63-
BOOL minimum_reached = FALSE;
64-
FILE f(file_path);
65-
@@max_cap_threshold = min_flow_threshhold;
66-
67-
IF cap_type NOT IN ("UINT", "INT", "FLOAT", "DOUBLE") THEN
68-
PRINT "weight_type must be UINT, INT, FLOAT, or DOUBLE" AS errMsg;
69-
RETURN;
70-
END;
71-
72-
##### Initialize #####
73-
init = {v_type};
74-
init = SELECT s
75-
FROM init:s - (e_type_set:e) - v_type:t
76-
ACCUM
77-
FLOAT fl = 0,
78-
CASE cap_type
79-
WHEN "UINT" THEN
80-
fl = e.getAttr(cap_attr, "UINT")
81-
WHEN "INT" THEN
82-
fl = e.getAttr(cap_attr, "INT")
83-
WHEN "FLOAT" THEN
84-
fl = e.getAttr(cap_attr, "FLOAT")
85-
WHEN "DOUBLE" THEN
86-
fl = e.getAttr(cap_attr, "DOUBLE")
87-
END,
88-
@@group_by_flow_accum += (s, t -> 0),
89-
IF s == source THEN
90-
@@max_cap_threshold += fl
91-
END;
92-
93-
//used for determining minimum flow of path, s.t. minimum flow > cap_threshold
94-
@@max_cap_threshold = pow(3, float_to_int(log(@@max_cap_threshold)/log(3)));
95-
96-
##### Push one flow at a time until there is residudal graph is disconnected #####
67+
SumAccum<INT> @@sum_max_flow = 0;
68+
69+
// initialize flow and residual graph
70+
all_vertices = {v_type};
71+
all_vertices = SELECT s
72+
FROM all_vertices:s - (e_type_set:e) - v_type:t
73+
ACCUM
74+
FLOAT cap = 0,
75+
CASE cap_type
76+
WHEN "UINT" THEN
77+
cap = e.getAttr(cap_attr, "UINT")
78+
WHEN "INT" THEN
79+
cap = e.getAttr(cap_attr, "INT")
80+
WHEN "FLOAT" THEN
81+
cap = e.getAttr(cap_attr, "FLOAT")
82+
WHEN "DOUBLE" THEN
83+
cap = e.getAttr(cap_attr, "DOUBLE")
84+
END,
85+
@@flow_gb += (s, t -> 0),
86+
@@residual_gb += (s, t -> cap)
87+
;
88+
89+
// mark the target node as true
90+
endset = {sink};
91+
endset = SELECT s
92+
From endset:s
93+
ACCUM s.@end_point = true;
94+
9795
WHILE TRUE DO
98-
//initilize
99-
100-
init = SELECT s
101-
FROM init:s
102-
POST-ACCUM s.@or_is_visited = FALSE,
103-
s.@trace_back_heap = tb_node(GSQL_INT_MIN, FALSE, source);
104-
105-
start = {source};
106-
start = SELECT s
107-
FROM start:s
108-
POST-ACCUM s.@or_is_visited = TRUE;
109-
110-
@@or_is_found = False;
111-
112-
//BFS to find feasible path from source -> sink
113-
WHILE NOT @@or_is_found AND start.size() > 0 DO
114-
forwd = SELECT t
115-
FROM start:s - (e_type_set:e) - v_type:t
116-
WHERE NOT t.@or_is_visited
117-
ACCUM
118-
FLOAT fl = 0,
119-
CASE cap_type
120-
WHEN "UINT" THEN
121-
fl = e.getAttr(cap_attr, "UINT")
122-
WHEN "INT" THEN
123-
fl = e.getAttr(cap_attr, "INT")
124-
WHEN "FLOAT" THEN
125-
fl = e.getAttr(cap_attr, "FLOAT")
126-
WHEN "DOUBLE" THEN
127-
fl = e.getAttr(cap_attr, "DOUBLE")
128-
END,
129-
IF fl - @@group_by_flow_accum.get(s, t).flow >= @@max_cap_threshold THEN
130-
t.@trace_back_heap += tb_node(fl - @@group_by_flow_accum.get(s, t).flow, TRUE, s),
131-
t.@or_is_visited += TRUE,
132-
@@or_is_found += t == sink
133-
END
134-
HAVING t.@or_is_visited;
135-
136-
bacwd = SELECT t
137-
FROM start:s - (reverse_e_type_set) - v_type:t
138-
WHERE NOT t.@or_is_visited
139-
ACCUM
140-
IF @@group_by_flow_accum.get(t, s).flow >= @@max_cap_threshold THEN
141-
t.@trace_back_heap += tb_node(@@group_by_flow_accum.get(t, s).flow, FALSE, s),
142-
t.@or_is_visited += TRUE,
143-
@@or_is_found += t == sink
144-
END
145-
HAVING t.@or_is_visited;
146-
147-
start = forwd UNION bacwd;
148-
END;
149-
150-
//done when residual graph is disconnected
151-
IF NOT @@or_is_found AND minimum_reached THEN
152-
BREAK;
153-
END;
154-
155-
//reduce cap_threshold to look for more path options
156-
IF NOT @@or_is_found THEN
157-
@@max_cap_threshold = float_to_int(@@max_cap_threshold/3);
158-
IF @@max_cap_threshold < min_flow_threshhold THEN
159-
@@max_cap_threshold = min_flow_threshhold;
160-
minimum_reached = TRUE;
161-
END;
162-
163-
CONTINUE;
164-
END;
165-
166-
//find bottleneck
167-
@@curr_set.clear();
168-
@@curr_set += sink;
169-
@@min_flow = GSQL_INT_MAX;
96+
// Run BFS, starting from the initial node
97+
SourceSet = {source};
98+
SourceSet = SELECT s
99+
FROM SourceSet:s
100+
ACCUM s.@path_list = [s];
170101

171-
WHILE NOT @@curr_set.contains(source) DO
172-
start = @@curr_set;
173-
@@curr_set.clear();
174-
start = SELECT s
175-
FROM start:s
176-
POST-ACCUM @@min_flow += s.@trace_back_heap.top().prev_flow,
177-
@@curr_set += s.@trace_back_heap.top().prev;
178-
179-
END;
180-
181-
@@sum_max_flow += @@min_flow;
182-
183-
//traceback to source and update flow vertices
184-
@@curr_set.clear();
185-
@@curr_set += sink;
186-
WHILE NOT @@curr_set.contains(source) DO
187-
start = @@curr_set;
188-
@@curr_set.clear();
189-
start = SELECT s
190-
FROM start:s
191-
POST-ACCUM
192-
@@curr_set += s.@trace_back_heap.top().prev,
193-
CASE
194-
WHEN s.@trace_back_heap.top().is_forward THEN
195-
@@group_by_flow_accum += (s.@trace_back_heap.top().prev, s -> @@min_flow)
196-
ELSE
197-
@@group_by_flow_accum += (s, s.@trace_back_heap.top().prev -> -@@min_flow)
198-
END;
199-
END;
200-
END;
102+
WHILE SourceSet.size() > 0 and @@total_path_list.size() == 0 DO
103+
SourceSet = SELECT t
104+
FROM SourceSet:s -((reverse_e_type_set|e_type_set):e)- :t
105+
WHERE @@residual_gb.get(s, t).flow > min_flow_threshhold AND t.@path_list.size() == 0
106+
ACCUM
107+
// choose any path for tie-breaking
108+
IF s.@path_list.size() > 0 THEN
109+
IF t.@end_point == true THEN
110+
@@total_path_list += s.@path_list + [t]
111+
ELSE
112+
t.@path_list = s.@path_list + [t]
113+
END
114+
END
115+
;
116+
END; // end of BFS
117+
118+
// if no augmenting path is found, break because we've reached max flow
119+
if @@total_path_list.size() == 0 THEN
120+
break;
121+
END;
122+
123+
@@shortest_path = @@total_path_list.get(0);
124+
125+
@@path_flow = GSQL_INT_MAX;
201126

127+
// see how much flow we can send
128+
FOREACH i IN RANGE[0, @@shortest_path.size() - 2] DO
129+
@@path_flow += @@residual_gb.get(@@shortest_path.get(i), @@shortest_path.get(i + 1)).flow;
130+
END;
131+
132+
// send the flow and update the residual graph
133+
FOREACH i IN RANGE[0, @@shortest_path.size() - 2] DO
134+
@@flow_gb += (@@shortest_path.get(i), @@shortest_path.get(i + 1) -> @@path_flow);
135+
@@residual_gb += (@@shortest_path.get(i + 1), @@shortest_path.get(i) -> @@path_flow);
136+
@@residual_gb += (@@shortest_path.get(i), @@shortest_path.get(i + 1) -> -@@path_flow);
137+
END;
138+
@@sum_max_flow += @@path_flow;
139+
140+
// reset and clear
141+
@@total_path_list.clear();
142+
all_vertices = SELECT s
143+
FROM all_vertices:s
144+
POST-ACCUM s.@path_list.clear();
145+
146+
END;
147+
202148
##### Output #####
203149
IF file_path != "" THEN
204150
f.println("Maxflow: " + to_string(@@sum_max_flow));
205151
f.println("From","To","Flow");
206-
END;
207-
start = {source};
208-
WHILE start.size() != 0 DO
209-
start = SELECT t
210-
FROM start:s - (e_type_set:e) - v_type:t
211-
WHERE @@group_by_flow_accum.get(s,t).flow >= min_flow_threshhold
152+
END;
153+
SourceSet = {source};
154+
WHILE SourceSet.size() != 0 DO
155+
SourceSet = SELECT t
156+
FROM SourceSet:s - (e_type_set:e) - v_type:t
157+
WHERE @@flow_gb.get(s,t).flow > min_flow_threshhold
212158
ACCUM
213159
IF print_results THEN
214-
@@edges_set += e
215-
END,
160+
@@edges_set += e
161+
END,
216162
IF file_path != "" THEN
217-
f.println(s, t, @@group_by_flow_accum.get(s,t).flow)
218-
END;
163+
f.println(s, t, @@flow_gb.get(s,t).flow)
164+
END
165+
;
219166
END;
220-
167+
221168
IF print_results THEN
222169
PRINT @@sum_max_flow;
223170
IF display_edges THEN
224171
PRINT @@edges_set;
225172
END;
226173
END;
227-
}
174+
}

0 commit comments

Comments
 (0)