Skip to content

Commit 42c8371

Browse files
authored
[Spark] Add invariant checks to DML commands (#4486)
<!-- Thanks for sending a pull request! Here are some tips for you: 1. If this is your first time, please read our contributor guidelines: https://github.com/delta-io/delta/blob/master/CONTRIBUTING.md 2. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP] Your PR title ...'. 3. Be sure to keep the PR description updated to reflect all changes. 4. Please write your PR title to summarize what this PR proposes. 5. If possible, provide a concise example to reproduce the issue for a faster review. 6. If applicable, include the corresponding issue number in the PR title and link it in the body. --> #### Which Delta project/connector is this regarding? <!-- Please add the component selected below to the beginning of the pull request title For example: [Spark] Title of my pull request --> - [x] Spark - [ ] Standalone - [ ] Flink - [ ] Kernel - [ ] Other (fill in here) ## Description This PR adds invariant checks to DELETE, UPDATE, and MERGE to detect bugs in Spark and Delta, and to prevent these DML statements from committing if they suffered from the bug. These invariants come in two flavors: - Unreliable checks using the commit stats (i.e. number of rows in the files added and removed) and the SQL metrics. These checks are disabled by default, as Spark can overcount metrics when there are retries. - Reliable checks based purely on the commit stats. These checks are enabled by default, but cannot detect every occurrence of a bug. ## How was this patch tested? - Manually ran existing DML tests with all invariant checks enabled by default to confirm that the checks do not cause any issues. - Manually ran existing DML tests with all invariant checks enabled by default with a bug introduced to confirm that the checks trigger. - Added `DeltaCommandInvariantsSuite`. ## Does this PR introduce _any_ user-facing changes? <!-- If yes, please clarify the previous behavior and the change this PR proposes - provide the console output, description and/or an example to show the behavior difference if possible. If possible, please also clarify if this is a user-facing change compared to the released Delta Lake versions or within the unreleased branches such as master. If no, write 'No'. --> No
1 parent b7ff92a commit 42c8371

File tree

12 files changed

+796
-18
lines changed

12 files changed

+796
-18
lines changed

spark/src/main/resources/error/delta-error-classes.json

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,14 @@
546546
],
547547
"sqlState" : "2200G"
548548
},
549+
"DELTA_COMMAND_INVARIANT_VIOLATION" : {
550+
"message" : [
551+
"A command internal invariant was violated in '<operation>'.",
552+
"Please retry the command.",
553+
"Exception reference: <uuid>."
554+
],
555+
"sqlState" : "XXKDS"
556+
},
549557
"DELTA_COMMIT_INTERMEDIATE_REDIRECT_STATE" : {
550558
"message" : [
551559
"Cannot handle commit of table within redirect table state '<state>'."
@@ -1994,6 +2002,14 @@
19942002
],
19952003
"sqlState" : "42P18"
19962004
},
2005+
"DELTA_NUM_RECORDS_MISMATCH" : {
2006+
"message" : [
2007+
"Failed to validate the number of records in <operation>.",
2008+
"Added <numAddedRecords> records and removed <numRemovedRecords> records.",
2009+
"This is a bug."
2010+
],
2011+
"sqlState" : "XXKDS"
2012+
},
19972013
"DELTA_ONEOF_IN_TIMETRAVEL" : {
19982014
"message" : [
19992015
"Please either provide 'timestampAsOf' or 'versionAsOf' for time travel."

spark/src/main/scala/org/apache/spark/sql/delta/DeltaErrors.scala

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.delta
1919
// scalastyle:off import.ordering.noEmptyLine
2020
import java.io.{FileNotFoundException, IOException}
2121
import java.nio.file.FileAlreadyExistsException
22-
import java.util.ConcurrentModificationException
22+
import java.util.{ConcurrentModificationException, UUID}
2323

2424
import scala.collection.JavaConverters._
2525

@@ -3728,6 +3728,25 @@ trait DeltaErrorsBase
37283728
errorClass = "DELTA_UNSUPPORTED_CATALOG_OWNED_TABLE_CREATION",
37293729
messageParameters = Array.empty)
37303730
}
3731+
3732+
def numRecordsMismatch(
3733+
operation: String,
3734+
numAddedRecords: Long,
3735+
numRemovedRecords: Long): Throwable = {
3736+
new DeltaIllegalStateException(
3737+
errorClass = "DELTA_NUM_RECORDS_MISMATCH",
3738+
messageParameters = Array(operation, numAddedRecords.toString, numRemovedRecords.toString)
3739+
)
3740+
}
3741+
3742+
def commandInvariantViolationException(
3743+
operation: String,
3744+
id: UUID): Throwable = {
3745+
new DeltaIllegalStateException(
3746+
errorClass = "DELTA_COMMAND_INVARIANT_VIOLATION",
3747+
messageParameters = Array(operation, id.toString)
3748+
)
3749+
}
37313750
}
37323751

