Skip to content

GA-163 | test_digraph #40

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 57 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
b315606
GA-163 | initial commit
aMahanna Aug 16, 2024
aa99026
Merge branch 'main' into GA-163
aMahanna Aug 16, 2024
26e1a85
unlock adbnx
aMahanna Aug 16, 2024
d6da2a3
fix: `incoming_graph_data`
aMahanna Aug 16, 2024
6561949
fix: incoming_graph_data
aMahanna Aug 17, 2024
469de71
fix: off-by-one IDs
aMahanna Aug 17, 2024
b742027
checkpoint
aMahanna Aug 17, 2024
8b47e4d
checkpoint: `BaseGraphTester` is passing
aMahanna Aug 17, 2024
0483486
checkpoint: BaseGraphAttrTester
aMahanna Aug 18, 2024
1ed111e
cleanup: `aql_fetch_data`, `aql_fetch_data_edge`
aMahanna Aug 18, 2024
f5963a6
use pytest skip for failing tests
aMahanna Aug 18, 2024
eb6717e
checkpoint: optimize `__iter__`
aMahanna Aug 18, 2024
04dc9c1
checkpoint: run `test_graph`
aMahanna Aug 19, 2024
2199ae3
add comment
aMahanna Aug 19, 2024
173f0a7
Merge branch 'main' into GA-163
aMahanna Aug 19, 2024
bc64fe9
checkpoint
aMahanna Aug 19, 2024
0df6c2b
attempt: slleep
aMahanna Aug 19, 2024
aa4b336
fix: lint
aMahanna Aug 19, 2024
5aa3eb2
cleanup: getitem
aMahanna Aug 20, 2024
b03f4cf
cleanup: copy
aMahanna Aug 20, 2024
27adfa3
attempt: shorten sleep
aMahanna Aug 20, 2024
c34898a
Merge branch 'main' into GA-163
aMahanna Aug 20, 2024
0d18563
fix: `__set_adj_elements`
aMahanna Aug 20, 2024
b0434a9
fix: mypy
aMahanna Aug 20, 2024
3f07ae1
attempt: decrease sleep
aMahanna Aug 20, 2024
8b87046
Merge branch 'main' into GA-163
aMahanna Aug 20, 2024
0c48931
GA-163 | `test_digraph`
aMahanna Aug 21, 2024
5907d8e
checkpoint
aMahanna Aug 21, 2024
de17587
fix: set `self.Graph`
aMahanna Aug 21, 2024
3144429
add type ignore
aMahanna Aug 21, 2024
28dd130
Merge branch 'main' into GA-163
aMahanna Aug 21, 2024
443d436
Merge branch 'main' into GA-163
aMahanna Aug 21, 2024
dcb94ff
fix: graph name
aMahanna Aug 21, 2024
0fc2568
Merge branch 'GA-163' into GA-163-digraph
aMahanna Aug 21, 2024
5264178
fix: graph name
aMahanna Aug 21, 2024
b805419
adjust assertions to exclude _rev, set `use_experimental_views`
aMahanna Aug 21, 2024
c75b758
Revert "adjust assertions to exclude _rev, set `use_experimental_views`"
aMahanna Aug 21, 2024
e7339de
fix: `_rev`, `use_experimental_views`
aMahanna Aug 21, 2024
f79fabb
Merge branch 'GA-163' into GA-163-digraph
aMahanna Aug 21, 2024
4c961d0
set `use_experimental_views`
aMahanna Aug 21, 2024
058d687
fix: lint
aMahanna Aug 21, 2024
6e5b504
new: `nbunch_iter` override
aMahanna Aug 21, 2024
dd17bb5
Merge branch 'GA-163' into GA-163-digraph
aMahanna Aug 22, 2024
e64781e
set experimental views to false
aMahanna Aug 22, 2024
f19839e
Merge branch 'GA-163' into GA-163-digraph
aMahanna Aug 22, 2024
6bf6a0a
set experimental views to false
aMahanna Aug 22, 2024
27dd792
cleanup
aMahanna Aug 22, 2024
69489a7
Merge branch 'main' into GA-163-digraph
aMahanna Aug 23, 2024
b3bb830
fix: `function.py`
aMahanna Aug 23, 2024
767301f
cleanup: `graph`, `digraph`
aMahanna Aug 23, 2024
9d30054
fix: `test_data_input`
aMahanna Aug 23, 2024
25b6418
attempt: wait for CircleCI
aMahanna Aug 23, 2024
6cdf2b8
fix: nx graph
aMahanna Aug 23, 2024
3b3b09c
remove sleep
aMahanna Aug 23, 2024
df237dd
new: `override` suffix
aMahanna Aug 23, 2024
6789e6c
enable more tests
aMahanna Aug 26, 2024
34af3fc
fix: lint
aMahanna Aug 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nx_arangodb/classes/dict/adj.py
Original file line number Diff line number Diff line change
Expand Up @@ -1557,7 +1557,7 @@ def keys(self) -> Any:

