Skip to content

Commit c951148

Browse files
authored
Merge pull request #298 from duckdb/jwills_add_retries_for_some_errors
Add support for retrying certain types of exceptions we see when running models with DuckDB
2 parents 36b4ec1 + 6c9dffe commit c951148

File tree

5 files changed

+185
-1
lines changed

5 files changed

+185
-1
lines changed

dbt/adapters/duckdb/credentials.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import time
33
from dataclasses import dataclass
4+
from dataclasses import field
45
from functools import lru_cache
56
from typing import Any
67
from typing import Dict
@@ -61,6 +62,20 @@ class Remote(dbtClassMixin):
6162
password: Optional[str] = None
6263

6364

65+
@dataclass
66+
class Retries(dbtClassMixin):
67+
# The number of times to attempt the initial duckdb.connect call
68+
# (to wait for another process to free the lock on the DB file)
69+
connect_attempts: int = 1
70+
71+
# The number of times to attempt to execute a DuckDB query that throws
72+
# one of the retryable exceptions
73+
query_attempts: Optional[int] = None
74+
75+
# The list of exceptions that we are willing to retry on
76+
retryable_exceptions: List[str] = field(default_factory=lambda: ["IOException"])
77+
78+
6479
@dataclass
6580
class DuckDBCredentials(Credentials):
6681
database: str = "main"
@@ -126,6 +141,11 @@ class DuckDBCredentials(Credentials):
126141
# provide helper functions for dbt Python models.
127142
module_paths: Optional[List[str]] = None
128143

144+
# An optional strategy for allowing retries when certain types of
145+
# exceptions occur on a model run (e.g., IOExceptions that were caused
146+
# by networking issues)
147+
retries: Optional[Retries] = None
148+
129149
@classmethod
130150
def __pre_deserialize__(cls, data: Dict[Any, Any]) -> Dict[Any, Any]:
131151
data = super().__pre_deserialize__(data)

dbt/adapters/duckdb/environments/__init__.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import os
44
import sys
55
import tempfile
6+
import time
67
from typing import Dict
8+
from typing import List
79
from typing import Optional
810

911
import duckdb
@@ -31,6 +33,44 @@ def _ensure_event_loop():
3133
asyncio.set_event_loop(loop)
3234

3335