37333752
object DeltaErrors extends DeltaErrorsBase

spark/src/main/scala/org/apache/spark/sql/delta/NumRecordsStats.scala

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ case class NumRecordsStats (
3030
numDeletionVectorRecordsAdded: Long,
3131
numDeletionVectorRecordsRemoved: Long,
3232
numFilesAddedWithoutNumRecords: Long,
33-
numFilesRemovedWithoutNumRecords: Long) {
33+
numFilesRemovedWithoutNumRecords: Long,
34+
numLogicalRecordsAddedInFilesWithDeletionVectorsPartial: Long) {
3435

3536
def allFilesHaveNumRecords: Boolean =
3637
numFilesAddedWithoutNumRecords == 0 && numFilesRemovedWithoutNumRecords == 0
@@ -48,6 +49,14 @@ case class NumRecordsStats (
4849
*/
4950
def numLogicalRecordsRemoved: Option[Long] = Option.when(numFilesRemovedWithoutNumRecords == 0)(
5051
numLogicalRecordsRemovedPartial)
52+
53+
/**
54+
* The number of logical records in all AddFile actions that have a deletion vector or None
55+
* if any file does not contain statistics.
56+
*/
57+
def numLogicalRecordsAddedInFilesWithDeletionVectors: Option[Long] =
58+
Option.when(numFilesAddedWithoutNumRecords == 0)(
59+
numLogicalRecordsAddedInFilesWithDeletionVectorsPartial)
5160
}
5261

5362
object NumRecordsStats {
@@ -60,6 +69,7 @@ object NumRecordsStats {
6069
var numLogicalRecordsRemovedPartial: Long = 0L
6170
var numDeletionVectorRecordsAdded = 0L
6271
var numDeletionVectorRecordsRemoved = 0L
72+
var numLogicalRecordsAddedInFilesWithDeletionVectorsPartial = 0L
6373

6474
actions.foreach {
6575
case a: AddFile =>
@@ -69,6 +79,10 @@ object NumRecordsStats {
6979
0L
7080
}
7181
numDeletionVectorRecordsAdded += a.numDeletedRecords
82+
if (a.deletionVector != null) {
83+
numLogicalRecordsAddedInFilesWithDeletionVectorsPartial +=
84+
a.numLogicalRecords.getOrElse(0L)
85+
}
7286
case r: RemoveFile =>
7387
numFilesRemoved += 1
7488
numLogicalRecordsRemovedPartial += r.numLogicalRecords.getOrElse {
@@ -85,6 +99,9 @@ object NumRecordsStats {
8599
numDeletionVectorRecordsAdded = numDeletionVectorRecordsAdded,
86100
numDeletionVectorRecordsRemoved = numDeletionVectorRecordsRemoved,
87101
numFilesAddedWithoutNumRecords = numFilesAddedWithoutNumRecords,
88-
numFilesRemovedWithoutNumRecords = numFilesRemovedWithoutNumRecords)
102+
numFilesRemovedWithoutNumRecords = numFilesRemovedWithoutNumRecords,
103+
numLogicalRecordsAddedInFilesWithDeletionVectorsPartial =
104+
numLogicalRecordsAddedInFilesWithDeletionVectorsPartial
105+
)
89106
}
90107
}

spark/src/main/scala/org/apache/spark/sql/delta/commands/DeleteCommand.scala

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ package org.apache.spark.sql.delta.commands
1818

1919
import java.util.concurrent.TimeUnit
2020

21+
import scala.util.control.NonFatal
22+
2123
import org.apache.spark.sql.delta.metric.IncrementMetric
2224
import org.apache.spark.sql.delta._
2325
import org.apache.spark.sql.delta.ClassicColumnConversions._
@@ -36,6 +38,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference,
3638
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
3739
import org.apache.spark.sql.catalyst.plans.QueryPlan
3840
import org.apache.spark.sql.catalyst.plans.logical.{DeltaDelete, LogicalPlan}
41+
import org.apache.spark.sql.delta.DeltaOperations.Operation
3942
import org.apache.spark.sql.execution.command.LeafRunnableCommand
4043
import org.apache.spark.sql.execution.metric.SQLMetric
4144
import org.apache.spark.sql.execution.metric.SQLMetrics.{createMetric, createTimingMetric}
@@ -127,9 +130,12 @@ case class DeleteCommand(
127130
}
128131

129132
val (deleteActions, deleteMetrics) = performDelete(sparkSession, deltaLog, txn)
133+
val numRecordsStats = NumRecordsStats.fromActions(deleteActions)
134+
val operation = DeltaOperations.Delete(condition.toSeq)
135+
validateNumRecords(deleteActions, numRecordsStats, operation)
130136
val commitVersion = txn.commitIfNeeded(
131137
actions = deleteActions,
132-
op = DeltaOperations.Delete(condition.toSeq),
138+
op = operation,
133139
tags = RowTracking.addPreservedRowTrackingTagIfNotSet(txn.snapshot))
134140
recordDeltaEvent(
135141
deltaLog,
@@ -472,6 +478,111 @@ case class DeleteCommand(
472478
spark.conf.get(DeltaSQLConf.DELETE_USE_PERSISTENT_DELETION_VECTORS) &&
473479
DeletionVectorUtils.deletionVectorsWritable(txn.snapshot)
474480
}
481+
482+
/**
483+
* Validates that the number of records does not increase.
484+
*
485+
* Note: ideally we would also compare the number of added/removed rows in the statistics with the
486+
* number of deleted/copied rows in the SQL metrics, but unfortunately this is not possible, as
487+
* sql metrics are not reliable when there are task or stage retries.
488+
*/
489+
private def validateNumRecords(
490+
actions: Seq[Action],
491+
numRecordsStats: NumRecordsStats,
492+
op: Operation): Unit = {
493+
(numRecordsStats.numLogicalRecordsAdded,
494+
numRecordsStats.numLogicalRecordsRemoved,
495+
numRecordsStats.numLogicalRecordsAddedInFilesWithDeletionVectors) match {
496+
case (
497+
Some(numAddedRecords),
498+
Some(numRemovedRecords),
499+
Some(numRecordsNotCopied)) =>
500+
if (numAddedRecords > numRemovedRecords) {
501+
logNumRecordsMismatch(deltaLog, actions, numRecordsStats, op)
502+
if (conf.getConf(DeltaSQLConf.NUM_RECORDS_VALIDATION_ENABLED)) {
503+
throw DeltaErrors.numRecordsMismatch(
504+
operation = "DELETE",
505+
numAddedRecords,
506+
numRemovedRecords
507+
)
508+
}
509+
}
510+
511+
if (conf.getConf(DeltaSQLConf.COMMAND_INVARIANT_CHECKS_USE_UNRELIABLE)) {
512+
// and also using regular (unreliable) metrics for baseline
513+
validateMetricBasedCommandInvariants(
514+
numAddedRecords, numRemovedRecords, numRecordsNotCopied, op, deltaLog)
515+
}
516+
517+
case _ =>
518+
recordDeltaEvent(deltaLog, opType = "delta.assertions.statsNotPresentForNumRecordsCheck")
519+
logWarning(log"Could not validate number of records due to missing statistics.")
520+
}
521+
}
522+
523+
private def validateMetricBasedCommandInvariants(
524+
numAddedRecords: Long,
525+
numRemovedRecords: Long,
526+
numRecordsNotCopied: Long,
527+
op: Operation,
528+
deltaLog: DeltaLog): Unit = try {
529+
530+
val numRowsDeleted = CommandInvariantMetricValueFromSingle(metrics("numDeletedRows"))
531+
val numRowsCopied = CommandInvariantMetricValueFromSingle(metrics("numCopiedRows"))
532+
533+
val recordMetricsFromMetadata = conf.getConf(DeltaSQLConf.DELTA_DML_METRICS_FROM_METADATA)
534+
if (numRowsDeleted.getOrDummy == 0 && !recordMetricsFromMetadata) {
535+
// If we don't record metrics we can't use them to perform invariant checks.
536+
return
537+
}
538+
539+
checkCommandInvariant(
540+
invariant = () =>
541+
numRowsDeleted.getOrThrow + numRowsCopied.getOrThrow + numRecordsNotCopied
542+
== numRemovedRecords,
543+
label = "numRowsDeleted + numRowsCopied + numRecordsNotCopied + " +
544+
"numRowsRemovedByMetadataOnlyDelete == numRemovedRecords",
545+
op = op,
546+
deltaLog = deltaLog,
547+
parameters = Map(
548+
"numRowsDeleted" -> numRowsDeleted.getOrDummy,
549+
"numRowsCopied" -> numRowsCopied.getOrDummy,
550+
"numRemovedRecords" -> numRemovedRecords,
551+
"numRecordsNotCopied" -> numRecordsNotCopied
552+
),
553+
additionalInfo = Map(
554+
DeltaSQLConf.DELTA_DML_METRICS_FROM_METADATA.key -> recordMetricsFromMetadata.toString
555+
)
556+
)
557+
558+
checkCommandInvariant(
559+
invariant = () => numRowsCopied.getOrThrow + numRecordsNotCopied == numAddedRecords,
560+
label = "numRowsCopied + numRecordsNotCopied == numAddedRecords",
561+
op = op,
562+
deltaLog = deltaLog,
563+
parameters = Map(
564+
"numRowsCopied" -> numRowsCopied.getOrDummy,
565+
"numAddedRecords" -> numAddedRecords,
566+
"numRecordsNotCopied" -> numRecordsNotCopied
567+
),
568+
additionalInfo = Map(
569+
DeltaSQLConf.DELTA_DML_METRICS_FROM_METADATA.key -> recordMetricsFromMetadata.toString
570+
)
571+
)
572+
} catch {
573+
// Immediately re-throw actual command invariant violations, so we don't re-wrap them below.
574+
case e: DeltaIllegalStateException if e.getErrorClass == "DELTA_COMMAND_INVARIANT_VIOLATION" =>
575+
throw e
576+
case NonFatal(e) =>
577+
logWarning(log"Unexpected error in validateMetricBasedCommandInvariants", e)
578+
checkCommandInvariant(
579+
invariant = () => false,
580+
label = "Unexpected error in validateMetricBasedCommandInvariants",
581+
op = op,
582+
deltaLog = deltaLog,
583+
parameters = Map.empty
584+
)
585+
}
475586
}
476587

477588
object DeleteCommand {

spark/src/main/scala/org/apache/spark/sql/delta/commands/DeltaCommand.scala

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import java.util.concurrent.TimeUnit.NANOSECONDS
2121

2222
import scala.util.control.NonFatal
2323

24-
import org.apache.spark.sql.delta.{DeltaAnalysisException, DeltaErrors, DeltaLog, DeltaOptions, DeltaTableIdentifier, DeltaTableUtils, OptimisticTransaction, ResolvedPathBasedNonDeltaTable}
24+
import org.apache.spark.sql.delta.{DeltaAnalysisException, DeltaErrors, DeltaLog, DeltaOptions, DeltaTableIdentifier, DeltaTableUtils, NumRecordsStats, OptimisticTransaction, ResolvedPathBasedNonDeltaTable}
2525
import org.apache.spark.sql.delta.actions._
2626
import org.apache.spark.sql.delta.catalog.{DeltaTableV2, IcebergTablePlaceHolder}
2727
import org.apache.spark.sql.delta.files.TahoeBatchFileIndex
@@ -40,6 +40,8 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, SubqueryExpression
4040
import org.apache.spark.sql.catalyst.parser.ParseException
4141
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
4242
import org.apache.spark.sql.connector.catalog.V1Table
43+
import org.apache.spark.sql.delta.DeltaOperations.Operation
44+
import org.apache.spark.sql.delta.sources.DeltaSQLConf.DELTA_COLLECT_STATS
4345
import org.apache.spark.sql.execution.SQLExecution
4446
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelationWithTable}
4547
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
@@ -48,7 +50,7 @@ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
4850
/**
4951
* Helper trait for all delta commands.
5052
*/
51-
trait DeltaCommand extends DeltaLogging {
53+
trait DeltaCommand extends DeltaLogging with DeltaCommandInvariants {
5254
/**
5355
* Converts string predicates into [[Expression]]s relative to a transaction.
5456
*
@@ -440,4 +442,50 @@ trait DeltaCommand extends DeltaLogging {
440442
(txnVersion, txnAppId, fromSessionConf)
441443
}
442444

445+
protected def logNumRecordsMismatch(
446+
deltaLog: DeltaLog,
447+
actions: Seq[Action],
448+
stats: NumRecordsStats,
449+
op: Operation): Unit = {
450+
451+
var numRemove = 0
452+
var numAdd = 0
453+
actions.foreach {
454+
case _: AddFile =>
455+
numAdd += 1
456+
case _: RemoveFile =>
457+
numRemove += 1
458+
case _ =>
459+
}
460+
461+
val info = NumRecordsCheckInfo(
462+
operation = op.name,
463+
numAdd = numAdd,
464+
numRemove = numRemove,
465+
numRecordsRemoved = stats.numLogicalRecordsRemovedPartial,
466+
numRecordsAdded = stats.numLogicalRecordsAddedPartial,
467+
numDeletionVectorRecordsRemoved = stats.numDeletionVectorRecordsRemoved,
468+
numDeletionVectorRecords = stats.numDeletionVectorRecordsAdded,
469+
operationParameters = op.jsonEncodedValues,
470+
statsCollectionEnabled = SparkSession.getActiveSession.get.conf.get(DELTA_COLLECT_STATS)
471+
)
472+
recordDeltaEvent(deltaLog, opType = "delta.assertions.numRecordsChanged", data = info)
473+
logWarning(log"Number of records validation failed. Number of added records" +
474+
log" (${MDC(DeltaLogKeys.NUM_RECORDS, stats.numLogicalRecordsAddedPartial)})" +
475+
log" does not match removed records" +
476+
log" (${MDC(DeltaLogKeys.NUM_RECORDS2, stats.numLogicalRecordsRemovedPartial)})")
477+
}
443478
}
479+
480+
// Recorded when number of records check for unchanged data fails.
481+
case class NumRecordsCheckInfo(
482+
operation: String,
483+
numAdd: Int,
484+
numRemove: Int,
485+
numRecordsAdded: Long,
486+
numRecordsRemoved: Long,
487+
numDeletionVectorRecordsRemoved: Long = 0, // number of DV records removed by the RemoveFiles
488+
numDeletionVectorRecords: Long = 0, // number of DV records present in all AddFiles
489+
operationParameters: Map[String, String],
490+
statsCollectionEnabled: Boolean
491+
)

0 commit comments

Comments
 (0)