diff --git a/nx_arangodb/classes/dict/adj.py b/nx_arangodb/classes/dict/adj.py index 63c5b5bd..9bb9a5e3 100644 --- a/nx_arangodb/classes/dict/adj.py +++ b/nx_arangodb/classes/dict/adj.py @@ -1557,7 +1557,7 @@ def keys(self) -> Any: @logger_debug def clear(self) -> None: - """g._node.clear()""" + """g._adj.clear()""" self.data.clear() self.FETCHED_ALL_DATA = False self.FETCHED_ALL_IDS = False diff --git a/nx_arangodb/classes/digraph.py b/nx_arangodb/classes/digraph.py index d1931963..7acdda1c 100644 --- a/nx_arangodb/classes/digraph.py +++ b/nx_arangodb/classes/digraph.py @@ -5,9 +5,11 @@ import nx_arangodb as nxadb from nx_arangodb.classes.graph import Graph +from nx_arangodb.logger import logger from .dict.adj import AdjListOuterDict from .enum import TraversalDirection +from .function import get_node_id networkx_api = nxadb.utils.decorators.networkx_class(nx.DiGraph) # type: ignore @@ -57,12 +59,9 @@ def __init__( ) if self.graph_exists_in_db: - assert isinstance(self._succ, AdjListOuterDict) - assert isinstance(self._pred, AdjListOuterDict) - self._succ.mirror = self._pred - self._pred.mirror = self._succ - self._succ.traversal_direction = TraversalDirection.OUTBOUND - self._pred.traversal_direction = TraversalDirection.INBOUND + self.clear_edges = self.clear_edges_override + self.add_node = self.add_node_override + self.remove_node = self.remove_node_override ####################### # nx.DiGraph Overides # @@ -78,7 +77,14 @@ def __init__( # def out_edges(self): # pass - def add_node(self, node_for_adding, **attr): + def clear_edges_override(self): + logger.info("Note that clearing edges ony erases the edges in the local cache") + for predecessor_dict in self._pred.data.values(): + predecessor_dict.clear() + + super().clear_edges() + + def add_node_override(self, node_for_adding, **attr): if node_for_adding not in self._succ: if node_for_adding is None: raise ValueError("None cannot be a node") @@ -109,7 +115,10 @@ def add_node(self, node_for_adding, **attr): nx._clear_cache(self) - def remove_node(self, n): + def remove_node_override(self, n): + if isinstance(n, (str, int)): + n = get_node_id(str(n), self.default_node_type) + try: ###################### @@ -136,6 +145,22 @@ def remove_node(self, n): del self._pred[u][n] # remove all edges n-u in digraph del self._succ[n] # remove node from succ for u in nbrs_pred: + ###################### + # NOTE: Monkey patch # + ###################### + + # Old: Nothing + + # New: + if u == n: + continue # skip self loops + + # Reason: We need to skip self loops, as they are + # already taken care of in the previous step. This + # avoids getting a KeyError on the next line. + + ########################### + del self._succ[u][n] # remove all edges n-u in digraph del self._pred[n] # remove node from pred nx._clear_cache(self) diff --git a/nx_arangodb/classes/graph.py b/nx_arangodb/classes/graph.py index 23ab36ac..bc4d84d8 100644 --- a/nx_arangodb/classes/graph.py +++ b/nx_arangodb/classes/graph.py @@ -29,6 +29,8 @@ node_attr_dict_factory, node_dict_factory, ) +from .dict.adj import AdjListOuterDict +from .enum import TraversalDirection from .function import get_node_id from .reportviews import CustomEdgeView, CustomNodeView @@ -97,6 +99,7 @@ def __init__( # m = "Must set **graph_name** if passing **incoming_graph_data**" # raise ValueError(m) + loaded_incoming_graph_data = False if self._graph_exists_in_db: if incoming_graph_data is not None: m = "Cannot pass both **incoming_graph_data** and **graph_name** yet if the already graph exists" # noqa: E501 @@ -171,29 +174,44 @@ def edge_type_func(u: str, v: str) -> str: use_async=True, ) + loaded_incoming_graph_data = True + else: self.adb_graph = self.db.create_graph( self.__name, edge_definitions=edge_definitions, ) - # Let the parent class handle the incoming graph data - # if it is not a networkx.Graph object - kwargs["incoming_graph_data"] = incoming_graph_data - self._set_factory_methods() self._set_arangodb_backend_config() logger.info(f"Graph '{name}' created.") self._graph_exists_in_db = True - else: - kwargs["incoming_graph_data"] = incoming_graph_data - - if name is not None: - kwargs["name"] = name + if self.__name is not None: + kwargs["name"] = self.__name super().__init__(*args, **kwargs) + if self.is_directed() and self.graph_exists_in_db: + assert isinstance(self._succ, AdjListOuterDict) + assert isinstance(self._pred, AdjListOuterDict) + self._succ.mirror = self._pred + self._pred.mirror = self._succ + self._succ.traversal_direction = TraversalDirection.OUTBOUND + self._pred.traversal_direction = TraversalDirection.INBOUND + + if incoming_graph_data is not None and not loaded_incoming_graph_data: + nx.convert.to_networkx_graph(incoming_graph_data, create_using=self) + + if self.graph_exists_in_db: + self.copy = self.copy_override + self.subgraph = self.subgraph_override + self.clear = self.clear_override + self.clear_edges = self.clear_edges_override + self.add_node = self.add_node_override + self.number_of_edges = self.number_of_edges_override + self.nbunch_iter = self.nbunch_iter_override + ####################### # Init helper methods # ####################### @@ -345,6 +363,9 @@ def _set_graph_name(self, graph_name: str | None = None) -> None: # ArangoDB Methods # #################### + def clear_nxcg_cache(self): + self.nxcg_graph = None + def aql(self, query: str, bind_vars: dict[str, Any] = {}, **kwargs: Any) -> Cursor: return nxadb.classes.function.aql(self.db, query, bind_vars, **kwargs) @@ -355,7 +376,7 @@ def aql(self, query: str, bind_vars: dict[str, Any] = {}, **kwargs: Any) -> Curs # NOTE: OUT OF SERVICE # def chat(self, prompt: str) -> str: # if self.__qa_chain is None: - # if not self.__graph_exists_in_db: + # if not self.graph_exists_in_db: # return "Could not initialize QA chain: Graph does not exist" # # try: @@ -381,30 +402,6 @@ def aql(self, query: str, bind_vars: dict[str, Any] = {}, **kwargs: Any) -> Curs # nx.Graph Overides # ##################### - def copy(self, *args, **kwargs): - logger.warning("Note that copying a graph loses the connection to the database") - G = super().copy(*args, **kwargs) - G.node_dict_factory = nx.Graph.node_dict_factory - G.node_attr_dict_factory = nx.Graph.node_attr_dict_factory - G.edge_attr_dict_factory = nx.Graph.edge_attr_dict_factory - G.adjlist_inner_dict_factory = nx.Graph.adjlist_inner_dict_factory - G.adjlist_outer_dict_factory = nx.Graph.adjlist_outer_dict_factory - return G - - def subgraph(self, nbunch): - raise NotImplementedError("Subgraphing is not yet implemented") - - def clear(self): - logger.info("Note that clearing only erases the local cache") - super().clear() - - def clear_edges(self): - logger.info("Note that clearing edges ony erases the edges in the local cache") - super().clear_edges() - - def clear_nxcg_cache(self): - self.nxcg_graph = None - @cached_property def nodes(self): if self.__use_experimental_views and self.graph_exists_in_db: @@ -437,7 +434,30 @@ def edges(self): return super().edges - def add_node(self, node_for_adding, **attr): + def copy_override(self, *args, **kwargs): + logger.warning("Note that copying a graph loses the connection to the database") + G = super().copy(*args, **kwargs) + G.node_dict_factory = nx.Graph.node_dict_factory + G.node_attr_dict_factory = nx.Graph.node_attr_dict_factory + G.edge_attr_dict_factory = nx.Graph.edge_attr_dict_factory + G.adjlist_inner_dict_factory = nx.Graph.adjlist_inner_dict_factory + G.adjlist_outer_dict_factory = nx.Graph.adjlist_outer_dict_factory + return G + + def subgraph_override(self, nbunch): + raise NotImplementedError("Subgraphing is not yet implemented") + + def clear_override(self): + logger.info("Note that clearing only erases the local cache") + super().clear() + + def clear_edges_override(self): + logger.info("Note that clearing edges ony erases the edges in the local cache") + for nbr_dict in self._adj.data.values(): + nbr_dict.clear() + nx._clear_cache(self) + + def add_node_override(self, node_for_adding, **attr): if node_for_adding not in self._node: if node_for_adding is None: raise ValueError("None cannot be a node") @@ -467,10 +487,7 @@ def add_node(self, node_for_adding, **attr): nx._clear_cache(self) - def number_of_edges(self, u=None, v=None): - if not self.graph_exists_in_db: - return super().number_of_edges(u, v) - + def number_of_edges_override(self, u=None, v=None): if u is not None: return super().number_of_edges(u, v) @@ -494,10 +511,7 @@ def number_of_edges(self, u=None, v=None): # It is more efficient to count the number of edges in the edge collections # compared to relying on the DegreeView. - def nbunch_iter(self, nbunch=None): - if not self._graph_exists_in_db: - return super().nbunch_iter(nbunch) - + def nbunch_iter_override(self, nbunch=None): if nbunch is None: bunch = iter(self._adj) elif nbunch in self: @@ -508,12 +522,13 @@ def nbunch_iter(self, nbunch=None): # Old: Nothing # New: - if isinstance(nbunch, int): + if isinstance(nbunch, (str, int)): nbunch = get_node_id(str(nbunch), self.default_node_type) # Reason: # ArangoDB only uses strings as node IDs. Therefore, we need to convert - # the integer node ID to a string before using it in an iterator. + # the non-prefixed node ID to an ArangoDB ID before + # using it in an iterator. bunch = iter([nbunch]) else: @@ -528,13 +543,15 @@ def bunch_iter(nlist, adj): # Old: Nothing # New: - if isinstance(n, int): + if isinstance(n, (str, int)): n = get_node_id(str(n), self.default_node_type) # Reason: # ArangoDB only uses strings as node IDs. Therefore, - # we need to convert the integer node ID to a - # string before using it in an iterator. + # we need to convert non-prefixed node IDs to an + # ArangoDB ID before using it in an iterator. + + ###################### if n in adj: yield n diff --git a/nx_arangodb/classes/multigraph.py b/nx_arangodb/classes/multigraph.py index 4efa205b..a0419143 100644 --- a/nx_arangodb/classes/multigraph.py +++ b/nx_arangodb/classes/multigraph.py @@ -56,6 +56,9 @@ def __init__( **kwargs, ) + if self.graph_exists_in_db: + self.add_edge = self.add_edge_override + ####################### # Init helper methods # ####################### @@ -74,10 +77,7 @@ def _set_factory_methods(self) -> None: # nx.MultiGraph Overides # ########################## - def add_edge(self, u_for_edge, v_for_edge, key=None, **attr): - if not self.graph_exists_in_db: - return super().add_edge(u_for_edge, v_for_edge, key=key, **attr) - + def add_edge_override(self, u_for_edge, v_for_edge, key=None, **attr): if key is not None: m = "ArangoDB MultiGraph does not support custom edge keys yet." logger.warning(m) diff --git a/tests/test.py b/tests/test.py index 1b37f654..c364ee76 100644 --- a/tests/test.py +++ b/tests/test.py @@ -18,6 +18,9 @@ from .conftest import create_line_graph, db G_NX = nx.karate_club_graph() +G_NX_digraph = nx.DiGraph(G_NX) +G_NX_multigraph = nx.MultiGraph(G_NX) +G_NX_multidigraph = nx.MultiDiGraph(G_NX) def assert_remote_dict(G: nxadb.Graph) -> None: @@ -1777,3 +1780,58 @@ def test_incoming_graph_data_not_nx_graph( ) assert has_club == ("club" in G.nodes["0"]) assert has_weight == ("weight" in G.adj["0"]["1"]) + + +@pytest.mark.parametrize( + "data_type, incoming_graph_data, has_club, has_weight", + [ + ("dict of dicts", G_NX_digraph._adj, False, True), + ( + "dict of lists", + {k: list(v) for k, v in G_NX_digraph._adj.items()}, + False, + False, + ), + ("container of edges", list(G_NX_digraph.edges), False, False), + ("iterator of edges", iter(G_NX_digraph.edges), False, False), + ("generator of edges", (e for e in G_NX_digraph.edges), False, False), + ("2D numpy array", nx.to_numpy_array(G_NX_digraph), False, True), + ( + "scipy sparse array", + nx.to_scipy_sparse_array(G_NX_digraph), + False, + True, + ), + ("Pandas EdgeList", nx.to_pandas_edgelist(G_NX_digraph), False, True), + ("Pandas Adjacency", nx.to_pandas_adjacency(G_NX_digraph), False, True), + ], +) +def test_incoming_graph_data_not_nx_graph_digraph( + data_type: str, incoming_graph_data: Any, has_club: bool, has_weight: bool +) -> None: + # See nx.convert.to_networkx_graph for the official supported types + name = "KarateGraph" + db.delete_graph(name, drop_collections=True, ignore_missing=True) + + G = nxadb.DiGraph(incoming_graph_data=incoming_graph_data, name=name) + + assert ( + len(G.adj) + == len(G_NX_digraph.adj) + == db.collection(G.default_node_type).count() + ) + assert ( + len(G.nodes) + == len(G_NX_digraph.nodes) + == db.collection(G.default_node_type).count() + == G.number_of_nodes() + ) + edge_col = G.edge_type_func(G.default_node_type, G.default_node_type) + assert ( + len(G.edges) + == len(G_NX_digraph.edges) + == db.collection(edge_col).count() + == G.number_of_edges() + ) + assert has_club == ("club" in G.nodes["0"]) + assert has_weight == ("weight" in G.adj["0"]["1"]) diff --git a/tests/test_digraph.py b/tests/test_digraph.py new file mode 100644 index 00000000..70326218 --- /dev/null +++ b/tests/test_digraph.py @@ -0,0 +1,515 @@ +# type: ignore + +import time + +import networkx as nx +import pytest +from networkx.utils import nodes_equal + +import nx_arangodb as nxadb +from nx_arangodb.classes.dict.adj import ( + AdjListInnerDict, + AdjListOuterDict, + EdgeAttrDict, +) +from nx_arangodb.classes.dict.graph import GraphDict +from nx_arangodb.classes.dict.node import NodeAttrDict, NodeDict + +from .conftest import db + +# from .test_graph import TestEdgeSubgraph as _TestGraphEdgeSubgraph +from .test_graph import BaseAttrGraphTester, BaseGraphTester +from .test_graph import TestGraph as _TestGraph +from .test_graph import get_doc + +GRAPH_NAME = "test_graph" + + +class BaseDiGraphTester(BaseGraphTester): + def test_has_successor(self): + G = self.K3Graph() + assert G.has_successor(0, 1) + assert not G.has_successor(0, -1) + + def test_successors(self): + G = self.K3Graph() + assert sorted(G.successors(0)) == ["test_graph_node/1", "test_graph_node/2"] + with pytest.raises(nx.NetworkXError): + G.successors(-1) + + def test_has_predecessor(self): + G = self.K3Graph() + assert G.has_predecessor(0, 1) + assert not G.has_predecessor(0, -1) + + def test_predecessors(self): + G = self.K3Graph() + assert sorted(G.predecessors(0)) == [ + "test_graph_node/1", + "test_graph_node/2", + ] + with pytest.raises(nx.NetworkXError): + G.predecessors(-1) + + def test_edges(self): + G = self.K3Graph() + assert sorted(G.edges()) == [ + ("test_graph_node/0", "test_graph_node/1"), + ("test_graph_node/0", "test_graph_node/2"), + ("test_graph_node/1", "test_graph_node/0"), + ("test_graph_node/1", "test_graph_node/2"), + ("test_graph_node/2", "test_graph_node/0"), + ("test_graph_node/2", "test_graph_node/1"), + ] + assert sorted(G.edges(0)) == [ + ("test_graph_node/0", "test_graph_node/1"), + ("test_graph_node/0", "test_graph_node/2"), + ] + assert sorted(G.edges([0, 1])) == [ + ("test_graph_node/0", "test_graph_node/1"), + ("test_graph_node/0", "test_graph_node/2"), + ("test_graph_node/1", "test_graph_node/0"), + ("test_graph_node/1", "test_graph_node/2"), + ] + with pytest.raises(nx.NetworkXError): + G.edges(-1) + + def test_out_edges(self): + G = self.K3Graph() + assert sorted(G.out_edges()) == [ + ("test_graph_node/0", "test_graph_node/1"), + ("test_graph_node/0", "test_graph_node/2"), + ("test_graph_node/1", "test_graph_node/0"), + ("test_graph_node/1", "test_graph_node/2"), + ("test_graph_node/2", "test_graph_node/0"), + ("test_graph_node/2", "test_graph_node/1"), + ] + + assert sorted(G.out_edges(0)) == [ + ("test_graph_node/0", "test_graph_node/1"), + ("test_graph_node/0", "test_graph_node/2"), + ] + with pytest.raises(nx.NetworkXError): + G.out_edges(-1) + + def test_out_edges_dir(self): + G = self.P3Graph() + assert sorted(G.out_edges()) == [ + ("test_graph_node/0", "test_graph_node/1"), + ("test_graph_node/1", "test_graph_node/2"), + ] + assert sorted(G.out_edges("test_graph_node/0")) == [ + ("test_graph_node/0", "test_graph_node/1") + ] + assert sorted(G.out_edges(2)) == [] + + def test_out_edges_data(self): + G = self.EmptyGraph(incoming_graph_data=[(0, 1, {"data": 0}), (1, 0, {})]) + edge_0_1 = get_doc(G[0][1]["_id"]) + edge_1_0 = get_doc(G[1][0]["_id"]) + assert "data" in edge_0_1 + assert edge_0_1["data"] == 0 + assert "data" not in edge_1_0 + assert sorted(G.out_edges(data=True)) == [ + ("test_graph_node/0", "test_graph_node/1", edge_0_1), + ("test_graph_node/1", "test_graph_node/0", edge_1_0), + ] + assert sorted(G.out_edges(0, data=True)) == [ + ("test_graph_node/0", "test_graph_node/1", edge_0_1) + ] + assert sorted(G.out_edges(data="data")) == [ + ("test_graph_node/0", "test_graph_node/1", 0), + ("test_graph_node/1", "test_graph_node/0", None), + ] + assert sorted(G.out_edges(0, data="data")) == [ + ("test_graph_node/0", "test_graph_node/1", 0) + ] + + def test_in_edges_dir(self): + G = self.P3Graph() + assert sorted(G.in_edges()) == [ + ("test_graph_node/1", "test_graph_node/0"), + ("test_graph_node/2", "test_graph_node/1"), + ] + assert sorted(G.in_edges(0)) == [("test_graph_node/1", "test_graph_node/0")] + assert sorted(G.in_edges(2)) == [] + + def test_in_edges_data(self): + G = nx.DiGraph([(0, 1, {"data": 0}), (1, 0, {})]) + assert sorted(G.in_edges(data=True)) == [(0, 1, {"data": 0}), (1, 0, {})] + assert sorted(G.in_edges(1, data=True)) == [(0, 1, {"data": 0})] + assert sorted(G.in_edges(data="data")) == [(0, 1, 0), (1, 0, None)] + assert sorted(G.in_edges(1, data="data")) == [(0, 1, 0)] + + def test_degree(self): + G = self.K3Graph() + assert sorted(G.degree()) == [ + ("test_graph_node/0", 4), + ("test_graph_node/1", 4), + ("test_graph_node/2", 4), + ] + assert dict(G.degree()) == { + "test_graph_node/0": 4, + "test_graph_node/1": 4, + "test_graph_node/2": 4, + } + assert G.degree(0) == 4 + assert list(G.degree(iter([0]))) == [ + ("test_graph_node/0", 4) + ] # run through iterator + + def test_in_degree(self): + G = self.K3Graph() + assert sorted(G.in_degree()) == [ + ("test_graph_node/0", 2), + ("test_graph_node/1", 2), + ("test_graph_node/2", 2), + ] + assert dict(G.in_degree()) == { + "test_graph_node/0": 2, + "test_graph_node/1": 2, + "test_graph_node/2": 2, + } + assert G.in_degree(0) == 2 + assert list(G.in_degree(iter([0]))) == [ + ("test_graph_node/0", 2) + ] # run through iterator + + def test_out_degree(self): + G = self.K3Graph() + assert sorted(G.out_degree()) == [ + ("test_graph_node/0", 2), + ("test_graph_node/1", 2), + ("test_graph_node/2", 2), + ] + assert dict(G.out_degree()) == { + "test_graph_node/0": 2, + "test_graph_node/1": 2, + "test_graph_node/2": 2, + } + assert G.out_degree(0) == 2 + assert list(G.out_degree(iter([0]))) == [("test_graph_node/0", 2)] + + def test_size(self): + G = self.K3Graph() + assert G.size() == 6 + assert G.number_of_edges() == 6 + + def test_to_undirected_reciprocal(self): + G = self.EmptyGraph() + G.add_edge(1, 2) + assert G.to_undirected().has_edge("test_graph_node/1", "test_graph_node/2") + assert not G.to_undirected(reciprocal=True).has_edge(1, 2) + G.add_edge(2, 1) + assert G.to_undirected(reciprocal=True).has_edge( + "test_graph_node/1", "test_graph_node/2" + ) + + def test_reverse_copy(self): + G = self.EmptyGraph(incoming_graph_data=[(0, 1), (1, 2)]) + R = G.reverse() + assert sorted(R.edges()) == [ + ("test_graph_node/1", "test_graph_node/0"), + ("test_graph_node/2", "test_graph_node/1"), + ] + R.remove_edge("test_graph_node/1", "test_graph_node/0") + assert sorted(R.edges()) == [("test_graph_node/2", "test_graph_node/1")] + assert sorted(G.edges()) == [ + ("test_graph_node/0", "test_graph_node/1"), + ("test_graph_node/1", "test_graph_node/2"), + ] + + def test_reverse_nocopy(self): + G = self.EmptyGraph(incoming_graph_data=[(0, 1), (1, 2)]) + R = G.reverse(copy=False) + assert R[1][0] + assert R[2][1] + assert R._pred[0][1] + assert R._pred[1][2] + with pytest.raises(KeyError): + R[0][1] + with pytest.raises(KeyError): + R[1][2] + with pytest.raises(KeyError): + R._pred[1][0] + with pytest.raises(KeyError): + R._pred[2][1] + with pytest.raises(nx.NetworkXError): + R.remove_edge(1, 0) + + def test_reverse_hashable(self): + pytest.skip("Class-based nodes are not supported in ArangoDB.") + + class Foo: + pass + + x = Foo() + y = Foo() + G = self.EmptyGraph() + G.add_edge(x, y) + assert nodes_equal(G.nodes(), G.reverse().nodes()) + assert [(y, x)] == list(G.reverse().edges()) + + def test_di_cache_reset(self): + G = self.K3Graph().copy() + old_succ = G.succ + assert id(G.succ) == id(old_succ) + old_adj = G.adj + assert id(G.adj) == id(old_adj) + + G._succ = {} + assert id(G.succ) != id(old_succ) + assert id(G.adj) != id(old_adj) + + old_pred = G.pred + assert id(G.pred) == id(old_pred) + G._pred = {} + assert id(G.pred) != id(old_pred) + + def test_di_attributes_cached(self): + G = self.K3Graph() + assert id(G.in_edges) == id(G.in_edges) + assert id(G.out_edges) == id(G.out_edges) + assert id(G.in_degree) == id(G.in_degree) + assert id(G.out_degree) == id(G.out_degree) + assert id(G.succ) == id(G.succ) + assert id(G.pred) == id(G.pred) + + +class BaseAttrDiGraphTester(BaseDiGraphTester, BaseAttrGraphTester): + def test_edges_data(self): + G = self.K3 + all_edges = [ + (0, 1, {}), + (0, 2, {}), + (1, 0, {}), + (1, 2, {}), + (2, 0, {}), + (2, 1, {}), + ] + assert sorted(G.edges(data=True)) == all_edges + assert sorted(G.edges(0, data=True)) == all_edges[:2] + assert sorted(G.edges([0, 1], data=True)) == all_edges[:4] + with pytest.raises(nx.NetworkXError): + G.edges(-1, True) + + def test_in_degree_weighted(self): + G = self.K3.copy() + G.add_edge(0, 1, weight=0.3, other=1.2) + assert sorted(G.in_degree(weight="weight")) == [(0, 2), (1, 1.3), (2, 2)] + assert dict(G.in_degree(weight="weight")) == {0: 2, 1: 1.3, 2: 2} + assert G.in_degree(1, weight="weight") == 1.3 + assert sorted(G.in_degree(weight="other")) == [(0, 2), (1, 2.2), (2, 2)] + assert dict(G.in_degree(weight="other")) == {0: 2, 1: 2.2, 2: 2} + assert G.in_degree(1, weight="other") == 2.2 + assert list(G.in_degree(iter([1]), weight="other")) == [(1, 2.2)] + + def test_out_degree_weighted(self): + G = self.K3.copy() + G.add_edge(0, 1, weight=0.3, other=1.2) + assert sorted(G.out_degree(weight="weight")) == [(0, 1.3), (1, 2), (2, 2)] + assert dict(G.out_degree(weight="weight")) == {0: 1.3, 1: 2, 2: 2} + assert G.out_degree(0, weight="weight") == 1.3 + assert sorted(G.out_degree(weight="other")) == [(0, 2.2), (1, 2), (2, 2)] + assert dict(G.out_degree(weight="other")) == {0: 2.2, 1: 2, 2: 2} + assert G.out_degree(0, weight="other") == 2.2 + assert list(G.out_degree(iter([0]), weight="other")) == [(0, 2.2)] + + +class TestDiGraph(BaseAttrDiGraphTester, _TestGraph): + """Tests specific to dict-of-dict-of-dict digraph data structure""" + + def setup_method(self): + self.Graph = nx.DiGraph + # build dict-of-dict-of-dict K3 + ed1, ed2, ed3, ed4, ed5, ed6 = ({}, {}, {}, {}, {}, {}) + self.k3adj = {0: {1: ed1, 2: ed2}, 1: {0: ed3, 2: ed4}, 2: {0: ed5, 1: ed6}} + self.k3edges = [(0, 1), (0, 2), (1, 2)] + self.k3nodes = ["test_graph_node/0", "test_graph_node/1", "test_graph_node/2"] + self.K3 = self.Graph() + self.K3._succ = self.k3adj # K3._adj is synced with K3._succ + self.K3._pred = {0: {1: ed3, 2: ed5}, 1: {0: ed1, 2: ed6}, 2: {0: ed2, 1: ed4}} + self.K3._node = {} + self.K3._node[0] = {} + self.K3._node[1] = {} + self.K3._node[2] = {} + + ed1, ed2 = ({}, {}) + self.P3 = self.Graph() + self.P3._succ = {0: {1: ed1}, 1: {2: ed2}, 2: {}} + self.P3._pred = {0: {}, 1: {0: ed1}, 2: {1: ed2}} + # P3._adj is synced with P3._succ + self.P3._node = {} + self.P3._node[0] = {} + self.P3._node[1] = {} + self.P3._node[2] = {} + + def nxadb_graph_constructor(*args, **kwargs) -> nxadb.DiGraph: + db.delete_graph(GRAPH_NAME, drop_collections=True, ignore_missing=True) + G = nxadb.DiGraph(*args, **kwargs, name=GRAPH_NAME) + # Experimenting with a delay to see if it helps with CircleCI... + time.sleep(0.10) + return G + + self.K3Graph = lambda *args, **kwargs: nxadb_graph_constructor( + *args, **kwargs, incoming_graph_data=self.K3 + ) + self.Graph = self.K3Graph + self.P3Graph = lambda *args, **kwargs: nxadb_graph_constructor( + *args, **kwargs, incoming_graph_data=self.P3 + ) + self.EmptyGraph = lambda *args, **kwargs: nxadb_graph_constructor( + *args, **kwargs + ) + + def test_data_input(self): + # NOTE: Creating a DiGraph from incoming_graph_data + # that is *not* a NetworkX Graph will *symmetrize* the data! + # i.e symmetrize_edges = True (no way around it AFAIK) + G = self.EmptyGraph(incoming_graph_data={1: [2], 2: [1]}) + + assert G._succ[1][2]["_id"] != G._succ[2][1]["_id"] + assert G._pred[1][2]["_id"] != G._pred[2][1]["_id"] + assert G._succ[1][2]["_id"] == G._pred[2][1]["_id"] + assert G._succ[2][1]["_id"] == G._pred[1][2]["_id"] + + succ = { + "test_graph_node/1": {"test_graph_node/2": G._succ[1][2]}, + "test_graph_node/2": {"test_graph_node/1": G._succ[2][1]}, + } + pred = { + "test_graph_node/1": {"test_graph_node/2": G._pred[1][2]}, + "test_graph_node/2": {"test_graph_node/1": G._pred[2][1]}, + } + + assert dict(G.adj.items()) == succ + assert dict(G.succ.items()) == succ + assert dict(G.pred.items()) == pred + + def test_add_edge(self): + G = self.EmptyGraph() + assert G.number_of_edges() == 0 + G.add_edge(0, 1) + assert G.number_of_edges() == 1 + assert G.adj[0][1] + assert 1 in G.adj + assert len(G.adj[1]) == 0 + assert G.pred[1][0] + assert len(G.pred[0]) == 0 + assert 1 not in G.pred[0] + G = self.EmptyGraph() + G.add_edge(*(0, 1)) + assert G.number_of_edges() == 1 + assert G.adj[0][1] + assert 1 in G.adj + assert len(G.adj[1]) == 0 + assert G.pred[1][0] + assert len(G.pred[0]) == 0 + with pytest.raises(ValueError, match="Key cannot be None"): + G.add_edge(None, 3) + + def test_add_edges_from(self): + G = self.EmptyGraph() + G.add_edges_from([(0, 1), (0, 2, {"data": 3})], data=2) + assert "data" in G.adj[0][1] + assert G.adj[0][1]["data"] == 2 + assert G.succ[0][1]["data"] == 2 + assert G.pred[1][0]["data"] == 2 + + assert "data" in G.adj[0][2] + assert G.adj[0][2]["data"] == 3 + assert G.succ[0][2]["data"] == 3 + assert G.pred[2][0]["data"] == 3 + + with pytest.raises(nx.NetworkXError): + G.add_edges_from([(0,)]) # too few in tuple + with pytest.raises(nx.NetworkXError): + G.add_edges_from([(0, 1, 2, 3)]) # too many in tuple + with pytest.raises(TypeError): + G.add_edges_from([0]) # not a tuple + with pytest.raises(ValueError, match="Key cannot be None"): + G.add_edges_from([(None, 3), (3, 2)]) + + def test_remove_edge(self): + G = self.K3Graph() + assert G.number_of_edges() == 6 + G.remove_edge(0, 1) + assert G.number_of_edges() == 5 + assert "test_graph_node/1" not in G._succ["test_graph_node/0"].data + assert 1 not in G[0] + assert G[1][0] + with pytest.raises(nx.NetworkXError): + G.remove_edge(-1, 0) + + def test_remove_edges_from(self): + G = self.K3Graph() + assert G.number_of_edges() == 6 + G.remove_edges_from([(0, 1)]) + assert G.number_of_edges() == 5 + assert "test_graph_node/1" not in G._succ["test_graph_node/0"].data + assert 1 not in G[0] + assert G[1][0] + G.remove_edges_from([(0, 0)]) # silent fail + + def test_clear(self): + G = self.K3Graph() + G.graph["name"] = "K3" + G.clear() + assert G._node.data == {} + assert G._succ.data == {} + assert G._pred.data == {} + assert G.graph.data == {} + assert list(G.nodes) != [] + + def test_clear_edges(self): + G = self.K3Graph() + G.graph["name"] = "K3" + nodes = list(G.nodes) + + G._adj._fetch_all() + G.clear_edges() + + assert list(G.nodes) == nodes + + for node, adj_inner_dict in G._succ.data.items(): + assert node in G._pred.data + assert adj_inner_dict.data == {} + + for node, adj_inner_dict in G._pred.data.items(): + assert node in G._succ.data + assert adj_inner_dict.data == {} + + assert list(G.edges) != [] + assert G.graph["name"] == "K3" + + +# TODO: Implement this test +# class TestEdgeSubgraph(_TestGraphEdgeSubgraph): +# """Unit tests for the :meth:`DiGraph.edge_subgraph` method.""" + +# def setup_method(self): +# # Create a doubly-linked path graph on five nodes. +# G = nx.DiGraph(nx.path_graph(5)) +# # Add some node, edge, and graph attributes. +# for i in range(5): +# G.nodes[i]["name"] = f"node{i}" +# G.edges[0, 1]["name"] = "edge01" +# G.edges[3, 4]["name"] = "edge34" +# G.graph["name"] = "graph" +# # Get the subgraph induced by the first and last edges. +# self.G = G +# self.H = G.edge_subgraph([(0, 1), (3, 4)]) + +# def test_pred_succ(self): +# """Test that nodes are added to predecessors and successors. + +# For more information, see GitHub issue #2370. + +# """ +# G = nx.DiGraph() +# G.add_edge(0, 1) +# H = G.edge_subgraph([(0, 1)]) +# assert list(H.predecessors(0)) == [] +# assert list(H.successors(0)) == [1] +# assert list(H.predecessors(1)) == [0] +# assert list(H.successors(1)) == [] diff --git a/tests/test_graph.py b/tests/test_graph.py index d5d5c746..d2a13e51 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -779,21 +779,33 @@ def test_data_input(self): def test_adjacency(self): G = self.Graph() edge_0_1 = get_doc(G.adj[0][1]["_id"]) + edge_1_0 = get_doc(G.adj[1][0]["_id"]) edge_0_2 = get_doc(G.adj[0][2]["_id"]) - edge_1_2 = get_doc(G.adj[1][2]["_id"]) edge_2_0 = get_doc(G.adj[2][0]["_id"]) + edge_1_2 = get_doc(G.adj[1][2]["_id"]) + edge_2_1 = get_doc(G.adj[2][1]["_id"]) + + if G.is_directed(): + assert edge_0_1 != edge_1_0 + assert edge_0_2 != edge_2_0 + assert edge_1_2 != edge_2_1 + else: + assert edge_0_1 == edge_1_0 + assert edge_0_2 == edge_2_0 + assert edge_1_2 == edge_2_1 + assert dict(G.adjacency()) == { "test_graph_node/0": { "test_graph_node/1": edge_0_1, "test_graph_node/2": edge_0_2, }, "test_graph_node/1": { - "test_graph_node/0": edge_0_1, + "test_graph_node/0": edge_1_0, "test_graph_node/2": edge_1_2, }, "test_graph_node/2": { "test_graph_node/0": edge_2_0, - "test_graph_node/1": edge_1_2, + "test_graph_node/1": edge_2_1, }, }