@logger_debug
def clear(self) -> None:
"""g._node.clear()"""
"""g._adj.clear()"""
self.data.clear()
self.FETCHED_ALL_DATA = False
self.FETCHED_ALL_IDS = False
Expand Down
41 changes: 33 additions & 8 deletions nx_arangodb/classes/digraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

import nx_arangodb as nxadb
from nx_arangodb.classes.graph import Graph
from nx_arangodb.logger import logger

from .dict.adj import AdjListOuterDict
from .enum import TraversalDirection
from .function import get_node_id

networkx_api = nxadb.utils.decorators.networkx_class(nx.DiGraph) # type: ignore

Expand Down Expand Up @@ -57,12 +59,9 @@ def __init__(
)

if self.graph_exists_in_db:
assert isinstance(self._succ, AdjListOuterDict)
assert isinstance(self._pred, AdjListOuterDict)
self._succ.mirror = self._pred
self._pred.mirror = self._succ
self._succ.traversal_direction = TraversalDirection.OUTBOUND
self._pred.traversal_direction = TraversalDirection.INBOUND
self.clear_edges = self.clear_edges_override
self.add_node = self.add_node_override
self.remove_node = self.remove_node_override

#######################
# nx.DiGraph Overides #
Expand All @@ -78,7 +77,14 @@ def __init__(
# def out_edges(self):
# pass

def add_node(self, node_for_adding, **attr):
def clear_edges_override(self):
logger.info("Note that clearing edges ony erases the edges in the local cache")
for predecessor_dict in self._pred.data.values():
predecessor_dict.clear()

super().clear_edges()

def add_node_override(self, node_for_adding, **attr):
if node_for_adding not in self._succ:
if node_for_adding is None:
raise ValueError("None cannot be a node")
Expand Down Expand Up @@ -109,7 +115,10 @@ def add_node(self, node_for_adding, **attr):

nx._clear_cache(self)

def remove_node(self, n):
def remove_node_override(self, n):
if isinstance(n, (str, int)):
n = get_node_id(str(n), self.default_node_type)

try:

######################
Expand All @@ -136,6 +145,22 @@ def remove_node(self, n):
del self._pred[u][n] # remove all edges n-u in digraph
del self._succ[n] # remove node from succ
for u in nbrs_pred:
######################
# NOTE: Monkey patch #
######################

# Old: Nothing

# New:
if u == n:
continue # skip self loops

# Reason: We need to skip self loops, as they are
# already taken care of in the previous step. This
# avoids getting a KeyError on the next line.

###########################

del self._succ[u][n] # remove all edges n-u in digraph
del self._pred[n] # remove node from pred
nx._clear_cache(self)
113 changes: 65 additions & 48 deletions nx_arangodb/classes/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
node_attr_dict_factory,
node_dict_factory,
)
from .dict.adj import AdjListOuterDict
from .enum import TraversalDirection
from .function import get_node_id
from .reportviews import CustomEdgeView, CustomNodeView

Expand Down Expand Up @@ -97,6 +99,7 @@ def __init__(
# m = "Must set **graph_name** if passing **incoming_graph_data**"
# raise ValueError(m)

loaded_incoming_graph_data = False
if self._graph_exists_in_db:
if incoming_graph_data is not None:
m = "Cannot pass both **incoming_graph_data** and **graph_name** yet if the already graph exists" # noqa: E501
Expand Down Expand Up @@ -171,29 +174,44 @@ def edge_type_func(u: str, v: str) -> str:
use_async=True,
)

loaded_incoming_graph_data = True

else:
self.adb_graph = self.db.create_graph(
self.__name,
edge_definitions=edge_definitions,
)

# Let the parent class handle the incoming graph data
# if it is not a networkx.Graph object
kwargs["incoming_graph_data"] = incoming_graph_data

self._set_factory_methods()
self._set_arangodb_backend_config()
logger.info(f"Graph '{name}' created.")
self._graph_exists_in_db = True

else:
kwargs["incoming_graph_data"] = incoming_graph_data

if name is not None:
kwargs["name"] = name
if self.__name is not None:
kwargs["name"] = self.__name

super().__init__(*args, **kwargs)

