Skip to content

Defend against attempts to bypass JVM serial proxy #522

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
7 changes: 7 additions & 0 deletions core/jvm/src/LocalDate.kt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ public actual class LocalDate internal constructor(
@Suppress("FunctionName")
public actual fun Format(block: DateTimeFormatBuilder.WithDate.() -> Unit): DateTimeFormat<LocalDate> =
LocalDateFormat.build(block)

// Even though this class uses writeReplace (so serialVersionUID is not needed for a stable serialized form), a
// stable serialVersionUID is useful for testing, see MaliciousJvmSerializationTest.
private const val serialVersionUID: Long = 0L
}

public actual object Formats {
Expand Down Expand Up @@ -107,6 +111,9 @@ public actual class LocalDate internal constructor(
@JvmName("toEpochDays")
internal fun toEpochDaysJvm(): Int = value.toEpochDay().clampToInt()

private fun readObject(ois: java.io.ObjectInputStream): Unit =
throw java.io.InvalidObjectException("kotlinx.datetime.LocalDate must be deserialized via kotlinx.datetime.Ser")

private fun writeReplace(): Any = Ser(Ser.DATE_TAG, this)
}

Expand Down
8 changes: 8 additions & 0 deletions core/jvm/src/LocalDateTimeJvm.kt
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,20 @@ public actual class LocalDateTime internal constructor(
@Suppress("FunctionName")
public actual fun Format(builder: DateTimeFormatBuilder.WithDateTime.() -> Unit): DateTimeFormat<LocalDateTime> =
LocalDateTimeFormat.build(builder)

// Even though this class uses writeReplace (so serialVersionUID is not needed for a stable serialized form), a
// stable serialVersionUID is useful for testing, see MaliciousJvmSerializationTest.
private const val serialVersionUID: Long = 0L
}

public actual object Formats {
public actual val ISO: DateTimeFormat<LocalDateTime> = ISO_DATETIME
}

private fun readObject(ois: java.io.ObjectInputStream): Unit = throw java.io.InvalidObjectException(
"kotlinx.datetime.LocalDateTime must be deserialized via kotlinx.datetime.Ser"
)

private fun writeReplace(): Any = Ser(Ser.DATE_TIME_TAG, this)
}

Expand Down
7 changes: 7 additions & 0 deletions core/jvm/src/LocalTimeJvm.kt
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,20 @@ public actual class LocalTime internal constructor(
@Suppress("FunctionName")
public actual fun Format(builder: DateTimeFormatBuilder.WithTime.() -> Unit): DateTimeFormat<LocalTime> =
LocalTimeFormat.build(builder)

// Even though this class uses writeReplace (so serialVersionUID is not needed for a stable serialized form), a
// stable serialVersionUID is useful for testing, see MaliciousJvmSerializationTest.
private const val serialVersionUID: Long = 0L
}

public actual object Formats {
public actual val ISO: DateTimeFormat<LocalTime> get() = ISO_TIME

}

private fun readObject(ois: java.io.ObjectInputStream): Unit =
throw java.io.InvalidObjectException("kotlinx.datetime.LocalTime must be deserialized via kotlinx.datetime.Ser")

private fun writeReplace(): Any = Ser(Ser.TIME_TAG, this)
}

Expand Down
7 changes: 7 additions & 0 deletions core/jvm/src/UtcOffsetJvm.kt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ public actual class UtcOffset(
@Suppress("FunctionName")
public actual fun Format(block: DateTimeFormatBuilder.WithUtcOffset.() -> Unit): DateTimeFormat<UtcOffset> =
UtcOffsetFormat.build(block)

// Even though this class uses writeReplace (so serialVersionUID is not needed for a stable serialized form), a
// stable serialVersionUID is useful for testing, see MaliciousJvmSerializationTest.
private const val serialVersionUID: Long = 0L
}

public actual object Formats {
Expand All @@ -48,6 +52,9 @@ public actual class UtcOffset(
public actual val FOUR_DIGITS: DateTimeFormat<UtcOffset> get() = FOUR_DIGIT_OFFSET
}

private fun readObject(ois: java.io.ObjectInputStream): Unit =
throw java.io.InvalidObjectException("kotlinx.datetime.UtcOffset must be deserialized via kotlinx.datetime.Ser")

private fun writeReplace(): Any = Ser(Ser.UTC_OFFSET_TAG, this)
}

Expand Down
193 changes: 193 additions & 0 deletions core/jvm/test/MaliciousJvmSerializationTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
/*
* Copyright 2019-2025 JetBrains s.r.o. and contributors.
* Use of this source code is governed by the Apache 2.0 License that can be found in the LICENSE.txt file.
*/

package kotlinx.datetime.test

import kotlinx.datetime.test.MaliciousJvmSerializationTest.TestCase.Streams
import java.io.ByteArrayInputStream
import java.io.ObjectInputStream
import java.io.ObjectStreamClass
import java.io.Serializable
import kotlin.reflect.KClass
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.fail

