Skip to content

changes to allow for compound foreign keys #203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions docs/python-api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,17 @@ You can leave off the third item in the tuple to have the referenced column auto
], foreign_keys=[
("author_id", "authors")
])

Compound foreign keys can be created by passing in tuples of columns rather than strings:

.. code-block:: python

foreign_keys=[
(("author_id", "person_id"), "authors", ("id", "person_id"))
]

This means that the ``author_id`` and ``person_id`` columns should be a compound foreign key that references the ``id`` and ``person_id`` columns in the ``authors`` table.


.. _python_api_table_configuration:

Expand Down Expand Up @@ -902,6 +913,20 @@ The ``table.add_foreign_key(column, other_table, other_column)`` method takes th

This method first checks that the specified foreign key references tables and columns that exist and does not clash with an existing foreign key. It will raise a ``sqlite_utils.db.AlterError`` exception if these checks fail.

You can add compound foreign keys by passing a tuple of column names. For example:

.. code-block:: python

db["authors"].insert_all([
{"id": 1, "person_id": 1, "name": "Sally"},
{"id": 2, "person_id": 2, "name": "Asheesh"}
], pk="id")
db["books"].insert_all([
{"title": "Hedgehogs of the world", "author_id": 1, "person_id": 1},
{"title": "How to train your wolf", "author_id": 2, "person_id": 2},
])
db["books"].add_foreign_key(("author_id", "person_id"), "authors", ("id", "person_id"))

To ignore the case where the key already exists, use ``ignore=True``:

.. code-block:: python
Expand Down
200 changes: 130 additions & 70 deletions sqlite_utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,46 @@
"least_common",
),
)
ForeignKey = namedtuple(
"ForeignKey", ("table", "column", "other_table", "other_column")
)
Index = namedtuple("Index", ("seq", "name", "unique", "origin", "partial", "columns"))
Trigger = namedtuple("Trigger", ("name", "table", "sql"))


class ForeignKey(ForeignKeyBase):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like ForeignKeyBase is missing - I added this and the tests started passing again:

ForeignKeyBase = namedtuple(
    "ForeignKey", ("table", "column", "other_table", "other_column")
)

def __new__(cls, table, column, other_table, other_column):
# column and other_column should be a tuple
if isinstance(column, (tuple, list)):
column = tuple(column)
else:
column = (column,)
if isinstance(other_column, (tuple, list)):
other_column = tuple(other_column)
else:
other_column = (other_column,)
self = super(ForeignKey, cls).__new__(
cls, table, column, other_table, other_column
)
return self

@property
def column_str(self):
return ",".join(["[{}]".format(c) for c in self.column])

@property
def other_column_str(self):
return ",".join(["[{}]".format(c) for c in self.other_column])

@property
def sql(self):
return (
"FOREIGN KEY({column}) REFERENCES [{other_table}]({other_column})".format(
table=self.table,
column=self.column_str,
other_table=self.other_table,
other_column=self.other_column_str,
)
)


DEFAULT = object()

