|
| 1 | +package closuretab |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "database/sql" |
| 6 | + "fmt" |
| 7 | + "strings" |
| 8 | +) |
| 9 | + |
| 10 | +type AttrType int |
| 11 | + |
| 12 | +const ( |
| 13 | + Child AttrType = iota |
| 14 | + Parent |
| 15 | + Depth |
| 16 | +) |
| 17 | + |
| 18 | +type AttrMapping = map[AttrType]string |
| 19 | + |
| 20 | +type Querier interface { |
| 21 | + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row |
| 22 | + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) |
| 23 | + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) |
| 24 | +} |
| 25 | + |
| 26 | +type Node struct { |
| 27 | + ID int64 |
| 28 | + ParentID int64 |
| 29 | + Depth int |
| 30 | +} |
| 31 | + |
| 32 | +type ClosureRelation struct { |
| 33 | + table string |
| 34 | + attrs map[AttrType]string |
| 35 | +} |
| 36 | + |
| 37 | +func InitClosureRelation(tableName string, attrs AttrMapping) *ClosureRelation { |
| 38 | + return &ClosureRelation{table: tableName, attrs: attrs} |
| 39 | +} |
| 40 | + |
| 41 | +func (r *ClosureRelation) GetChildren(ctx context.Context, q Querier, parentID int64) ([]Node, error) { |
| 42 | + rows, err := q.QueryContext( |
| 43 | + ctx, |
| 44 | + fmt.Sprintf( |
| 45 | + "SELECT %s, %s, %s FROM %s WHERE %s = ? ORDER BY %s ASC", |
| 46 | + r.attrs[Child], r.attrs[Parent], r.attrs[Depth], r.table, |
| 47 | + r.attrs[Parent], r.attrs[Depth], |
| 48 | + ), |
| 49 | + parentID, |
| 50 | + ) |
| 51 | + if err != nil { |
| 52 | + return nil, fmt.Errorf("get child nodes for parent ID %d: %w", parentID, err) |
| 53 | + } |
| 54 | + return scanNodes(rows) |
| 55 | +} |
| 56 | + |
| 57 | +func (r *ClosureRelation) GetParents(ctx context.Context, q Querier, nodeID int64) ([]Node, error) { |
| 58 | + rows, err := q.QueryContext( |
| 59 | + ctx, |
| 60 | + fmt.Sprintf( |
| 61 | + "SELECT %s, %s, %s FROM %s WHERE %s = ? AND %s != ? ORDER BY %s DESC", |
| 62 | + r.attrs[Parent], r.attrs[Parent], r.attrs[Depth], r.table, |
| 63 | + r.attrs[Child], r.attrs[Parent], |
| 64 | + r.attrs[Depth], |
| 65 | + ), |
| 66 | + nodeID, nodeID, |
| 67 | + ) |
| 68 | + if err != nil { |
| 69 | + return nil, fmt.Errorf("get parent nodes for node ID %d: %w", nodeID, err) |
| 70 | + } |
| 71 | + return scanNodes(rows) |
| 72 | +} |
| 73 | + |
| 74 | +func (r *ClosureRelation) Insert(ctx context.Context, q Querier, parentID, nodeID int64) (Node, error) { |
| 75 | + _, err := q.ExecContext( |
| 76 | + ctx, |
| 77 | + fmt.Sprintf( |
| 78 | + "INSERT INTO %s (%s, %s, %s) "+ |
| 79 | + "SELECT ?, %s, %s + 1 FROM %s WHERE %s = ?", |
| 80 | + r.table, r.attrs[Child], r.attrs[Parent], r.attrs[Depth], |
| 81 | + r.attrs[Parent], r.attrs[Depth], r.table, r.attrs[Child], |
| 82 | + ), |
| 83 | + nodeID, parentID, |
| 84 | + ) |
| 85 | + if err != nil { |
| 86 | + return Node{}, fmt.Errorf("insert hierarchy references: %w", err) |
| 87 | + } |
| 88 | + |
| 89 | + _, err = q.ExecContext( |
| 90 | + ctx, |
| 91 | + fmt.Sprintf( |
| 92 | + "INSERT INTO %s (%s, %s, %s) VALUES (?, ?, ?)", |
| 93 | + r.table, r.attrs[Child], r.attrs[Parent], r.attrs[Depth], |
| 94 | + ), |
| 95 | + nodeID, nodeID, 0, |
| 96 | + ) |
| 97 | + if err != nil { |
| 98 | + return Node{}, fmt.Errorf("insert self-reference: %w", err) |
| 99 | + } |
| 100 | + |
| 101 | + return Node{}, nil |
| 102 | +} |
| 103 | + |
| 104 | +func (r *ClosureRelation) Delete(ctx context.Context, q Querier, nodeID int64) error { |
| 105 | + if _, err := q.ExecContext( |
| 106 | + ctx, |
| 107 | + fmt.Sprintf( |
| 108 | + "DELETE FROM %s WHERE %s IN (SELECT %s FROM %s WHERE %s = ?)", |
| 109 | + r.table, r.attrs[Child], r.attrs[Child], r.table, r.attrs[Parent], |
| 110 | + ), |
| 111 | + nodeID, |
| 112 | + ); err != nil { |
| 113 | + return fmt.Errorf("remove node ID %d: %w", nodeID, err) |
| 114 | + } |
| 115 | + |
| 116 | + if _, err := q.ExecContext( |
| 117 | + ctx, |
| 118 | + fmt.Sprintf( |
| 119 | + "DELETE FROM %s WHERE %s = ? OR %s = ?", |
| 120 | + r.table, r.attrs[Child], r.attrs[Parent], |
| 121 | + ), |
| 122 | + nodeID, nodeID, |
| 123 | + ); err != nil { |
| 124 | + return fmt.Errorf("remove node ID %d: %w", nodeID, err) |
| 125 | + } |
| 126 | + return nil |
| 127 | +} |
| 128 | + |
| 129 | +func (r *ClosureRelation) Move(ctx context.Context, q Querier, nodeID, newParentID int64) error { |
| 130 | + if _, deleteErr := q.ExecContext( |
| 131 | + ctx, |
| 132 | + fmt.Sprintf( |
| 133 | + "DELETE FROM %s "+ |
| 134 | + "WHERE %s IN "+ |
| 135 | + "(SELECT %s FROM %s WHERE %s = ?) "+ |
| 136 | + "AND %s IN "+ |
| 137 | + "(SELECT %s FROM %s WHERE %s = ? AND %s != %s) ", |
| 138 | + r.table, |
| 139 | + r.attrs[Child], |
| 140 | + r.attrs[Child], r.table, r.attrs[Parent], |
| 141 | + r.attrs[Parent], |
| 142 | + r.attrs[Parent], r.table, r.attrs[Child], r.attrs[Parent], r.attrs[Child], |
| 143 | + ), |
| 144 | + nodeID, nodeID, |
| 145 | + ); deleteErr != nil { |
| 146 | + return fmt.Errorf("remove node ID %d: %w", nodeID, deleteErr) |
| 147 | + } |
| 148 | + |
| 149 | + parents, err := r.GetParents(ctx, q, newParentID) |
| 150 | + if err != nil { |
| 151 | + return fmt.Errorf("get new parents for moved nodes: %w", err) |
| 152 | + } |
| 153 | + parentIDs := NodeIDs(parents) |
| 154 | + parentIDs = append(parentIDs, newParentID) |
| 155 | + parentIDsPlaceholders := makePlaceholders("?", len(parentIDs)) |
| 156 | + |
| 157 | + children, err := r.GetChildren(ctx, q, nodeID) |
| 158 | + if err != nil { |
| 159 | + return fmt.Errorf("get all nodes being moved: %w", err) |
| 160 | + } |
| 161 | + childrenIDs := NodeIDs(children) |
| 162 | + childrenIDsPlaceholders := makePlaceholders("?", len(childrenIDs)) |
| 163 | + |
| 164 | + args := make([]interface{}, len(parentIDs)+len(childrenIDs)) |
| 165 | + for i := 0; i < len(args); i++ { |
| 166 | + if i < len(parentIDs) { |
| 167 | + args[i] = parentIDs[i] |
| 168 | + } else { |
| 169 | + args[i] = childrenIDs[i-len(parentIDs)] |
| 170 | + } |
| 171 | + } |
| 172 | + |
| 173 | + query := fmt.Sprintf( |
| 174 | + `INSERT INTO %s (%s, %s, %s) |
| 175 | + SELECT supertree.%s, subtree.%s, MAX(supertree.%s + subtree.%s + 1) |
| 176 | + FROM %s AS supertree, %s AS subtree |
| 177 | + WHERE |
| 178 | + supertree.%s IN %s |
| 179 | + AND subtree.%s IN %s |
| 180 | + GROUP BY supertree.%s, subtree.%s`, |
| 181 | + r.table, r.attrs[Parent], r.attrs[Child], r.attrs[Depth], |
| 182 | + r.attrs[Parent], r.attrs[Child], r.attrs[Depth], r.attrs[Depth], |
| 183 | + r.table, r.table, |
| 184 | + r.attrs[Parent], parentIDsPlaceholders, |
| 185 | + r.attrs[Child], childrenIDsPlaceholders, |
| 186 | + r.attrs[Parent], r.attrs[Child], |
| 187 | + ) |
| 188 | + |
| 189 | + if _, insertErr := q.ExecContext( |
| 190 | + ctx, |
| 191 | + query, |
| 192 | + args..., |
| 193 | + ); insertErr != nil { |
| 194 | + return fmt.Errorf("insert nodes under new parent: %w", insertErr) |
| 195 | + } |
| 196 | + |
| 197 | + return nil |
| 198 | +} |
| 199 | + |
| 200 | +func (r *ClosureRelation) Empty(ctx context.Context, q Querier) (bool, error) { |
| 201 | + row := q.QueryRowContext(ctx, fmt.Sprintf("SELECT count(*) FROM %s", r.table)) |
| 202 | + var cnt int |
| 203 | + if err := row.Scan(&cnt); err != nil { |
| 204 | + return false, fmt.Errorf("count closure table rows: %w", err) |
| 205 | + } |
| 206 | + return cnt == 0, nil |
| 207 | +} |
| 208 | + |
| 209 | +func NodeIDs(nodes []Node) []int64 { |
| 210 | + res := make([]int64, 0, len(nodes)) |
| 211 | + for i := range nodes { |
| 212 | + res = append(res, nodes[i].ID) |
| 213 | + } |
| 214 | + return res |
| 215 | +} |
| 216 | + |
| 217 | +type scanner interface { |
| 218 | + Scan(dest ...interface{}) error |
| 219 | +} |
| 220 | + |
| 221 | +func scanNodes(rows *sql.Rows) ([]Node, error) { |
| 222 | + result := make([]Node, 0) |
| 223 | + if _, scanErr := scanEachRow(rows, func(s scanner) error { |
| 224 | + n := Node{} |
| 225 | + if rowErr := s.Scan(&n.ID, &n.ParentID, &n.Depth); rowErr != nil { |
| 226 | + return rowErr |
| 227 | + } |
| 228 | + result = append(result, n) |
| 229 | + return nil |
| 230 | + }); scanErr != nil { |
| 231 | + return nil, fmt.Errorf("scan nodes: %w", scanErr) |
| 232 | + } |
| 233 | + return result, nil |
| 234 | +} |
| 235 | + |
| 236 | +func scanEachRow(rows *sql.Rows, scanRow func(s scanner) error) (rowsProcessed int, err error) { |
| 237 | + defer func() { _ = rows.Close() }() |
| 238 | + count := 0 |
| 239 | + for rows.Next() { |
| 240 | + err = scanRow(rows) |
| 241 | + if err != nil { |
| 242 | + return 0, fmt.Errorf("scan row: %w", err) |
| 243 | + } |
| 244 | + count++ |
| 245 | + } |
| 246 | + if err = rows.Err(); err != nil { |
| 247 | + if err == sql.ErrNoRows { |
| 248 | + return 0, nil |
| 249 | + } |
| 250 | + return 0, fmt.Errorf("rows scan: %w", err) |
| 251 | + } |
| 252 | + return count, nil |
| 253 | +} |
| 254 | + |
| 255 | +func makePlaceholders(pHolder string, argLen int) string { |
| 256 | + pHolders := make([]string, argLen) |
| 257 | + for i := 0; i < argLen; i++ { |
| 258 | + pHolders[i] = pHolder |
| 259 | + } |
| 260 | + return fmt.Sprintf("(%s)", strings.Join(pHolders, ", ")) |
| 261 | +} |
0 commit comments