36+
class RetryableCursor:
37+
def __init__(self, cursor, retry_attempts: int, retryable_exceptions: List[str]):
38+
self._cursor = cursor
39+
self._retry_attempts = retry_attempts
40+
self._retryable_exceptions = retryable_exceptions
41+
42+
def execute(self, sql: str, bindings=None):
43+
attempt, success, exc = 0, False, None
44+
while not success and attempt < self._retry_attempts:
45+
try:
46+
if bindings is None:
47+
self._cursor.execute(sql)
48+
else:
49+
self._cursor.execute(sql, bindings)
50+
success = True
51+
except Exception as e:
52+
exception_name = type(e).__name__
53+
if exception_name in self._retryable_exceptions:
54+
time.sleep(2**attempt)
55+
exc = e
56+
attempt += 1
57+
else:
58+
print(f"Did not retry exception named '{exception_name}'")
59+
raise e
60+
if not success:
61+
if exc:
62+
raise exc
63+
else:
64+
raise RuntimeError(
65+
"execute call failed, but no exceptions raised- this should be impossible"
66+
)
67+
return self
68+
69+
# forward along all non-execute() methods/attribute look-ups
70+
def __getattr__(self, name):
71+
return getattr(self._cursor, name)
72+
73+
3474
class Environment(abc.ABC):
3575
"""An Environment is an abstraction to describe *where* the code you execute in your dbt-duckdb project
3676
actually runs. This could be the local Python process that runs dbt (which is the default),
@@ -74,7 +114,32 @@ def initialize_db(
74114
cls, creds: DuckDBCredentials, plugins: Optional[Dict[str, BasePlugin]] = None
75115
):
76116
config = creds.config_options or {}
77-
conn = duckdb.connect(creds.path, read_only=False, config=config)
117+
118+
if creds.retries:
119+
success, attempt, exc = False, 0, None
120+
while not success and attempt < creds.retries.connect_attempts:
121+
try:
122+
conn = duckdb.connect(creds.path, read_only=False, config=config)
123+
success = True
124+
except Exception as e:
125+
exception_name = type(e).__name__
126+
if exception_name in creds.retries.retryable_exceptions:
127+
time.sleep(2**attempt)
128+
exc = e
129+
attempt += 1
130+
else:
131+
print(f"Did not retry exception named '{exception_name}'")
132+
raise e
133+
if not success:
134+
if exc:
135+
raise exc
136+
else:
137+
raise RuntimeError(
138+
"connect call failed, but no exceptions raised- this should be impossible"
139+
)
140+
141+
else:
142+
conn = duckdb.connect(creds.path, read_only=False, config=config)
78143

79144
# install any extensions on the connection
80145
if creds.extensions is not None:
@@ -127,6 +192,11 @@ def initialize_cursor(
127192
for df_name, df in registered_df.items():
128193
cursor.register(df_name, df)
129194

195+
if creds.retries and creds.retries.query_attempts:
196+
cursor = RetryableCursor(
197+
cursor, creds.retries.query_attempts, creds.retries.retryable_exceptions
198+
)
199+
130200
return cursor
131201

132202
@classmethod

tests/functional/plugins/test_plugins.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def profiles_config_update(self, dbt_profile_target, sqlite_test_db):
9191
"type": "duckdb",
9292
"path": dbt_profile_target.get("path", ":memory:"),
9393
"plugins": plugins,
94+
"retries": {"query_attempts": 2},
9495
}
9596
},
9697
"target": "dev",

tests/unit/test_retries_connect.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import pytest
2+
from unittest.mock import patch
3+
4+
from duckdb.duckdb import IOException
5+
6+
from dbt.adapters.duckdb.credentials import DuckDBCredentials
7+
from dbt.adapters.duckdb.credentials import Retries
8+
from dbt.adapters.duckdb.environments import Environment
9+
10+
class TestConnectRetries:
11+
12+
@pytest.fixture
13+
def creds(self):
14+
# Create a mock credentials object
15+
return DuckDBCredentials(
16+
path="foo.db",
17+
retries=Retries(connect_attempts=2, retryable_exceptions=["IOException", "ArithmeticError"])
18+
)
19+
20+
@pytest.mark.parametrize("exception", [None, IOException, ArithmeticError, ValueError])
21+
def test_initialize_db(self, creds, exception):
22+
# Mocking the duckdb.connect method
23+
with patch('duckdb.connect') as mock_connect:
24+
if exception:
25+
mock_connect.side_effect = [exception, None]
26+
27+
if exception == ValueError:
28+
with pytest.raises(ValueError) as excinfo:
29+
Environment.initialize_db(creds)
30+
else:
31+
# Call the initialize_db method
32+
Environment.initialize_db(creds)
33+
if exception in {IOException, ArithmeticError}:
34+
assert mock_connect.call_count == creds.retries.connect_attempts
35+
else:
36+
mock_connect.assert_called_once_with(creds.path, read_only=False, config={})

tests/unit/test_retries_query.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import pytest
2+
from unittest.mock import MagicMock
3+
from unittest.mock import patch
4+
5+
import duckdb
6+
7+
from dbt.adapters.duckdb.credentials import Retries
8+
from dbt.adapters.duckdb.environments import RetryableCursor
9+
10+
class TestRetryableCursor:
11+
12+
@pytest.fixture
13+
def mock_cursor(self):
14+
return MagicMock()
15+
16+
@pytest.fixture
17+
def mock_retries(self):
18+
return Retries(query_attempts=3)
19+
20+
@pytest.fixture
21+
def retry_cursor(self, mock_cursor, mock_retries):
22+
return RetryableCursor(
23+
mock_cursor,
24+
mock_retries.query_attempts,
25+
mock_retries.retryable_exceptions)
26+
27+
def test_successful_execute(self, mock_cursor, retry_cursor):
28+
""" Test that execute successfully runs the SQL query. """
29+
sql_query = "SELECT * FROM table"
30+
retry_cursor.execute(sql_query)
31+
mock_cursor.execute.assert_called_once_with(sql_query)
32+
33+
def test_retry_on_failure(self, mock_cursor, retry_cursor):
34+
""" Test that execute retries the SQL query on failure. """
35+
mock_cursor.execute.side_effect = [duckdb.duckdb.IOException, None]
36+
sql_query = "SELECT * FROM table"
37+
retry_cursor.execute(sql_query)
38+
assert mock_cursor.execute.call_count == 2
39+
40+
def test_no_retry_on_non_retryable_exception(self, mock_cursor, retry_cursor):
41+
""" Test that a non-retryable exception is not retried. """
42+
mock_cursor.execute.side_effect = ValueError
43+
sql_query = "SELECT * FROM table"
44+
with pytest.raises(ValueError):
45+
retry_cursor.execute(sql_query)
46+
mock_cursor.execute.assert_called_once_with(sql_query)
47+
48+
def test_exponential_backoff(self, mock_cursor, retry_cursor):
49+
""" Test that exponential backoff is applied between retries. """
50+
mock_cursor.execute.side_effect = [duckdb.duckdb.IOException, duckdb.duckdb.IOException, None]
51+
sql_query = "SELECT * FROM table"
52+
53+
with patch("time.sleep") as mock_sleep:
54+
retry_cursor.execute(sql_query)
55+
assert mock_sleep.call_count == 2
56+
mock_sleep.assert_any_call(1)
57+
mock_sleep.assert_any_call(2)

0 commit comments

Comments
 (0)