Skip to content

Commit c786def

Browse files
KuuCiv-chen_data
andauthored
Add proper user error for accessing schema (#1548)
Co-authored-by: v-chen_data <[email protected]>
1 parent 722526d commit c786def

File tree

2 files changed

+58
-1
lines changed

2 files changed

+58
-1
lines changed

llmfoundry/command_utils/data_prep/convert_delta_to_json.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,27 @@ def run_query(
233233
elif method == 'dbconnect':
234234
if spark == None:
235235
raise ValueError(f'sparkSession is required for dbconnect')
236-
df = spark.sql(query)
236+
237+
try:
238+
df = spark.sql(query)
239+
except Exception as e:
240+
from pyspark.errors import AnalysisException
241+
if isinstance(e, AnalysisException):
242+
if 'INSUFFICIENT_PERMISSIONS' in e.message: # pyright: ignore
243+
match = re.search(
244+
r"Schema\s+'([^']+)'",
245+
e.message, # pyright: ignore
246+
)
247+
if match:
248+
schema_name = match.group(1)
249+
action = f'using the schema {schema_name}'
250+
else:
251+
action = 'using the schema'
252+
raise InsufficientPermissionsError(action=action,) from e
253+
raise RuntimeError(
254+
f'Error in querying into schema. Restart sparkSession and try again',
255+
) from e
256+
237257
if collect:
238258
return df.collect()
239259
return df
@@ -461,6 +481,8 @@ def fetch(
461481
raise InsufficientPermissionsError(
462482
action=f'reading from {tablename}',
463483
) from e
484+
if isinstance(e, InsufficientPermissionsError):
485+
raise e
464486
raise RuntimeError(
465487
f'Error in get rows from {tablename}. Restart sparkSession and try again',
466488
) from e

tests/a_scripts/data_prep/test_convert_delta_to_json.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
# Copyright 2022 MosaicML LLM Foundry authors
22
# SPDX-License-Identifier: Apache-2.0
33

4+
import sys
45
import unittest
56
from argparse import Namespace
67
from typing import Any
78
from unittest.mock import MagicMock, mock_open, patch
89

910
from llmfoundry.command_utils.data_prep.convert_delta_to_json import (
11+
InsufficientPermissionsError,
1012
download,
1113
fetch_DT,
1214
format_tablename,
@@ -17,6 +19,39 @@
1719

1820
class TestConvertDeltaToJsonl(unittest.TestCase):
1921

22+
def test_run_query_dbconnect_insufficient_permissions(self):
23+
error_message = (
24+
'[INSUFFICIENT_PERMISSIONS] Insufficient privileges: User does not have USE SCHEMA '
25+
"on Schema 'main.oogabooga'. SQLSTATE: 42501"
26+
)
27+
28+
class MockAnalysisException(Exception):
29+
30+
def __init__(self, message: str):
31+
self.message = message
32+
33+
with patch.dict('sys.modules', {'pyspark.errors': MagicMock()}):
34+
sys.modules[
35+
'pyspark.errors'
36+
].AnalysisException = MockAnalysisException # pyright: ignore
37+
38+
mock_spark = MagicMock()
39+
mock_spark.sql.side_effect = MockAnalysisException(error_message)
40+
41+
with self.assertRaises(InsufficientPermissionsError) as context:
42+
run_query(
43+
'SELECT * FROM table',
44+
method='dbconnect',
45+
cursor=None,
46+
spark=mock_spark,
47+
)
48+
49+
self.assertIn(
50+
'using the schema main.oogabooga',
51+
str(context.exception),
52+
)
53+
mock_spark.sql.assert_called_once_with('SELECT * FROM table')
54+
2055
@patch(
2156
'databricks.sql.connect',
2257
)

0 commit comments

Comments
 (0)