Skip to content

Catches when data prep cluster fails to start #1628

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions llmfoundry/command_utils/data_prep/convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,7 @@ def fetch_DT(

formatted_delta_table_name = format_tablename(delta_table_name)
import grpc
import pyspark.errors.exceptions.connect as spark_errors
try:
fetch(
method,
Expand All @@ -702,8 +703,16 @@ def fetch_DT(
sparkSession,
dbsql,
)
except grpc.RpcError as e:
if e.code(
except (grpc.RpcError, spark_errors.SparkConnectGrpcException) as e:
if isinstance(
e,
spark_errors.SparkConnectGrpcException,
) and 'Cannot start cluster' in str(e):
raise FaultyDataPrepCluster(
message=
f'The data preparation cluster you provided is terminated. Please retry with a cluster that is healthy and alive. {e}',
) from e
if isinstance(e, grpc.RpcError) and e.code(
) == grpc.StatusCode.INTERNAL and 'Job aborted due to stage failure' in e.details(
):
raise FaultyDataPrepCluster(
Expand Down
53 changes: 53 additions & 0 deletions tests/a_scripts/data_prep/test_convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import grpc
from pyspark.errors import AnalysisException
from pyspark.errors.exceptions.connect import SparkConnectGrpcException

from llmfoundry.command_utils.data_prep.convert_delta_to_json import (
FaultyDataPrepCluster,
Expand Down Expand Up @@ -584,6 +585,58 @@ def test_fetch_DT_grpc_error_handling(
# Verify that fetch was called
mock_fetch.assert_called_once()

@patch('llmfoundry.command_utils.data_prep.convert_delta_to_json.fetch')
@patch(
'llmfoundry.command_utils.data_prep.convert_delta_to_json.validate_and_get_cluster_info',
)
def test_fetch_DT_catches_cluster_failed_to_start(
self,
mock_validate_cluster_info: MagicMock,
mock_fetch: MagicMock,
):
# Arrange
# Mock the validate_and_get_cluster_info to return test values
mock_validate_cluster_info.return_value = ('dbconnect', None, None)

# Create a SparkConnectGrpcException indicating that the cluster failed to start

grpc_error = SparkConnectGrpcException(
message='Cannot start cluster etc...',
)

# Configure the fetch function to raise the SparkConnectGrpcException
mock_fetch.side_effect = grpc_error

# Test inputs
delta_table_name = 'test_table'
json_output_folder = '/tmp/to/jsonl'
http_path = None
cluster_id = None
use_serverless = False
DATABRICKS_HOST = 'https://test-host'
DATABRICKS_TOKEN = 'test-token'

# Act & Assert
with self.assertRaises(FaultyDataPrepCluster) as context:
fetch_DT(
delta_table_name=delta_table_name,
json_output_folder=json_output_folder,
http_path=http_path,
cluster_id=cluster_id,
use_serverless=use_serverless,
DATABRICKS_HOST=DATABRICKS_HOST,
DATABRICKS_TOKEN=DATABRICKS_TOKEN,
)

# Verify that the FaultyDataPrepCluster contains the expected message
self.assertIn(
'The data preparation cluster you provided is terminated. Please retry with a cluster that is healthy and alive.',
str(context.exception),
)

# Verify that fetch was called
mock_fetch.assert_called_once()

@patch(
'llmfoundry.command_utils.data_prep.convert_delta_to_json.get_total_rows',
)
Expand Down
Loading