if self.is_directed() and self.graph_exists_in_db:
assert isinstance(self._succ, AdjListOuterDict)
assert isinstance(self._pred, AdjListOuterDict)
self._succ.mirror = self._pred
self._pred.mirror = self._succ
self._succ.traversal_direction = TraversalDirection.OUTBOUND
self._pred.traversal_direction = TraversalDirection.INBOUND

if incoming_graph_data is not None and not loaded_incoming_graph_data:
nx.convert.to_networkx_graph(incoming_graph_data, create_using=self)

if self.graph_exists_in_db:
self.copy = self.copy_override
self.subgraph = self.subgraph_override
self.clear = self.clear_override
self.clear_edges = self.clear_edges_override
self.add_node = self.add_node_override
self.number_of_edges = self.number_of_edges_override
self.nbunch_iter = self.nbunch_iter_override

#######################
# Init helper methods #
#######################
Expand Down Expand Up @@ -345,6 +363,9 @@ def _set_graph_name(self, graph_name: str | None = None) -> None:
# ArangoDB Methods #
####################

def clear_nxcg_cache(self):
self.nxcg_graph = None

def aql(self, query: str, bind_vars: dict[str, Any] = {}, **kwargs: Any) -> Cursor:
return nxadb.classes.function.aql(self.db, query, bind_vars, **kwargs)

Expand All @@ -355,7 +376,7 @@ def aql(self, query: str, bind_vars: dict[str, Any] = {}, **kwargs: Any) -> Curs
# NOTE: OUT OF SERVICE
# def chat(self, prompt: str) -> str:
# if self.__qa_chain is None:
# if not self.__graph_exists_in_db:
# if not self.graph_exists_in_db:
# return "Could not initialize QA chain: Graph does not exist"

# # try:
Expand All @@ -381,30 +402,6 @@ def aql(self, query: str, bind_vars: dict[str, Any] = {}, **kwargs: Any) -> Curs
# nx.Graph Overides #
#####################

def copy(self, *args, **kwargs):
logger.warning("Note that copying a graph loses the connection to the database")
G = super().copy(*args, **kwargs)
G.node_dict_factory = nx.Graph.node_dict_factory
G.node_attr_dict_factory = nx.Graph.node_attr_dict_factory
G.edge_attr_dict_factory = nx.Graph.edge_attr_dict_factory
G.adjlist_inner_dict_factory = nx.Graph.adjlist_inner_dict_factory
G.adjlist_outer_dict_factory = nx.Graph.adjlist_outer_dict_factory
return G

def subgraph(self, nbunch):
raise NotImplementedError("Subgraphing is not yet implemented")

def clear(self):
logger.info("Note that clearing only erases the local cache")
super().clear()

def clear_edges(self):
logger.info("Note that clearing edges ony erases the edges in the local cache")
super().clear_edges()

def clear_nxcg_cache(self):
self.nxcg_graph = None

@cached_property
def nodes(self):
if self.__use_experimental_views and self.graph_exists_in_db:
Expand Down Expand Up @@ -437,7 +434,30 @@ def edges(self):

return super().edges

def add_node(self, node_for_adding, **attr):
def copy_override(self, *args, **kwargs):
logger.warning("Note that copying a graph loses the connection to the database")
G = super().copy(*args, **kwargs)
G.node_dict_factory = nx.Graph.node_dict_factory
G.node_attr_dict_factory = nx.Graph.node_attr_dict_factory
G.edge_attr_dict_factory = nx.Graph.edge_attr_dict_factory
G.adjlist_inner_dict_factory = nx.Graph.adjlist_inner_dict_factory
G.adjlist_outer_dict_factory = nx.Graph.adjlist_outer_dict_factory
return G

def subgraph_override(self, nbunch):
raise NotImplementedError("Subgraphing is not yet implemented")

def clear_override(self):
logger.info("Note that clearing only erases the local cache")
super().clear()

def clear_edges_override(self):
logger.info("Note that clearing edges ony erases the edges in the local cache")
for nbr_dict in self._adj.data.values():
nbr_dict.clear()
nx._clear_cache(self)

def add_node_override(self, node_for_adding, **attr):
if node_for_adding not in self._node:
if node_for_adding is None:
raise ValueError("None cannot be a node")
Expand Down Expand Up @@ -467,10 +487,7 @@ def add_node(self, node_for_adding, **attr):

nx._clear_cache(self)

def number_of_edges(self, u=None, v=None):
if not self.graph_exists_in_db:
return super().number_of_edges(u, v)

def number_of_edges_override(self, u=None, v=None):
if u is not None:
return super().number_of_edges(u, v)

Expand All @@ -494,10 +511,7 @@ def number_of_edges(self, u=None, v=None):
# It is more efficient to count the number of edges in the edge collections
# compared to relying on the DegreeView.