COLUMN_TYPE_MAPPING = {
Expand Down Expand Up @@ -375,12 +408,9 @@ def create_table_sql(
pk = hash_id
# Soundness check foreign_keys point to existing tables
for fk in foreign_keys:
if not any(
c for c in self[fk.other_table].columns if c.name == fk.other_column
):
raise AlterError(
"No such column: {}.{}".format(fk.other_table, fk.other_column)
)
for oc in fk.other_column:
if not any(c for c in self[fk.other_table].columns if c.name == oc):
raise AlterError("No such column: {}.{}".format(fk.other_table, oc))

column_defs = []
# ensure pk is a tuple
Expand All @@ -401,13 +431,6 @@ def create_table_sql(
column_extras.append(
"DEFAULT {}".format(self.escape(defaults[column_name]))
)
if column_name in foreign_keys_by_column:
column_extras.append(
"REFERENCES [{other_table}]([{other_column}])".format(
other_table=foreign_keys_by_column[column_name].other_table,
other_column=foreign_keys_by_column[column_name].other_column,
)
)
column_defs.append(
" [{column_name}] {column_type}{column_extras}".format(
column_name=column_name,
Expand All @@ -422,6 +445,10 @@ def create_table_sql(
extra_pk = ",\n PRIMARY KEY ({pks})".format(
pks=", ".join(["[{}]".format(p) for p in pk])
)
for column_name in foreign_keys_by_column:
extra_pk += ",\n {}".format(
foreign_keys_by_column[column_name].sql,
)
columns_sql = ",\n".join(column_defs)
sql = """CREATE TABLE [{table}] (
{columns_sql}{extra_pk}
Expand Down Expand Up @@ -494,7 +521,7 @@ def m2m_table_candidates(self, table, other_table):
candidates.append(table.name)
return candidates

def add_foreign_keys(self, foreign_keys):
def add_foreign_keys(self, foreign_keys, ignore=True):
# foreign_keys is a list of explicit 4-tuples
assert all(
len(fk) == 4 and isinstance(fk, (list, tuple)) for fk in foreign_keys
Expand All @@ -503,43 +530,55 @@ def add_foreign_keys(self, foreign_keys):
foreign_keys_to_create = []

# Verify that all tables and columns exist
for table, column, other_table, other_column in foreign_keys:
if not self[table].exists():
raise AlterError("No such table: {}".format(table))
if column not in self[table].columns_dict:
raise AlterError("No such column: {} in {}".format(column, table))
if not self[other_table].exists():
raise AlterError("No such other_table: {}".format(other_table))
if (
other_column != "rowid"
and other_column not in self[other_table].columns_dict
):
raise AlterError(
"No such other_column: {} in {}".format(other_column, other_table)
for fk in foreign_keys:
if not isinstance(fk, ForeignKey):
fk = ForeignKey(
table=fk[0],
column=fk[1],
other_table=fk[2],
other_column=fk[3],
)

if not self[fk.table].exists():
raise AlterError("No such table: {}".format(fk.table))
for c in fk.column:
if c not in self[fk.table].columns_dict:
raise AlterError("No such column: {} in {}".format(c, fk.table))
if not self[fk.other_table].exists():
raise AlterError("No such other_table: {}".format(fk.other_table))
for c in fk.other_column:
if c != "rowid" and c not in self[fk.other_table].columns_dict:
raise AlterError(
"No such other_column: {} in {}".format(c, fk.other_table)
)
# We will silently skip foreign keys that exist already
if not any(
fk
for fk in self[table].foreign_keys
if fk.column == column
and fk.other_table == other_table
and fk.other_column == other_column
if any(
existing_fk
for existing_fk in self[fk.table].foreign_keys
if existing_fk.column == fk.column
and existing_fk.other_table == fk.other_table
and existing_fk.other_column == fk.other_column
):
foreign_keys_to_create.append(
(table, column, other_table, other_column)
)
if ignore:
continue
else:
raise AlterError(
"Foreign key already exists for {} => {}.{}".format(
fk.column_str, other_table, fk.other_column_str
)
)
else:
foreign_keys_to_create.append(fk)

# Construct SQL for use with "UPDATE sqlite_master SET sql = ? WHERE name = ?"
table_sql = {}
for table, column, other_table, other_column in foreign_keys_to_create:
old_sql = table_sql.get(table, self[table].schema)
extra_sql = ",\n FOREIGN KEY({column}) REFERENCES {other_table}({other_column})\n".format(
column=column, other_table=other_table, other_column=other_column
)
for fk in foreign_keys_to_create:
old_sql = table_sql.get(fk.table, self[fk.table].schema)
extra_sql = ",\n {}\n".format(fk.sql)
# Stick that bit in at the very end just before the closing ')'
last_paren = old_sql.rindex(")")
new_sql = old_sql[:last_paren].strip() + extra_sql + old_sql[last_paren:]
table_sql[table] = new_sql
table_sql[fk.table] = new_sql

# And execute it all within a single transaction
with self.conn:
Expand All @@ -560,12 +599,10 @@ def add_foreign_keys(self, foreign_keys):
def index_foreign_keys(self):
for table_name in self.table_names():
table = self[table_name]
existing_indexes = {
i.columns[0] for i in table.indexes if len(i.columns) == 1
}
existing_indexes = {tuple(i.columns) for i in table.indexes}
for fk in table.foreign_keys:
if fk.column not in existing_indexes:
table.create_index([fk.column])
table.create_index(fk.column)

def vacuum(self):
self.execute("VACUUM;")
Expand Down Expand Up @@ -701,21 +738,30 @@ def get(self, pk_values):

@property
def foreign_keys(self):
fks = []
fks = {}
for row in self.db.execute(
"PRAGMA foreign_key_list([{}])".format(self.name)
).fetchall():
if row is not None:
id, seq, table_name, from_, to_, on_update, on_delete, match = row
fks.append(
ForeignKey(
table=self.name,
column=from_,
other_table=table_name,
other_column=to_,
)
)
return fks
if id not in fks:
fks[id] = {
"column": [],
"other_column": [],
}
fks[id]["table"] = self.name
fks[id]["column"].append(from_)
fks[id]["other_table"] = table_name
fks[id]["other_column"].append(to_)
return [
ForeignKey(
table=fk["table"],
column=tuple(fk["column"]),
other_table=fk["other_table"],
other_column=tuple(fk["other_column"]),
)
for fk in fks.values()
]

@property
def virtual_table_using(self):
Expand Down Expand Up @@ -899,10 +945,16 @@ def transform_sql(

# foreign_keys
create_table_foreign_keys = []
for table, column, other_table, other_column in self.foreign_keys:
if (drop_foreign_keys is None) or (column not in drop_foreign_keys):
if drop_foreign_keys:
drop_foreign_keys = [
(f,) if not isinstance(f, (tuple, list)) else tuple(f)
for f in drop_foreign_keys
]
for fk in self.foreign_keys:
if (drop_foreign_keys is None) or (fk.column not in drop_foreign_keys):
column_names = tuple(rename.get(c) or c for c in fk.column)
create_table_foreign_keys.append(
(rename.get(column) or column, other_table, other_column)
(column_names, fk.other_table, fk.other_column)
)

if column_order is not None:
Expand Down Expand Up @@ -1102,6 +1154,8 @@ def drop(self):
self.db.execute("DROP TABLE [{}]".format(self.name))

def guess_foreign_table(self, column):
if isinstance(column, (tuple, list)):
column = column[0]
column = column.lower()
possibilities = [column]
if column.endswith("_id"):
Expand Down Expand Up @@ -1134,22 +1188,28 @@ def guess_foreign_column(self, other_table):
def add_foreign_key(
self, column, other_table=None, other_column=None, ignore=False
):
if not isinstance(column, (tuple, list)):
column = (column,)
# Ensure column exists
if column not in self.columns_dict:
raise AlterError("No such column: {}".format(column))
for c in column:
if c not in self.columns_dict:
raise AlterError("No such column: {}".format(c))
# If other_table is not specified, attempt to guess it from the column
if other_table is None:
other_table = self.guess_foreign_table(column)
# If other_column is not specified, detect the primary key on other_table
if other_column is None:
other_column = self.guess_foreign_column(other_table)
if not isinstance(other_column, (tuple, list)):
other_column = (other_column,)

# Soundness check that the other column exists
if (
not [c for c in self.db[other_table].columns if c.name == other_column]
and other_column != "rowid"
):
raise AlterError("No such column: {}.{}".format(other_table, other_column))
for oc in other_column:
if (
not [c for c in self.db[other_table].columns if c.name == oc]
and oc != "rowid"
):
raise AlterError("No such column: {}.{}".format(other_table, oc))
# Check we do not already have an existing foreign key
if any(
fk
Expand All @@ -1163,7 +1223,7 @@ def add_foreign_key(
else:
raise AlterError(
"Foreign key already exists for {} => {}.{}".format(
column, other_table, other_column
",".join(column), other_table, ",".join(other_column)
)
)
self.db.add_foreign_keys([(self.name, column, other_table, other_column)])
Expand Down
Loading