class MaliciousJvmSerializationTest {

/**
* This data was generated by running the following Java code (`X` was replaced with [clazz]`.simpleName`, `Y` with
* [delegate]`::class.qualifiedName` and `z` with [delegateFieldName]):
* ```java
* package kotlinx.datetime;
*
* import java.io.*;
* import java.util.*;
*
* public class X implements Serializable {
* private final Y z = ...;
*
* @Serial
* private static final long serialVersionUID = ...;
*
* public static void main(String[] args) throws IOException {
* var bos = new ByteArrayOutputStream();
* try (var oos = new ObjectOutputStream(bos)) {
* oos.writeObject(new X());
* }
* System.out.println(HexFormat.of().formatHex(bos.toByteArray()));
* }
* }
* ```
*/
private class TestCase(
val clazz: KClass<out Serializable>,
val delegateFieldName: String,
val delegate: Serializable,
/** `serialVersionUID` was set to the correct value (`0L`) in the Java code. */
val withCorrectSVUID: Streams,
/** `serialVersionUID` was set to an incorrect value (`42L`) in the Java code. */
val withIncorrectSVUID: Streams,
) {
class Streams(
/** `z` was set to `null` in the Java code. */
val delegateNull: String,
/** `z` was set to [delegate] in the Java code. */
val delegateValid: String,
)
}

private val testCases = listOf(
TestCase(
kotlinx.datetime.LocalDate::class,
delegateFieldName = "value",
delegate = java.time.LocalDate.of(2025, 4, 26),
withCorrectSVUID = Streams(
delegateValid = "aced00057372001a6b6f746c696e782e6461746574696d652e4c6f63616c4461746500000000000000000200014c000576616c75657400154c6a6176612f74696d652f4c6f63616c446174653b78707372000d6a6176612e74696d652e536572955d84ba1b2248b20c00007870770703000007e9041a78",
delegateNull = "aced00057372001a6b6f746c696e782e6461746574696d652e4c6f63616c4461746500000000000000000200014c000576616c75657400154c6a6176612f74696d652f4c6f63616c446174653b787070",
),
withIncorrectSVUID = Streams(
delegateValid = "aced00057372001a6b6f746c696e782e6461746574696d652e4c6f63616c44617465000000000000002a0200014c000576616c75657400154c6a6176612f74696d652f4c6f63616c446174653b78707372000d6a6176612e74696d652e536572955d84ba1b2248b20c00007870770703000007e9041a78",
delegateNull = "aced00057372001a6b6f746c696e782e6461746574696d652e4c6f63616c44617465000000000000002a0200014c000576616c75657400154c6a6176612f74696d652f4c6f63616c446174653b787070",
),
),
TestCase(
kotlinx.datetime.LocalDateTime::class,
delegateFieldName = "value",
delegate = java.time.LocalDateTime.of(2025, 4, 26, 11, 18),
withCorrectSVUID = Streams(
delegateValid = "aced00057372001e6b6f746c696e782e6461746574696d652e4c6f63616c4461746554696d6500000000000000000200014c000576616c75657400194c6a6176612f74696d652f4c6f63616c4461746554696d653b78707372000d6a6176612e74696d652e536572955d84ba1b2248b20c00007870770905000007e9041a0bed78",
delegateNull = "aced00057372001e6b6f746c696e782e6461746574696d652e4c6f63616c4461746554696d6500000000000000000200014c000576616c75657400194c6a6176612f74696d652f4c6f63616c4461746554696d653b787070",
),
withIncorrectSVUID = Streams(
delegateValid = "aced00057372001e6b6f746c696e782e6461746574696d652e4c6f63616c4461746554696d65000000000000002a0200014c000576616c75657400194c6a6176612f74696d652f4c6f63616c4461746554696d653b78707372000d6a6176612e74696d652e536572955d84ba1b2248b20c00007870770905000007e9041a0bed78",
delegateNull = "aced00057372001e6b6f746c696e782e6461746574696d652e4c6f63616c4461746554696d65000000000000002a0200014c000576616c75657400194c6a6176612f74696d652f4c6f63616c4461746554696d653b787070",
),
),
TestCase(
kotlinx.datetime.LocalTime::class,
delegateFieldName = "value",
delegate = java.time.LocalTime.of(11, 18),
withCorrectSVUID = Streams(
delegateValid = "aced00057372001a6b6f746c696e782e6461746574696d652e4c6f63616c54696d6500000000000000000200014c000576616c75657400154c6a6176612f74696d652f4c6f63616c54696d653b78707372000d6a6176612e74696d652e536572955d84ba1b2248b20c000078707703040bed78",
delegateNull = "aced00057372001a6b6f746c696e782e6461746574696d652e4c6f63616c54696d6500000000000000000200014c000576616c75657400154c6a6176612f74696d652f4c6f63616c54696d653b787070",
),
withIncorrectSVUID = Streams(
delegateValid = "aced00057372001a6b6f746c696e782e6461746574696d652e4c6f63616c54696d65000000000000002a0200014c000576616c75657400154c6a6176612f74696d652f4c6f63616c54696d653b78707372000d6a6176612e74696d652e536572955d84ba1b2248b20c000078707703040bed78",
delegateNull = "aced00057372001a6b6f746c696e782e6461746574696d652e4c6f63616c54696d65000000000000002a0200014c000576616c75657400154c6a6176612f74696d652f4c6f63616c54696d653b787070",
),
),
TestCase(
kotlinx.datetime.UtcOffset::class,
delegateFieldName = "zoneOffset",
delegate = java.time.ZoneOffset.UTC,
withCorrectSVUID = Streams(
delegateValid = "aced00057372001a6b6f746c696e782e6461746574696d652e5574634f666673657400000000000000000200014c000a7a6f6e654f66667365747400164c6a6176612f74696d652f5a6f6e654f66667365743b78707372000d6a6176612e74696d652e536572955d84ba1b2248b20c000078707702080078",
delegateNull = "aced00057372001a6b6f746c696e782e6461746574696d652e5574634f666673657400000000000000000200014c000a7a6f6e654f66667365747400164c6a6176612f74696d652f5a6f6e654f66667365743b787070",
),
withIncorrectSVUID = Streams(
delegateValid = "aced00057372001a6b6f746c696e782e6461746574696d652e5574634f6666736574000000000000002a0200014c000a7a6f6e654f66667365747400164c6a6176612f74696d652f5a6f6e654f66667365743b78707372000d6a6176612e74696d652e536572955d84ba1b2248b20c000078707702080078",
delegateNull = "aced00057372001a6b6f746c696e782e6461746574696d652e5574634f6666736574000000000000002a0200014c000a7a6f6e654f66667365747400164c6a6176612f74696d652f5a6f6e654f66667365743b787070",
),
),
)

@OptIn(ExperimentalStdlibApi::class)
private fun deserialize(stream: String): Any? {
val bis = ByteArrayInputStream(stream.hexToByteArray())
return ObjectInputStream(bis).use { ois ->
ois.readObject()
}
}

@Test
fun deserializeMaliciousStreams() {
for (testCase in testCases) {
testCase.ensureAssumptionsHold()
val className = testCase.clazz.qualifiedName!!
testStreamsWithCorrectSVUID(className, testCase.withCorrectSVUID)
testStreamsWithIncorrectSVUID(className, testCase.withIncorrectSVUID)
}
}

private fun TestCase.ensureAssumptionsHold() {
val className = clazz.qualifiedName!!
val objectStreamClass = ObjectStreamClass.lookup(clazz.java)

val actualSerialVersionUID = objectStreamClass.serialVersionUID
if (actualSerialVersionUID != 0L) {
fail("This test assumes that the serialVersionUID of $className is 0, but it was $actualSerialVersionUID.")
}

val field = objectStreamClass.fields.singleOrNull()
if (field == null || field.name != delegateFieldName || field.type != delegate.javaClass) {
fail(
"This test assumes that $className has a single serializable field named '$delegateFieldName' of " +
"type ${delegate::class.qualifiedName}. The test case for $className should be updated with new " +
"malicious serial streams that represent the changes to $className."
)
}
}

private fun testStreamsWithCorrectSVUID(className: String, streams: Streams) {
val testFailureMessage = "Deserialization of a serial stream that tries to bypass kotlinx.datetime.Ser and " +
"has the correct serialVersionUID for $className should fail"

val expectedIOEMessage = "$className must be deserialized via kotlinx.datetime.Ser"

// this would actually create a valid instance, but serialization should always go through the proxy
val ioe1 = assertFailsWith<java.io.InvalidObjectException>(testFailureMessage) {
deserialize(streams.delegateValid)
}
assertEquals(expectedIOEMessage, ioe1.message)

// this would create an instance that has null in a non-nullable field (e.g., the field
// kotlinx.datetime.LocalDate.value)
// see https://github.com/Kotlin/kotlinx-datetime/pull/373#discussion_r2008922681
val ioe2 = assertFailsWith<java.io.InvalidObjectException>(testFailureMessage) {
deserialize(streams.delegateNull)
}
assertEquals(expectedIOEMessage, ioe2.message)
}

private fun testStreamsWithIncorrectSVUID(className: String, streams: Streams) {
val testFailureMessage = "Deserialization of a serial stream that tries to bypass kotlinx.datetime.Ser but " +
"has a wrong serialVersionUID for $className should fail"

val expectedICEMessage = "$className; local class incompatible: stream classdesc serialVersionUID = 42, " +
"local class serialVersionUID = 0"

val ice1 = assertFailsWith<java.io.InvalidClassException>(testFailureMessage) {
deserialize(streams.delegateValid)
}
assertEquals(expectedICEMessage, ice1.message)

val ice2 = assertFailsWith<java.io.InvalidClassException>(testFailureMessage) {
deserialize(streams.delegateNull)
}
assertEquals(expectedICEMessage, ice2.message)
}
}