def nbunch_iter(self, nbunch=None):
if not self._graph_exists_in_db:
return super().nbunch_iter(nbunch)

def nbunch_iter_override(self, nbunch=None):
if nbunch is None:
bunch = iter(self._adj)
elif nbunch in self:
Expand All @@ -508,12 +522,13 @@ def nbunch_iter(self, nbunch=None):
# Old: Nothing

# New:
if isinstance(nbunch, int):
if isinstance(nbunch, (str, int)):
nbunch = get_node_id(str(nbunch), self.default_node_type)

# Reason:
# ArangoDB only uses strings as node IDs. Therefore, we need to convert
# the integer node ID to a string before using it in an iterator.
# the non-prefixed node ID to an ArangoDB ID before
# using it in an iterator.

bunch = iter([nbunch])
else:
Expand All @@ -528,13 +543,15 @@ def bunch_iter(nlist, adj):
# Old: Nothing

# New:
if isinstance(n, int):
if isinstance(n, (str, int)):
n = get_node_id(str(n), self.default_node_type)

# Reason:
# ArangoDB only uses strings as node IDs. Therefore,
# we need to convert the integer node ID to a
# string before using it in an iterator.
# we need to convert non-prefixed node IDs to an
# ArangoDB ID before using it in an iterator.

######################

if n in adj:
yield n
Expand Down
8 changes: 4 additions & 4 deletions nx_arangodb/classes/multigraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def __init__(
**kwargs,
)

if self.graph_exists_in_db:
self.add_edge = self.add_edge_override

#######################
# Init helper methods #
#######################
Expand All @@ -74,10 +77,7 @@ def _set_factory_methods(self) -> None:
# nx.MultiGraph Overides #
##########################

def add_edge(self, u_for_edge, v_for_edge, key=None, **attr):
if not self.graph_exists_in_db:
return super().add_edge(u_for_edge, v_for_edge, key=key, **attr)

def add_edge_override(self, u_for_edge, v_for_edge, key=None, **attr):
if key is not None:
m = "ArangoDB MultiGraph does not support custom edge keys yet."
logger.warning(m)
Expand Down
58 changes: 58 additions & 0 deletions tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from .conftest import create_line_graph, db

G_NX = nx.karate_club_graph()
G_NX_digraph = nx.DiGraph(G_NX)
G_NX_multigraph = nx.MultiGraph(G_NX)
G_NX_multidigraph = nx.MultiDiGraph(G_NX)


def assert_remote_dict(G: nxadb.Graph) -> None:
Expand Down Expand Up @@ -1777,3 +1780,58 @@ def test_incoming_graph_data_not_nx_graph(
)
assert has_club == ("club" in G.nodes["0"])
assert has_weight == ("weight" in G.adj["0"]["1"])


@pytest.mark.parametrize(
"data_type, incoming_graph_data, has_club, has_weight",
[
("dict of dicts", G_NX_digraph._adj, False, True),
(
"dict of lists",
{k: list(v) for k, v in G_NX_digraph._adj.items()},
False,
False,
),
("container of edges", list(G_NX_digraph.edges), False, False),
("iterator of edges", iter(G_NX_digraph.edges), False, False),
("generator of edges", (e for e in G_NX_digraph.edges), False, False),
("2D numpy array", nx.to_numpy_array(G_NX_digraph), False, True),
(
"scipy sparse array",
nx.to_scipy_sparse_array(G_NX_digraph),
False,
True,
),
("Pandas EdgeList", nx.to_pandas_edgelist(G_NX_digraph), False, True),
("Pandas Adjacency", nx.to_pandas_adjacency(G_NX_digraph), False, True),
],
)
def test_incoming_graph_data_not_nx_graph_digraph(
data_type: str, incoming_graph_data: Any, has_club: bool, has_weight: bool
) -> None:
# See nx.convert.to_networkx_graph for the official supported types
name = "KarateGraph"
db.delete_graph(name, drop_collections=True, ignore_missing=True)

G = nxadb.DiGraph(incoming_graph_data=incoming_graph_data, name=name)

assert (
len(G.adj)
== len(G_NX_digraph.adj)
== db.collection(G.default_node_type).count()
)
assert (
len(G.nodes)
== len(G_NX_digraph.nodes)
== db.collection(G.default_node_type).count()
== G.number_of_nodes()
)
edge_col = G.edge_type_func(G.default_node_type, G.default_node_type)
assert (
len(G.edges)
== len(G_NX_digraph.edges)
== db.collection(edge_col).count()
== G.number_of_edges()
)
assert has_club == ("club" in G.nodes["0"])
assert has_weight == ("weight" in G.adj["0"]["1"])
Loading