diff --git a/test/tool/net/lcrypto_test.lua b/test/tool/net/lcrypto_test.lua new file mode 100644 index 00000000000..d21321e4bc3 --- /dev/null +++ b/test/tool/net/lcrypto_test.lua @@ -0,0 +1,756 @@ +-- Test RSA key pair generation +local function test_rsa_keypair_generation() + local priv_key, pub_key = crypto.generateKeyPair("rsa", 2048) + assert(type(priv_key) == "string", "Private key type") + assert(type(pub_key) == "string", "Public key type") +end + +-- Test ECDSA key pair generation +local function test_ecdsa_keypair_generation() + local priv_key, pub_key = crypto.generateKeyPair("ecdsa", "secp256r1") + assert(type(priv_key) == "string", "Private key type") + assert(type(pub_key) == "string", "Public key type") +end + +-- Test RSA encryption and decryption +local function test_rsa_encryption_decryption() + local priv_key, pub_key = crypto.generateKeyPair("rsa", 2048) + assert(type(priv_key) == "string", "Private key type") + assert(type(pub_key) == "string", "Public key type") + + local plaintext = "Hello, RSA!" + local ciphertext = crypto.encrypt("rsa", pub_key, plaintext) + assert(type(ciphertext) == "string", "Ciphertext type") + + local decrypted_plaintext = crypto.decrypt("rsa", priv_key, ciphertext) + assert(decrypted_plaintext == plaintext, "Decrypted ciphertext matches plaintext") +end + +-- Test RSA signing and verification +local function test_rsa_signing_verification() + local priv_key, pub_key = crypto.generateKeyPair("rsa", 2048) + assert(type(priv_key) == "string", "Private key type") + assert(type(pub_key) == "string", "Public key type") + + local message = "Sign this message" + local signature = crypto.sign("rsa", priv_key, message, "sha256") + assert(type(signature) == "string", "Signature type") + + local is_valid = crypto.verify("rsa", pub_key, message, signature, "sha256") + assert(is_valid == true, "Signature verification") +end + +-- Test RSA-PSS signing and verification +local function test_rsapss_signing_verification() + local priv_key, pub_key = crypto.generateKeyPair("rsa", 2048) + assert(type(priv_key) == "string", "Private key type") + assert(type(pub_key) == "string", "Public key type") + + Log(kLogVerbose," - Testing RSA-PSS signing") + local message = "Sign this message with RSA-PSS" + local signature = crypto.sign("rsapss", priv_key, message, "sha256") + assert(type(signature) == "string", "Signature type") + + Log(kLogVerbose," - Testing RSA-PSS verification") + local is_valid = crypto.verify("rsapss", pub_key, message, signature, "sha256") + assert(is_valid == true, "RSA-PSS Signature verification") + + -- Test with different hash algorithm + Log(kLogVerbose," - Testing RSA-PSS with different hash algorithms") + signature = crypto.sign("rsapss", priv_key, message, "sha384") + assert(type(signature) == "string", "SHA-384 Signature type") + is_valid = crypto.verify("rsapss", pub_key, message, signature, "sha384") + assert(is_valid == true, "RSA-PSS SHA-384 Signature verification") + + Log(kLogVerbose," - Testing RSA-PSS with SHA-512") + signature = crypto.sign("rsapss", priv_key, message, "sha512") + assert(type(signature) == "string", "SHA-512 Signature type") + is_valid = crypto.verify("rsapss", pub_key, message, signature, "sha512") + assert(is_valid == true, "RSA-PSS SHA-512 Signature verification") +end + +-- Test ECDSA signing and verification +local function test_ecdsa_signing_verification() + local priv_key, pub_key = crypto.generateKeyPair("ecdsa", "secp256r1") + assert(type(priv_key) == "string", "Private key type") + assert(type(pub_key) == "string", "Public key type") + + local message = "Sign this message with ECDSA" + local signature = crypto.sign("ecdsa", priv_key, message, "sha256") + assert(type(signature) == "string", "Signature type") + + local is_valid = crypto.verify("ecdsa", pub_key, message, signature, "sha256") + assert(is_valid == true, "Signature verification") +end + +-- Test AES key generation +local function test_aes_key_generation() + local key = crypto.generateKeyPair('aes', 256) -- 256-bit key + assert(type(key) == "string", "Key type") + assert(#key == 32, "Key length (256 bits)") +end + +-- Test AES encryption and decryption (CBC mode) +local function test_aes_encryption_decryption_cbc() + local key = crypto.generateKeyPair('aes', 256) -- 256-bit key + local plaintext = "Hello, AES CBC!" + + -- Encrypt without providing IV (should auto-generate IV) + local ciphertext, iv = crypto.encrypt("aes", key, plaintext, nil) + assert(type(ciphertext) == "string", "Ciphertext type") + assert(type(iv) == "string", "IV type") + + -- Decrypt + local decrypted_plaintext = crypto.decrypt("aes", key, ciphertext, { mode = "cbc", iv = iv }) + assert(decrypted_plaintext == plaintext, "Decrypted ciphertext matches plaintext") + + -- Encrypt with explicit IV + local iv2 = GetRandomBytes(16) + local ciphertext2, iv_used = crypto.encrypt("aes", key, plaintext, { mode = "cbc", iv = iv2 }) + assert(type(ciphertext2) == "string", "Ciphertext type") + assert(iv_used == iv2, "IV match") + + local decrypted_plaintext2 = crypto.decrypt("aes", key, ciphertext2, { mode = "cbc", iv = iv2 }) + assert(decrypted_plaintext2 == plaintext, "Decrypted ciphertext matches plaintext") +end + +-- Test AES encryption and decryption (CTR mode) +local function test_aes_encryption_decryption_ctr() + local key = crypto.generateKeyPair('aes', 256) + local plaintext = "Hello, AES CTR!" + + -- Encrypt without providing IV (should auto-generate IV) + local ciphertext, iv = crypto.encrypt("aes", key, plaintext, { mode = "ctr" }) + assert(type(ciphertext) == "string", "Ciphertext type") + assert(type(iv) == "string", "IV type") + + -- Decrypt + local decrypted_plaintext = crypto.decrypt("aes", key, ciphertext, { mode = "ctr", iv = iv }) + assert(decrypted_plaintext == plaintext, "Decrypted ciphertext matches plaintext") + + -- Encrypt with explicit IV + local iv2 = GetRandomBytes(16) + local ciphertext2, iv_used = crypto.encrypt("aes", key, plaintext, { mode = "ctr", iv = iv2 }) + assert(type(ciphertext2) == "string", "Ciphertext type") + assert(iv_used == iv2, "IV match") + + local decrypted_plaintext2 = crypto.decrypt("aes", key, ciphertext2, { mode = "ctr", iv = iv2 }) + assert(decrypted_plaintext2 == plaintext, "Decrypted ciphertext matches plaintext") +end + +-- Test AES encryption and decryption (GCM mode) +local function test_aes_encryption_decryption_gcm() + local key = crypto.generateKeyPair('aes', 256) + assert(type(key) == "string", "key type") + local plaintext = "Hello, AES GCM!" + + -- Encrypt without providing IV (should auto-generate IV) + local ciphertext, iv, tag = crypto.encrypt("aes", key, plaintext, { mode = "gcm" }) + assert(#plaintext == #ciphertext, "Ciphertext length matches plaintext") + assert(type(ciphertext) == "string", "Ciphertext type") + assert(type(iv) == "string", "IV type") + assert(type(tag) == "string", "Tag type") + + -- Decrypt + local decrypted_plaintext = crypto.decrypt("aes", key, ciphertext, { mode = "gcm", iv = iv, tag = tag }) + assert(decrypted_plaintext == plaintext, "Decrypted ciphertext matches plaintext") + + -- Encrypt with explicit IV + local iv2 = GetRandomBytes(13) -- GCM IV/nonce can be 12-16 bytes, 12 is standard + local ciphertext2, iv_used, tag2 = crypto.encrypt("aes", key, plaintext, { mode = "gcm", iv = iv2 }) + assert(type(ciphertext2) == "string", "Ciphertext type") + assert(iv_used == iv2, "IV match") + assert(type(tag2) == "string", "Tag type") + + local decrypted_plaintext2 = crypto.decrypt("aes", key, ciphertext2, { mode = "gcm", iv = iv2, tag = tag2 }) + assert(decrypted_plaintext2 == plaintext, "Decrypted ciphertext matches plaintext") +end + +-- Test PemToJwk conversion +local function test_pem_to_jwk() + local priv_key, pub_key = crypto.generateKeyPair() + local priv_jwk = crypto.convertPemToJwk(priv_key) + assert(type(priv_jwk) == "table", "JWK type") + assert(priv_jwk.kty == "RSA", "kty is correct") + + local pub_jwk = crypto.convertPemToJwk(pub_key) + assert(type(pub_jwk) == "table", "JWK type") + assert(pub_jwk.kty == "RSA", "kty is correct") + + -- Test ECDSA keys + priv_key, pub_key = crypto.generateKeyPair('ecdsa') + priv_jwk = crypto.convertPemToJwk(priv_key) + assert(type(priv_jwk) == "table", "JWK type") + assert(priv_jwk.kty == "EC", "kty is correct") + + pub_jwk = crypto.convertPemToJwk(pub_key) + assert(type(pub_jwk) == "table", "JWK type") + assert(pub_jwk.kty == "EC", "kty is correct") +end + +-- Test JwkToPem conversion +local function test_jwk_to_pem() + local priv_key, pub_key = crypto.generateKeyPair() + local priv_jwk = crypto.convertPemToJwk(priv_key) + local pub_jwk = crypto.convertPemToJwk(pub_key) + + local priv_pem = crypto.convertJwkToPem(priv_jwk) + local pub_pem = crypto.convertJwkToPem(pub_jwk) + assert(type(priv_pem) == "string", "Private PEM type") + + -- Roundtrip + assert(priv_key == priv_pem, "Private PEM matches original RSA key") + assert(pub_key == pub_pem, "Public PEM matches original RSA key") + + pub_pem = crypto.convertJwkToPem(pub_jwk) + assert(type(pub_pem) == "string", "Public PEM type") + + -- Test ECDSA keys + priv_key, pub_key = crypto.generateKeyPair('ecdsa') + priv_jwk = crypto.convertPemToJwk(priv_key) + pub_jwk = crypto.convertPemToJwk(pub_key) + + priv_pem = crypto.convertJwkToPem(priv_jwk) + pub_pem = crypto.convertJwkToPem(pub_jwk) + assert(type(priv_pem) == "string", "Private PEM type for ECDSA") + + -- Roundtrip + assert(priv_key == priv_pem, "Private PEM matches original ECDSA key") + assert(pub_key == pub_pem, "Public PEM matches original ECDSA key") + + pub_pem = crypto.convertJwkToPem(pub_jwk) + assert(type(pub_pem) == "string", "Public PEM type for ECDSA") +end + +-- Test CSR generation +local function test_csr_generation() + local priv_key, _ = crypto.generateKeyPair() + local subject_name = "CN=example.com,O=Example Org,C=US" + local san = "DNS:example.com, DNS:www.example.com, IP:192.168.1.1" + assert(type(priv_key) == "string", "Private key type") + + local csr = crypto.generateCsr(priv_key, subject_name) + assert(type(csr) == "string", "CSR generation with subject name") + + csr = crypto.generateCsr(priv_key, subject_name, san) + assert(type(csr) == "string", "CSR generation with subject name and san") + + csr = crypto.generateCsr(priv_key, nil, san) + assert(type(csr) == "string", "CSR generation with nil subject name and san") + + csr = crypto.generateCsr(priv_key, '', san) + assert(type(csr) == "string", "CSR generation with empty subject name and san") + + -- These should fail + csr = crypto.generateCsr(priv_key, '') + assert(type(csr) ~= "string", "CSR generation with empty subject name and no san is rejected") + + csr = crypto.generateCsr(priv_key) + assert(type(csr) ~= "string", "CSR generation with nil subject name and no san is rejected") +end + +-- Test various hash algorithms +local function test_hash_algorithms() + local priv_key, pub_key = crypto.generateKeyPair("rsa", 2048) + local message = "Test message for hash algorithms" + + -- Test different hash algorithms for RSA signatures + local hash_algorithms = { "sha256", "sha384", "sha512" } + for _, hash in ipairs(hash_algorithms) do + local signature = crypto.sign("rsa", priv_key, message, hash) + assert(type(signature) == "string", "RSA signature with " .. hash) + local is_valid = crypto.verify("rsa", pub_key, message, signature, hash) + assert(is_valid == true, "RSA verification with " .. hash) + + -- Test with RSA-PSS + local signature_pss = crypto.sign("rsapss", priv_key, message, hash) + assert(type(signature_pss) == "string", "RSA-PSS signature with " .. hash) + local is_valid_pss = crypto.verify("rsapss", pub_key, message, signature_pss, hash) + assert(is_valid_pss == true, "RSA-PSS verification with " .. hash) + end + + -- Test ECDSA with different hash algorithms + local ec_priv_key, ec_pub_key = crypto.generateKeyPair("ecdsa", "secp256r1") + for _, hash in ipairs(hash_algorithms) do + local signature = crypto.sign("ecdsa", ec_priv_key, message, hash) + assert(type(signature) == "string", "ECDSA signature with " .. hash) + + local is_valid = crypto.verify("ecdsa", ec_pub_key, message, signature, hash) + assert(is_valid == true, "ECDSA verification with " .. hash) + end +end + +-- Test negative cases for hash algorithms +local function test_negative_hash_algorithms() + local priv_key, pub_key = crypto.generateKeyPair() + local message = "Test message for hash algorithms" + + -- Test with invalid hash algorithm + local ok = pcall(function() return crypto.sign("rsa", priv_key, message, "invalid-hash") end) + assert(ok == false, "Sign with invalid hash should fail") + + -- Test with nil hash algorithm (should default to SHA-256) + local signature = crypto.sign("rsa", priv_key, message) + assert(type(signature) == "string", "Sign with nil hash") + + local is_valid = crypto.verify("rsa", pub_key, message, signature) + assert(is_valid == true, "Verify with nil hash") +end + +-- Negative tests for crypto functions +local function test_negative_keypair_generation() + -- Invalid algorithm + local ok = pcall(function() return crypto.generateKeyPair('invalidalg', 2048) end) + assert(ok == false, "generatekeypair with invalid algorithm should fail") + -- Invalid RSA key size + local pk, _ = crypto.generateKeyPair('rsa', 123) + assert(pk == nil, "generatekeypair with invalid RSA size should fail") + -- Invalid ECDSA curve + pk, _ = crypto.generateKeyPair('ecdsa', 'invalidcurve') + assert(pk == nil, "generatekeypair with invalid ECDSA curve should fail") +end + +local function test_negative_encrypt_decrypt() + local priv_key, pub_key = crypto.generateKeyPair('rsa', 2048) + -- Encrypt with invalid algorithm + local ok, ciphertext = pcall(function() return crypto.encrypt('invalidalg', pub_key, 'data') end) + assert(ok == false, "RSA encrypt with invalid algorithm should fail") + + -- Decrypt with invalid algorithm + ok, _ = pcall(function() return crypto.decrypt('invalidalg', priv_key, 'data') end) + assert(ok == false, "RSA decrypt with invalid algorithm should fail") + + -- Encrypt with invalid key + ciphertext = crypto.encrypt('rsa', 'notakey', 'data') + assert(ciphertext == nil, "RSA encrypt with invalid key should fail") + + -- Decrypt with invalid key + local retval = crypto.decrypt('rsa', 'notakey', 'data') + assert(retval == nil, "RSA decrypt with invalid key should fail") + + -- AES: invalid IV length + local key = crypto.generateKeyPair('aes', 256) + ciphertext = crypto.encrypt('aes', key, 'data', { mode = "cbc", iv = "tooShortIV" }) + assert(ciphertext == nil, "AES encrypt with short IV should fail") + + retval = crypto.decrypt('aes', key, 'data', { mode = "cbc", iv = "tooShortIV" }) + assert(retval == nil, "AES decrypt with short IV should fail") +end + +local function test_negative_sign_verify() + local priv_key, pub_key = crypto.generateKeyPair('rsa', 2048) + -- Sign with invalid algorithm + local ok = pcall(function() return crypto.sign('invalidalg', priv_key, 'msg', 'sha256') end) + assert(ok == false, "RSA sign with invalid algorithm should fail") + + -- Verify with invalid algorithm + ok = pcall(function() return crypto.verify('invalidalg', pub_key, 'msg', 'sig', 'sha256') end) + assert(ok == false, "RSA verify with invalid algorithm should fail") + + -- Sign with invalid key + ok = pcall(function() return crypto.sign('rsa', 'notakey', 'msg', 'sha256') end) + assert(ok == false, "RSA sign with invalid key should fail") + + -- Verify with invalid key + local verified = crypto.verify('rsa', 'notakey', 'msg', 'sig', 'sha256') + assert(verified == false, "verify with invalid key should fail") + + -- Verify with wrong signature (should return false, not error) + local badsig = 'thisisnotavalidsignature' + local result = crypto.verify('rsa', pub_key, 'msg', badsig, 'sha256') + assert(result == false, "RSA verify with wrong signature should return false") +end + +local function test_negative_pem_jwk_conversion() + -- Invalid PEM + local ok = pcall(function() return crypto.convertPemToJwk('notapem') end) + assert(ok == false, "convertPemToJwk with invalid PEM should fail") + + -- Invalid JWK (wrong type, but still a table) + local pem = crypto.convertJwkToPem({ kty = 'INVALID' }) + assert(pem == nil, "convertJwkToPem with invalid JWK should fail") + + -- Missing kty in JWK + pem = crypto.convertJwkToPem({}) + assert(pem == nil, "convertJwkToPem with missing kty should fail") +end + +local function test_negative_csr_generation() + -- Invalid key + local csr = crypto.generateCsr('notakey', 'CN=bad') + assert(csr == nil, "generateCsr with invalid key should fail") +end + +-- Add additional tests for edge cases in crypto functions + +-- Test RSA key size variations +local function test_rsa_key_sizes() + -- Test 2048-bit keys + local priv_key_2048, pub_key_2048 = crypto.generateKeyPair("rsa", 2048) + assert(type(priv_key_2048) == "string", "2048-bit private key type") + assert(type(pub_key_2048) == "string", "2048-bit public key type") + + -- Test 4096-bit keys + local priv_key_4096, pub_key_4096 = crypto.generateKeyPair("rsa", 4096) + assert(type(priv_key_4096) == "string", "4096-bit private key type") + assert(type(pub_key_4096) == "string", "4096-bit public key type") + + -- Test signing and verification with different key sizes + local message = "Test message for RSA key sizes" + + local signature_2048 = crypto.sign("rsa", priv_key_2048, message, "sha256") + assert(type(signature_2048) == "string", "2048-bit key signature") + + local is_valid_2048 = crypto.verify("rsa", pub_key_2048, message, signature_2048, "sha256") + assert(is_valid_2048 == true, "2048-bit key verification") + + local signature_4096 = crypto.sign("rsa", priv_key_4096, message, "sha256") + assert(type(signature_4096) == "string", "4096-bit key signature") + + local is_valid_4096 = crypto.verify("rsa", pub_key_4096, message, signature_4096, "sha256") + assert(is_valid_4096 == true, "4096-bit key verification") +end + +-- Test ECDSA curves +local function test_ecdsa_curves() + local curves = { "secp256r1", "secp384r1", "secp521r1" } + local message = "Test message for ECDSA curves" + + for _, curve in ipairs(curves) do + local priv_key, pub_key = crypto.generateKeyPair("ecdsa", curve) + assert(type(priv_key) == "string", curve .. " private key type") + assert(type(pub_key) == "string", curve .. " public key type") + + local signature = crypto.sign("ecdsa", priv_key, message, "sha256") + assert(type(signature) == "string", curve .. " signature") + local is_valid = crypto.verify("ecdsa", pub_key, message, signature, "sha256") + assert(is_valid == true, curve .. " verification") + end +end + +-- Test AES key sizes +local function test_aes_key_sizes() + local key_sizes = { 128, 192, 256 } + local plaintext = "Test message for AES key sizes" + + for _, size in ipairs(key_sizes) do + local key = crypto.generateKeyPair("aes", size) + assert(type(key) == "string", size .. "-bit AES key type") + assert(#key == size / 8, size .. "-bit AES key length") + + -- Test CBC mode + local ciphertext, iv = crypto.encrypt("aes", key, plaintext, { mode = "cbc" }) + assert(type(ciphertext) == "string", size .. "-bit AES CBC encryption") + local decrypted_plaintext_cbc = crypto.decrypt("aes", key, ciphertext, { mode = "cbc", iv = iv }) + assert(decrypted_plaintext_cbc == plaintext, size .. "-bit AES CBC decryption") + + -- Test CTR mode + local ciphertext_ctr, iv_ctr = crypto.encrypt("aes", key, plaintext, { mode = "ctr" }) + assert(type(ciphertext_ctr) == "string", size .. "-bit AES CTR encryption") + local decrypted_plaintext_ctr = crypto.decrypt("aes", key, ciphertext_ctr, { mode = "ctr", iv = iv_ctr }) + assert(decrypted_plaintext_ctr == plaintext, size .. "-bit AES CTR decryption") + + -- Test GCM mode + local ciphertext_gcm, iv_gcm, tag = crypto.encrypt("aes", key, plaintext, { mode = "gcm" }) + assert(type(ciphertext_gcm) == "string", size .. "-bit AES GCM encryption") + local decrypted_plaintext_gcm = crypto.decrypt("aes", key, ciphertext_gcm, { mode = "gcm", iv = iv_gcm, tag = tag }) + assert(decrypted_plaintext_gcm == plaintext, size .. "-bit AES GCM decryption") + end +end + +-- Test AES decryption with corrupted ciphertext and tag +local function test_aes_corruption_handling() + local key = crypto.generateKeyPair('aes', 256) + local plaintext = "Sensitive data for corruption test" + -- CBC mode + local ciphertext, iv = crypto.encrypt("aes", key, plaintext, { mode = "cbc" }) + + -- Corrupt ciphertext + Log(kLogVerbose," - CBC decryption with corrupted ciphertext should fail") + local corrupted = ciphertext:sub(1, #ciphertext - 1) .. string.char((ciphertext:byte(-1) ~ 0xFF) % 256) + local plaintext_cbc = crypto.decrypt("aes", key, corrupted, { mode = "cbc", iv = iv }) + assert(plaintext_cbc == nil, "CBC decryption with corrupted ciphertext should fail") + + -- CTR mode (should not error, but output will be wrong) + Log(kLogVerbose," - CTR decryption with corrupted ciphertext should not match original") + local ciphertext_ctr, iv_ctr = crypto.encrypt("aes", key, plaintext, { mode = "ctr" }) + local corrupted_ctr = ciphertext_ctr:sub(1, #ciphertext_ctr - 1) .. + string.char((ciphertext_ctr:byte(-1) ~ 0xFF) % 256) + local plaintext_ctr = crypto.decrypt("aes", key, corrupted_ctr, { mode = "ctr", iv = iv_ctr }) + assert(plaintext_ctr ~= plaintext, "CTR decryption with corrupted ciphertext should not match original") + + -- GCM mode (should fail authentication) + Log(kLogVerbose,"GCM decryption with corrupted ciphertext should fail") + local ciphertext_gcm, iv_gcm, tag = crypto.encrypt("aes", key, plaintext, { mode = "gcm" }) + local corrupted_gcm = ciphertext_gcm:sub(1, #ciphertext_gcm - 1) .. + string.char((ciphertext_gcm:byte(-1) ~ 0xFF) % 256) + local plaintext_gcm = crypto.decrypt("aes", key, corrupted_gcm, { mode = "gcm", iv = iv_gcm, tag = tag }) + assert(plaintext_gcm == nil, "GCM decryption with corrupted ciphertext should fail") + + -- GCM mode with corrupted tag + Log(kLogVerbose,"GCM decryption with corrupted tag should fail") + local badtag = tag:sub(1, #tag - 1) .. string.char((tag:byte(-1) ~ 0xFF) % 256) + local plaintext_gcm2 = crypto.decrypt("aes", key, ciphertext_gcm, { mode = "gcm", iv = iv_gcm, tag = badtag }) + assert(plaintext_gcm2 ~= plaintext, "GCM decryption with corrupted tag should fail") +end + +-- Test AES encryption/decryption with empty plaintext +local function test_aes_empty_plaintext() + local key = crypto.generateKeyPair('aes', 256) + local empty = "" + for _, mode in ipairs({ "cbc", "ctr", "gcm" }) do + local ciphertext, iv, tag = crypto.encrypt("aes", key, empty, { mode = mode }) + assert(type(ciphertext) == "string", "AES " .. mode .. " encrypt empty string") + + local opts = { mode = mode, iv = iv, tag = tag } + if mode ~= "gcm" then opts.tag = nil end + + local plaintext = crypto.decrypt("aes", key, ciphertext, opts) + assert(plaintext == empty, "AES " .. mode .. " decrypt empty string") + end +end + +-- Test sign/verify with empty message +local function test_sign_verify_empty_message() + local priv_key, pub_key = crypto.generateKeyPair("rsa", 2048) + local signature = crypto.sign("rsa", priv_key, "", "sha256") + assert(type(signature) == "string", "RSA sign empty message") + + local is_valid = crypto.verify("rsa", pub_key, "", signature, "sha256") + assert(is_valid == true, "RSA verify empty message") + + local ec_priv, ec_pub = crypto.generateKeyPair("ecdsa", "secp256r1") + local ec_sig = crypto.sign("ecdsa", ec_priv, "", "sha256") + assert(type(ec_sig) == "string", "ECDSA sign empty message") + + local ec_valid = crypto.verify("ecdsa", ec_pub, "", ec_sig, "sha256") + assert(ec_valid == true, "ECDSA verify empty message") +end + +-- Test JWK to PEM with minimal valid JWKs and missing fields +local function test_jwk_to_pem_minimal() + -- Minimal valid RSA public JWK + local _, pub_key = crypto.generateKeyPair("rsa", 2048) + local pub_jwk = crypto.convertPemToJwk(pub_key) + local minimal_jwk = { kty = pub_jwk.kty, n = pub_jwk.n, e = pub_jwk.e } + Log(kLogVerbose," - Testing minimal JWK to PEM conversion") + local pem = crypto.convertJwkToPem(minimal_jwk) + assert(type(pem) == "string", "Minimal RSA JWK to PEM") + + -- Missing 'n' field + Log(kLogVerbose," - Testing missing 'n' field in JWK to PEM conversion") + local bad_jwk = { kty = "RSA", e = pub_jwk.e } + local pem2 = crypto.convertJwkToPem(bad_jwk) + assert(pem2 == nil, "JWK to PEM with missing n should fail") + + -- Minimal EC public JWK + Log(kLogVerbose," - Testing minimal EC JWK to PEM conversion") + local _, ec_pub = crypto.generateKeyPair("ecdsa", "secp256r1") + local ec_jwk = crypto.convertPemToJwk(ec_pub) + local minimal_ec_jwk = { kty = ec_jwk.kty, crv = ec_jwk.crv, x = ec_jwk.x, y = ec_jwk.y } + local ec_pem = crypto.convertJwkToPem(minimal_ec_jwk) + assert(type(ec_pem) == "string", "Minimal EC JWK to PEM") + + -- Missing 'x' field + Log(kLogVerbose," - Testing missing 'x' field in EC JWK to PEM conversion") + local bad_ec_jwk = { kty = "EC", crv = ec_jwk.crv, y = ec_jwk.y } + local ec_pem2 = crypto.convertJwkToPem(bad_ec_jwk) + assert(ec_pem2 == nil, "EC JWK to PEM with missing x should fail") +end + +-- Test PEM to JWK with corrupted PEM +local function test_pem_to_jwk_corrupted() + local badpem = "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA7\n-----END PUBLIC KEY-----" + local ok = pcall(function() return crypto.convertPemToJwk(badpem) end) + assert(ok == false, "PEM to JWK with corrupted PEM should fail") +end + +-- Test CSR generation with missing/invalid subject/SAN +local function test_csr_generation_edge_cases() + local priv_key, _ = crypto.generateKeyPair() + -- Missing subject and SAN + local csr = crypto.generateCsr(priv_key) + assert(csr == nil, "CSR with missing subject and SAN should fail") + -- Invalid SAN type (not validated yet) + -- local csr2, err2 = crypto.generateCsr(priv_key, "CN=foo", 12345) + -- assert(csr2 == nil, "CSR with invalid SAN type should fail") +end + +-- Test unsupported AES mode +local function test_unsupported_aes_mode() + Log(kLogVerbose," - AES decrypt with unsupported mode should fail") + local key = crypto.generateKeyPair('aes', 256) + local ciphertext = crypto.encrypt('aes', key, 'data', { mode = 'ofb' }) + assert(ciphertext == nil, "AES encrypt with unsupported mode should fail") + + local plaintext = crypto.decrypt('aes', key, 'data', { mode = 'ofb', iv = string.rep('A', 16) }) + assert(plaintext == nil, "AES decrypt with unsupported mode should fail") +end + +-- Test encrypting and signing very large messages +local function test_large_message_handling() + local priv_key, pub_key = crypto.generateKeyPair("rsa", 2048) + local large_message = string.rep("A", 1024 * 1024) -- 1MB + + -- RSA encryption (should fail or be limited by key size) + Log(kLogVerbose," - RSA encrypt large message should fail or be limited") + local ciphertext = crypto.encrypt("rsa", pub_key, large_message) + assert(ciphertext == nil, "RSA encrypt large message should fail or be limited") + + -- AES encryption (should succeed) + Log(kLogVerbose," - AES encrypt large message") + local key = crypto.generateKeyPair('aes', 256) + local aes_ciphertext, iv = crypto.encrypt("aes", key, large_message, { mode = "cbc" }) + assert(type(aes_ciphertext) == "string", "AES encrypt large message") + + local decrypted_large_message = crypto.decrypt("aes", key, aes_ciphertext, { mode = "cbc", iv = iv }) + assert(decrypted_large_message == large_message, "AES decrypt large message") + + -- RSA sign large message + Log(kLogVerbose," - RSA verify large message") + local signature = crypto.sign("rsa", priv_key, large_message, "sha256") + assert(type(signature) == "string", "RSA sign large message") + + local is_valid = crypto.verify("rsa", pub_key, large_message, signature, "sha256") + assert(is_valid == true, "RSA verify large message") +end + +-- Test passing non-string values as keys/messages/options +local function test_invalid_types() + local priv_key, _ = crypto.generateKeyPair("rsa", 2048) + local key = crypto.generateKeyPair('aes', 256) + + -- Non-string message + Log(kLogVerbose," - RSA sign with integer message should fail") + local signature = crypto.sign("rsa", priv_key, 12345, "sha256") + assert(signature == nil, "RSA sign with non-string message should fail") + + Log(kLogVerbose," - AES encrypt with boolean message should fail") + ciphertext = crypto.encrypt("aes", key, true, { mode = "cbc" }) + assert(ciphertext == nil, "AES encrypt with boolean message should fail") + + -- Non-string key + Log(kLogVerbose," - RSA sign with table as key should fail") + signature = crypto.sign("rsa", {}, "msg", "sha256") + assert(signature == nil, "RSA sign with table as key should fail") + + -- Non-table options + Log(kLogVerbose," - AES encrypt with number as options should fail") + ciphertext = crypto.encrypt("aes", key, "msg", 123) + assert(ciphertext == nil, "AES encrypt with number as options should fail") +end + +-- Test encrypting with one mode and decrypting with another +local function test_mixed_mode_operations() + local key = crypto.generateKeyPair('aes', 256) + local plaintext = "Mixed mode test" + local ciphertext, iv = crypto.encrypt("aes", key, plaintext, { mode = "cbc" }) + local decrypted_plaintext = crypto.decrypt("aes", key, ciphertext, { mode = "ctr", iv = iv }) + assert(decrypted_plaintext ~= plaintext, "Decrypt CBC ciphertext with CTR mode should not match") + + local ciphertext2, iv2 = crypto.encrypt("aes", key, plaintext, { mode = "ctr" }) + local decrypted_plaintext2 = crypto.decrypt("aes", key, ciphertext2, { mode = "cbc", iv = iv2 }) + assert(decrypted_plaintext2 ~= plaintext, "Decrypt CTR ciphertext with CBC mode should not match") +end + +-- Test signing/verifying/converting with nil or empty parameters +local function test_nil_empty_parameters() + Log(kLogVerbose," - RSA sign with nil message should fail") + local priv_key, pub_key = crypto.generateKeyPair("rsa", 2048) + local signature = crypto.sign("rsa", priv_key, nil, "sha256") + assert(signature == nil, "Sign with nil message should fail") + + Log(kLogVerbose," - RSA verify with nil message should fail") + local is_valid = crypto.verify("rsa", pub_key, nil, "sig", "sha256") + assert(is_valid == nil, "Verify with nil message should fail") + + Log(kLogVerbose," - JWK to PEM with nil should fail") + local ok = pcall(function() return crypto.convertJwkToPem(nil) end) + assert(ok == false, "convertJwkToPem with nil should fail") + + Log(kLogVerbose," - JWK to PEM with empty string should fail") + ok = pcall(function() return crypto.convertJwkToPem("") end) + assert(ok == false, "convertJwkToPem with empty string should fail") +end + +-- Run all tests +local function run_tests() + Log(kLogVerbose,"Testing RSA keypair generation...") + test_rsa_keypair_generation() + Log(kLogVerbose,"Testing RSA key size variations...") + test_rsa_key_sizes() + Log(kLogVerbose,"Testing RSA signing and verification...") + test_rsa_signing_verification() + Log(kLogVerbose,"Testing RSA encryption and decryption...") + test_rsa_encryption_decryption() + Log(kLogVerbose,"Testing RSA-PSS signing and verification...") + test_rsapss_signing_verification() + + Log(kLogVerbose,"Testing ECDSA keypair generation...") + test_ecdsa_keypair_generation() + Log(kLogVerbose,"Testing ECDSA signing and verification...") + test_ecdsa_signing_verification() + Log(kLogVerbose,"Testing ECDSA curves...") + test_ecdsa_curves() + + Log(kLogVerbose,"Testing AES key generation...") + test_aes_key_generation() + Log(kLogVerbose,"Testing AES encryption and decryption (CBC mode)...") + test_aes_encryption_decryption_cbc() + Log(kLogVerbose,"Testing AES encryption and decryption (CTR mode)...") + test_aes_encryption_decryption_ctr() + Log(kLogVerbose,"Testing AES encryption and decryption (GCM mode)...") + test_aes_encryption_decryption_gcm() + Log(kLogVerbose,"Testing unsupported AES mode...") + test_unsupported_aes_mode() + Log(kLogVerbose,"Testing AES key sizes...") + test_aes_key_sizes() + Log(kLogVerbose,"Testing AES decryption with corrupted ciphertext and tag...") + test_aes_corruption_handling() + Log(kLogVerbose,"Testing AES encryption/decryption with empty plaintext...") + test_aes_empty_plaintext() + Log(kLogVerbose,"Testing large message encryption and signing...") + test_large_message_handling() + + Log(kLogVerbose,"Testing various hash algorithms...") + test_hash_algorithms() + Log(kLogVerbose,"Testing negative cases for hash algorithms...") + test_negative_hash_algorithms() + Log(kLogVerbose,"Testing sign/verify with empty message...") + test_sign_verify_empty_message() + Log(kLogVerbose,"Testing invalid input types...") + test_invalid_types() + Log(kLogVerbose,"Testing mixed mode encryption/decryption...") + test_mixed_mode_operations() + Log(kLogVerbose,"Testing nil/empty parameters...") + test_nil_empty_parameters() + Log(kLogVerbose,"Testing negative keypair generation...") + test_negative_keypair_generation() + Log(kLogVerbose,"Testing negative encrypt/decrypt...") + test_negative_encrypt_decrypt() + Log(kLogVerbose,"Testing negative sign/verify...") + test_negative_sign_verify() + + Log(kLogVerbose,"Testing PEM to JWK conversion...") + test_pem_to_jwk() + Log(kLogVerbose,"Testing PEM to JWK with corrupted PEM...") + test_pem_to_jwk_corrupted() + Log(kLogVerbose,"Testing JWK to PEM conversion...") + test_jwk_to_pem() + Log(kLogVerbose,"Testing negative PEM/JWK conversion...") + test_negative_pem_jwk_conversion() + Log(kLogVerbose,"Testing JWK to PEM with minimal valid JWKs and missing fields...") + test_jwk_to_pem_minimal() + Log(kLogVerbose,"Testing CSR generation...") + test_csr_generation() + Log(kLogVerbose,"Testing CSR generation with missing/invalid subject/SAN...") + test_csr_generation_edge_cases() + Log(kLogVerbose,"Testing negative CSR generation...") + test_negative_csr_generation() + EXIT = 0 + return EXIT +end + +EXIT = 70 + +os.exit(run_tests()) diff --git a/third_party/mbedtls/config.h b/third_party/mbedtls/config.h index 24f2c227b26..1a6fc18c9af 100644 --- a/third_party/mbedtls/config.h +++ b/third_party/mbedtls/config.h @@ -38,11 +38,11 @@ /* block modes */ #define MBEDTLS_GCM_C -#ifndef TINY #define MBEDTLS_CIPHER_MODE_CBC +#define MBEDTLS_CIPHER_MODE_CTR +#ifndef TINY /*#define MBEDTLS_CCM_C*/ /*#define MBEDTLS_CIPHER_MODE_CFB*/ -/*#define MBEDTLS_CIPHER_MODE_CTR*/ /*#define MBEDTLS_CIPHER_MODE_OFB*/ /*#define MBEDTLS_CIPHER_MODE_XTS*/ #endif @@ -71,10 +71,10 @@ /* eliptic curves */ #define MBEDTLS_ECP_DP_SECP256R1_ENABLED #define MBEDTLS_ECP_DP_SECP384R1_ENABLED +#define MBEDTLS_ECP_DP_SECP521R1_ENABLED #define MBEDTLS_ECP_DP_CURVE25519_ENABLED #ifndef TINY #define MBEDTLS_ECP_DP_CURVE448_ENABLED -/*#define MBEDTLS_ECP_DP_SECP521R1_ENABLED*/ /*#define MBEDTLS_ECP_DP_BP384R1_ENABLED*/ /*#define MBEDTLS_ECP_DP_SECP192R1_ENABLED*/ /*#define MBEDTLS_ECP_DP_SECP224R1_ENABLED*/ @@ -395,7 +395,9 @@ * * This enables support for RSAES-OAEP and RSASSA-PSS operations. */ -/*#define MBEDTLS_PKCS1_V21*/ +#ifndef TINY +#define MBEDTLS_PKCS1_V21 +#endif /** * \def MBEDTLS_RSA_NO_CRT diff --git a/tool/net/BUILD.mk b/tool/net/BUILD.mk index 06a80d5f8dc..6528b88541a 100644 --- a/tool/net/BUILD.mk +++ b/tool/net/BUILD.mk @@ -100,6 +100,7 @@ TOOL_NET_REDBEAN_LUA_MODULES = \ o/$(MODE)/tool/net/lmaxmind.o \ o/$(MODE)/tool/net/lsqlite3.o \ o/$(MODE)/tool/net/largon2.o \ + o/$(MODE)/tool/net/lcrypto.o \ o/$(MODE)/tool/net/launch.o o/$(MODE)/tool/net/redbean.dbg: \ diff --git a/tool/net/definitions.lua b/tool/net/definitions.lua index 3732416b017..11e545a6062 100644 --- a/tool/net/definitions.lua +++ b/tool/net/definitions.lua @@ -746,7 +746,7 @@ function EscapeHtml(str) end ---@param path string? function LaunchBrowser(path) end ----@param ip uint32 +---@param ip integer|string ---@return string # a string describing the IP address. This is currently Class A granular. It can tell you if traffic originated from private networks, ARIN, APNIC, DOD, etc. ---@nodiscard function CategorizeIp(ip) end @@ -1142,10 +1142,10 @@ function FormatHttpDateTime(seconds) end --- Turns integer like `0x01020304` into a string like `"1.2.3.4"`. See also --- `ParseIp` for the inverse operation. ----@param uint32 integer +---@param ip integer ---@return string ---@nodiscard -function FormatIp(uint32) end +function FormatIp(ip) end --- Returns client ip4 address and port, e.g. `0x01020304`,`31337` would represent --- `1.2.3.4:31337`. This is the same as `GetClientAddr` except it will use the @@ -1363,25 +1363,25 @@ function HidePath(prefix) end ---@nodiscard function IsHiddenPath(path) end ----@param uint32 integer +---@param ip integer|string|string ---@return boolean # `true` if IP address is not a private network (`10.0.0.0/8`, `172.16.0.0/12`, `192.168.0.0/16`) and is not localhost (`127.0.0.0/8`). --- Note: we intentionally regard TEST-NET IPs as public. ---@nodiscard -function IsPublicIp(uint32) end +function IsPublicIp(ip) end ----@param uint32 integer +---@param ip integer|string|string ---@return boolean # `true` if IP address is part of a private network (10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16). ---@nodiscard -function IsPrivateIp(uint32) end +function IsPrivateIp(ip) end ---@return boolean # `true` if the client IP address (returned by GetRemoteAddr) is part of the localhost network (127.0.0.0/8). ---@nodiscard function IsLoopbackClient() end ----@param uint32 integer +---@param ip integer|string|string ---@return boolean # true if IP address is part of the localhost network (127.0.0.0/8). ---@nodiscard -function IsLoopbackIp(uint32) end +function IsLoopbackIp(ip) end ---@param path string ---@return boolean # `true` if ZIP artifact at path is stored on disk using DEFLATE compression. @@ -1615,7 +1615,7 @@ function GetCryptoHash(name, payload, key) end --- to the system-configured DNS resolution service. Please note that in MODE=tiny --- the HOSTS.TXT and DNS resolution isn't included, and therefore an IP must be --- provided. ----@param ip integer +---@param ip integer|string|string ---@overload fun(host:string) function ProgramAddr(ip) end @@ -1669,8 +1669,8 @@ function ProgramTimeout(milliseconds) end --- Hard-codes the port number on which to listen, which can be any number in the --- range `1..65535`, or alternatively `0` to ask the operating system to choose a --- port, which may be revealed later on by `GetServerAddr` or the `-z` flag to stdout. ----@param uint16 integer -function ProgramPort(uint16) end +---@param port integer +function ProgramPort(port) end --- Sets the maximum HTTP message payload size in bytes. The --- default is very conservatively set to 65536 so this is @@ -2169,7 +2169,7 @@ function bin(int) end --- unspecified format describing the error. Calls to this function may be wrapped --- in `assert()` if an exception is desired. ---@param hostname string ----@return uint32 ip uint32 +---@return string ---@nodiscard ---@overload fun(hostname: string): nil, error: string function ResolveIp(hostname) end @@ -2183,7 +2183,7 @@ function ResolveIp(hostname) end --- The network interface addresses used by the host machine are always --- considered trustworthy, e.g. 127.0.0.1. This may change soon, if we --- decide to export a `GetHostIps()` API which queries your NIC devices. ----@param ip integer +---@param ip integer|string ---@return boolean function IsTrustedIp(ip) end @@ -2213,7 +2213,7 @@ function IsTrustedIp(ip) end --- --- Although you might want consider trusting redbean's open source --- freedom embracing solution to DDOS protection instead! ----@param ip integer +---@param ip integer|string ---@param cidr integer? function ProgramTrustedIp(ip, cidr) end @@ -8048,6 +8048,75 @@ kUrlPlus = nil ---@type integer to transcode ISO-8859-1 input into UTF-8. See `ParseUrl`. kUrlLatin1 = nil + +--- This module provides cryptographic operations. + +--- The crypto module for cryptographic operations +crypto = {} + +--- Converts a PEM-encoded key to JWK format +---@param pem string PEM-encoded key +---@return table?, string? JWK table or nil on error +---@return string? error message +function crypto.convertPemToJwk(pem) end + +--- Generates a Certificate Signing Request (CSR) +---@param key_pem string PEM-encoded private key +---@param subject_name string? X.509 subject name +---@param san_list string? Subject Alternative Names +---@return string?, string? CSR in PEM format or nil on error and error message +function crypto.generateCsr(key_pem, subject_name, san_list) end + +--- Signs data using a private key +---@param key_type string "rsa" or "ecdsa" +---@param private_key string PEM-encoded private key +---@param message string Data to sign +---@param hash_algo string? Hash algorithm (default: SHA-256) +---@return string?, string? Signature or nil on error and error message +function crypto.sign(key_type, private_key, message, hash_algo) end + +--- Verifies a signature +---@param key_type string "rsa" or "ecdsa" +---@param public_key string PEM-encoded public key +---@param message string Original message +---@param signature string Signature to verify +---@param hash_algo string? Hash algorithm (default: SHA-256) +---@return boolean?, string? True if valid or nil on error and error message +function crypto.verify(key_type, public_key, message, signature, hash_algo) end + +--- Encrypts data +---@param cipher_type string "rsa" or "aes" +---@param key string Public key or symmetric key +---@param plaintext string Data to encrypt +---@param options table Table with optional parameters: +--- options.mode string? AES mode: "cbc", "gcm", "ctr" (default: "cbc") +--- options.iv string? Initialization Vector for AES +--- options.aad string? Additional data for AES-GCM +---@return string? Encrypted data or nil on error +---@return string? IV or error message +---@return string? Authentication tag for GCM mode +function crypto.encrypt(cipher_type, key, plaintext, options) end + +--- Decrypts data +---@param cipher_type string "rsa" or "aes" +---@param key string Private key or symmetric key +---@param ciphertext string Data to decrypt +---@param options table Table with optional parameters: +--- options.iv string? Initialization Vector for AES +--- options.mode string? AES mode: "cbc", "gcm", "ctr" (default: "cbc") +--- options.tag string? Authentication tag for AES-GCM +--- options.aad string? Additional data for AES-GCM +---@return string?, string? Decrypted data or nil on error and error message +function crypto.decrypt(cipher_type, key, ciphertext, options) end + +--- Generates cryptographic keys +---@param key_type string? "rsa", "ecdsa", or "aes" +---@param key_size_or_curve number|string? Key size or curve name +---@return string? Private key or nil on error +---@return string? Public key (nil for AES) or error message +function crypto.generatekeypair(key_type, key_size_or_curve) end + + --[[ ──────────────────────────────────────────────────────────────────────────────── LEGAL diff --git a/tool/net/lcrypto.c b/tool/net/lcrypto.c new file mode 100644 index 00000000000..f9792089361 --- /dev/null +++ b/tool/net/lcrypto.c @@ -0,0 +1,2557 @@ +/*-*- mode:c;indent-tabs-mode:nil;c-basic-offset:2;tab-width:8;coding:utf-8 -*-│ +│ vi: set et ft=c ts=2 sts=2 sw=2 fenc=utf-8 :vi │ +╞══════════════════════════════════════════════════════════════════════════════╡ +│ Copyright 2025 Miguel Angel Terron │ +│ │ +│ Permission to use, copy, modify, and/or distribute this software for │ +│ any purpose with or without fee is hereby granted, provided that the │ +│ above copyright notice and this permission notice appear in all copies. │ +│ │ +│ THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL │ +│ WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED │ +│ WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE │ +│ AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL │ +│ DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR │ +│ PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER │ +│ TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR │ +│ PERFORMANCE OF THIS SOFTWARE. │ +╚─────────────────────────────────────────────────────────────────────────────*/ + +#include "tool/net/luacheck.h" +// mbedTLS +#include "third_party/mbedtls/aes.h" +#include "third_party/mbedtls/base64.h" +#include "third_party/mbedtls/ctr_drbg.h" +#include "third_party/mbedtls/ecdsa.h" +#include "third_party/mbedtls/entropy.h" +#include "third_party/mbedtls/error.h" +#include "third_party/mbedtls/gcm.h" +#include "third_party/mbedtls/md.h" +#include "third_party/mbedtls/oid.h" +#include "third_party/mbedtls/pk.h" +#include "third_party/mbedtls/rsa.h" +#include "third_party/mbedtls/x509_csr.h" + +// Elliptic curves +// Supported curves mapping +typedef struct { + const char *name; + mbedtls_ecp_group_id id; +} curve_map_t; + +static const curve_map_t supported_curves[] = { + {"secp256r1", MBEDTLS_ECP_DP_SECP256R1}, // + {"P256", MBEDTLS_ECP_DP_SECP256R1}, // + {"P-256", MBEDTLS_ECP_DP_SECP256R1}, // + {"secp384r1", MBEDTLS_ECP_DP_SECP384R1}, // + {"P384", MBEDTLS_ECP_DP_SECP384R1}, // + {"P-384", MBEDTLS_ECP_DP_SECP384R1}, // + {"secp521r1", MBEDTLS_ECP_DP_SECP521R1}, // + {"P521", MBEDTLS_ECP_DP_SECP521R1}, // + {"P-521", MBEDTLS_ECP_DP_SECP521R1}, // + {"curve25519", MBEDTLS_ECP_DP_CURVE25519}, // + {"curve448", MBEDTLS_ECP_DP_CURVE448}, // + {NULL, 0}}; + +// List available curves +static int LuaListCurves(lua_State *L) { + const curve_map_t *curve = supported_curves; + int i = 1; + lua_newtable(L); + + while (curve->name != NULL) { + lua_pushstring(L, curve->name); + lua_rawseti(L, -2, i++); + curve++; + } + return 1; +} + +// Find curve ID by name +static mbedtls_ecp_group_id find_curve_by_name(const char *name) { + const curve_map_t *curve = supported_curves; + + while (curve->name != NULL) { + if (strcasecmp(curve->name, name) == 0) { + return curve->id; + } + curve++; + } + return MBEDTLS_ECP_DP_NONE; +} + +// Message digests +// Supported digests mapping +typedef struct { + const char *name; + mbedtls_md_type_t id; +} digest_map_t; + +static const digest_map_t supported_digests[] = { + {"MD5", MBEDTLS_MD_MD5}, // + {"SHA1", MBEDTLS_MD_SHA1}, // + {"SHA-1", MBEDTLS_MD_SHA1}, // + {"SHA224", MBEDTLS_MD_SHA224}, // + {"SHA-224", MBEDTLS_MD_SHA224}, // + {"SHA256", MBEDTLS_MD_SHA256}, // + {"SHA-256", MBEDTLS_MD_SHA256}, // + {"SHA384", MBEDTLS_MD_SHA384}, // + {"SHA-384", MBEDTLS_MD_SHA384}, // + {"SHA512", MBEDTLS_MD_SHA512}, // + {"SHA-512", MBEDTLS_MD_SHA512}, // + {NULL, 0}}; + +// List available digests +static int LuaListDigests(lua_State *L) { + const digest_map_t *digest = supported_digests; + int i = 1; + lua_newtable(L); + + while (digest->name != NULL) { + lua_pushstring(L, digest->name); + lua_rawseti(L, -2, i++); + digest++; + } + return 1; +} + +// Find digest ID by name +static mbedtls_md_type_t find_digest_by_name(const char *name) { + const digest_map_t *digest = supported_digests; + + while (digest->name != NULL) { + if (strcasecmp(digest->name, name) == 0) { + return digest->id; + } + digest++; + } + return MBEDTLS_MD_NONE; +} + +// Get the size of the hash output based on the mbedtls_md_type_t +static size_t get_hash_size_from_md_type(mbedtls_md_type_t md_type) { + switch (md_type) { + case MBEDTLS_MD_SHA256: + return 32; + case MBEDTLS_MD_SHA384: + return 48; + case MBEDTLS_MD_SHA512: + return 64; + case MBEDTLS_MD_SHA1: + return 20; + case MBEDTLS_MD_MD5: + return 16; + default: + return 32; + } +} + +// Compute hash using mbedtls +static int compute_hash(mbedtls_md_type_t md_type, const unsigned char *input, + size_t input_len, unsigned char *output, + size_t output_size) { + mbedtls_md_context_t md_ctx; + const mbedtls_md_info_t *md_info; + int ret; + + md_info = mbedtls_md_info_from_type(md_type); + if (md_info == NULL) { + WARNF("(crypto) Unsupported hash algorithm"); + return -1; + } + + if (output_size < mbedtls_md_get_size(md_info)) { + WARNF("(crypto) Output buffer too small for hash"); + return -1; + } + + mbedtls_md_init(&md_ctx); + + ret = mbedtls_md_setup(&md_ctx, md_info, 0); // 0 = non-HMAC + if (ret != 0) { + WARNF("(crypto) Failed to set up hash context: -0x%04x", -ret); + goto cleanup; + } + + ret = mbedtls_md_starts(&md_ctx); + if (ret != 0) { + WARNF("(crypto) Failed to start hash operation: -0x%04x", -ret); + goto cleanup; + } + + ret = mbedtls_md_update(&md_ctx, input, input_len); + if (ret != 0) { + WARNF("(crypto) Failed to update hash: -0x%04x", -ret); + goto cleanup; + } + + ret = mbedtls_md_finish(&md_ctx, output); + if (ret != 0) { + WARNF("(crypto) Failed to finish hash: -0x%04x", -ret); + goto cleanup; + } + +cleanup: + mbedtls_md_free(&md_ctx); + return ret; +} + +// Ciphers +typedef struct { + const char *name; + mbedtls_cipher_id_t id; +} ciphers_map_t; + +static const ciphers_map_t supported_ciphers[] = { + {"AES-128-CBC", MBEDTLS_CIPHER_AES_128_CBC}, // + {"AES-192-CBC", MBEDTLS_CIPHER_AES_192_CBC}, // + {"AES-256-CBC", MBEDTLS_CIPHER_AES_256_CBC}, // + {"AES-128-CTR", MBEDTLS_CIPHER_AES_128_CTR}, // + {"AES-192-CTR", MBEDTLS_CIPHER_AES_192_CTR}, // + {"AES-256-CTR", MBEDTLS_CIPHER_AES_256_CTR}, // + {"AES-128-GCM", MBEDTLS_CIPHER_AES_128_GCM}, // + {"AES-192-GCM", MBEDTLS_CIPHER_AES_192_GCM}, // + {"AES-256-GCM", MBEDTLS_CIPHER_AES_256_GCM}, // + {NULL, 0}}; + +// Strong RNG using mbedtls_entropy_context and mbedtls_ctr_drbg_context +int GenerateRandom(void *ctx, unsigned char *output, size_t len) { + static mbedtls_entropy_context entropy; + static mbedtls_ctr_drbg_context ctr_drbg; + static int initialized = 0; + int ret; + const char *pers = "redbean_entropy"; + + if (!initialized) { + mbedtls_entropy_init(&entropy); + mbedtls_ctr_drbg_init(&ctr_drbg); + + ret = mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy, + (const unsigned char *)pers, strlen(pers)); + if (ret != 0) { + // Clean up on failure + mbedtls_ctr_drbg_free(&ctr_drbg); + mbedtls_entropy_free(&entropy); + return -1; + } + initialized = 1; + } + // mbedtls_ctr_drbg_random returns 0 on success + ret = mbedtls_ctr_drbg_random(&ctr_drbg, output, len); + if (ret != 0) { + // If DRBG fails, reinitialize on next call + initialized = 0; + return -1; + } + return 0; +} + +// RSA + +// Generate RSA Key Pair +static bool RSAGenerateKeyPair(char **private_key_pem, size_t *private_key_len, + char **public_key_pem, size_t *public_key_len, + unsigned int key_length) { + int rc; + mbedtls_pk_context key; + mbedtls_pk_init(&key); + + // Initialize as RSA key + if ((rc = mbedtls_pk_setup(&key, + mbedtls_pk_info_from_type(MBEDTLS_PK_RSA))) != 0) { + WARNF("(crypto) Failed to setup key (grep -0x%04x)", -rc); + mbedtls_pk_free(&key); + return false; + } + + // Generate RSA key + if ((rc = mbedtls_rsa_gen_key(mbedtls_pk_rsa(key), GenerateRandom, 0, + key_length, 65537)) != 0) { + WARNF("(crypto) Failed to generate key (grep -0x%04x)", -rc); + mbedtls_pk_free(&key); + return false; + } + + // Write private key to PEM + *private_key_len = 16000; // Buffer size for private key + *private_key_pem = calloc(1, *private_key_len); + if ((rc = mbedtls_pk_write_key_pem(&key, (unsigned char *)*private_key_pem, + *private_key_len)) != 0) { + WARNF("(crypto) Failed to write private key (grep -0x%04x)", -rc); + free(*private_key_pem); + mbedtls_pk_free(&key); + return false; + } + *private_key_len = strlen(*private_key_pem); + + // Write public key to PEM + *public_key_len = 8000; // Buffer size for public key + *public_key_pem = calloc(1, *public_key_len); + if ((rc = mbedtls_pk_write_pubkey_pem(&key, (unsigned char *)*public_key_pem, + *public_key_len)) != 0) { + WARNF("(crypto) Failed to write public key (grep -0x%04x)", -rc); + free(*private_key_pem); + free(*public_key_pem); + mbedtls_pk_free(&key); + return false; + } + *public_key_len = strlen(*public_key_pem); + + mbedtls_pk_free(&key); + return true; +} +static int LuaRSAGenerateKeyPair(lua_State *L) { + int bits = 2048; + // If no arguments, or first argument is nil, default to 2048 + if (lua_gettop(L) == 0 || lua_isnoneornil(L, 1)) { + bits = 2048; + } else if (lua_gettop(L) == 1 && lua_type(L, 1) == LUA_TNUMBER) { + bits = (int)lua_tointeger(L, 1); + } else { + bits = (int)luaL_optinteger(L, 2, 2048); + } + // Check if key length is valid (only 2048 or 4096 bits are allowed) + if (bits != 2048 && bits != 4096) { + lua_pushnil(L); + lua_pushfstring(L, + "Invalid RSA key length: %d. Only 2048 or 4096 bits key " + "lengths are supported", + bits); + return 2; + } + + char *private_key, *public_key; + size_t private_len, public_len; + + // Call the C function to generate the key pair + if (!RSAGenerateKeyPair(&private_key, &private_len, &public_key, &public_len, + bits)) { + lua_pushnil(L); + lua_pushstring(L, "Failed to generate RSA key pair"); + return 2; + } + + // Push results to Lua + lua_pushstring(L, private_key); + lua_pushstring(L, public_key); + + // Clean up + free(private_key); + free(public_key); + + return 2; +} + +// Helper to get string field from options table for RSA +// static const char *parse_rsa_options(lua_State *L, int options_idx) { +// const char *padding = "pkcs1"; // default +// if (lua_istable(L, options_idx)) { +// lua_getfield(L, options_idx, "padding"); +// if (lua_isstring(L, -1)) { +// padding = lua_tostring(L, -1); +// } +// lua_pop(L, 1); +// } +// return padding; +// } + +static char *RSAEncrypt(const char *public_key_pem, const unsigned char *data, + size_t data_len, size_t *out_len) { + int rc; + + // Parse public key + mbedtls_pk_context key; + mbedtls_pk_init(&key); + if ((rc = mbedtls_pk_parse_public_key(&key, + (const unsigned char *)public_key_pem, + strlen(public_key_pem) + 1)) != 0) { + WARNF("(crypto) Failed to parse public key (grep -0x%04x)", -rc); + mbedtls_pk_free(&key); + return NULL; + } + + // Check if key is RSA + if (mbedtls_pk_get_type(&key) != MBEDTLS_PK_RSA) { + WARNF("(crypto) Key is not an RSA key"); + mbedtls_pk_free(&key); + return NULL; + } + + // Allocate output buffer + size_t key_size = mbedtls_pk_get_len(&key); + unsigned char *output = calloc(1, key_size); + if (!output) { + mbedtls_pk_free(&key); + return NULL; + } + + // Encrypt data + if ((rc = mbedtls_rsa_pkcs1_encrypt(mbedtls_pk_rsa(key), GenerateRandom, 0, + MBEDTLS_RSA_PUBLIC, data_len, data, + output)) != 0) { + WARNF("(crypto) Encryption failed (grep -0x%04x)", -rc); + free(output); + mbedtls_pk_free(&key); + return NULL; + } + + *out_len = key_size; + mbedtls_pk_free(&key); + return (char *)output; +} +static int LuaRSAEncrypt(lua_State *L) { + // Args: key, plaintext, options table + size_t keylen, ptlen; + // Ensure key is a string + if (!lua_isstring(L, 1)) { + lua_pushnil(L); + lua_pushstring(L, "Key must be a string"); + return 2; + } + const char *key = luaL_checklstring(L, 1, &keylen); + // Ensure plaintext is a string + if (!lua_isstring(L, 2)) { + lua_pushnil(L); + lua_pushstring(L, "Plaintext must be a string"); + return 2; + } + const unsigned char *plaintext = + (const unsigned char *)luaL_checklstring(L, 2, &ptlen); + // int options_idx = 3; + // const char *padding = parse_rsa_options(L, options_idx); + size_t out_len; + + char *encrypted = RSAEncrypt(key, plaintext, ptlen, &out_len); + if (!encrypted) { + lua_pushnil(L); + lua_pushstring(L, "Encryption failed"); + return 2; + } + + lua_pushlstring(L, encrypted, out_len); + free(encrypted); + + return 1; +} +static char *RSADecrypt(const char *private_key_pem, + const unsigned char *encrypted_data, + size_t encrypted_len, size_t *out_len) { + int rc; + + // Parse private key + mbedtls_pk_context key; + mbedtls_pk_init(&key); + rc = mbedtls_pk_parse_key(&key, (const unsigned char *)private_key_pem, + strlen(private_key_pem) + 1, NULL, 0); + if (rc != 0) { + WARNF("(crypto) Failed to parse private key (grep -0x%04x)", -rc); + mbedtls_pk_free(&key); + return NULL; + } + + // Check if key is RSA + if (mbedtls_pk_get_type(&key) != MBEDTLS_PK_RSA) { + WARNF("(crypto) Key is not an RSA key"); + mbedtls_pk_free(&key); + return NULL; + } + + // Allocate output buffer + size_t key_size = mbedtls_pk_get_len(&key); + unsigned char *output = calloc(1, key_size); + if (!output) { + mbedtls_pk_free(&key); + return NULL; + } + + // Decrypt data + size_t output_len = 0; + rc = mbedtls_rsa_pkcs1_decrypt(mbedtls_pk_rsa(key), GenerateRandom, 0, + MBEDTLS_RSA_PRIVATE, &output_len, + encrypted_data, output, key_size); + if (rc != 0) { + WARNF("(crypto) Decryption failed (grep -0x%04x)", -rc); + free(output); + mbedtls_pk_free(&key); + return NULL; + } + + *out_len = output_len; + mbedtls_pk_free(&key); + return (char *)output; +} +static int LuaRSADecrypt(lua_State *L) { + // Args: key, ciphertext, options table + size_t keylen, ctlen; + // Ensure key is a string + if (!lua_isstring(L, 1)) { + lua_pushnil(L); + lua_pushstring(L, "Key must be a string"); + return 2; + } + const char *key = luaL_checklstring(L, 1, &keylen); + // Ensure ciphertext is a string + if (!lua_isstring(L, 2)) { + lua_pushnil(L); + lua_pushstring(L, "Ciphertext must be a string"); + return 2; + } + const unsigned char *ciphertext = + (const unsigned char *)luaL_checklstring(L, 2, &ctlen); + // int options_idx = 3; + // const char *padding = parse_rsa_options(L, options_idx); + size_t out_len; + + char *decrypted = RSADecrypt(key, ciphertext, ctlen, &out_len); + if (!decrypted) { + lua_pushnil(L); + lua_pushstring(L, "Decryption failed"); + return 2; + } + + lua_pushlstring(L, decrypted, out_len); + free(decrypted); + + return 1; +} + +// RSA Signing +static char *RSASign(const char *private_key_pem, const unsigned char *data, + size_t data_len, const char *hash_name, size_t *sig_len) { + int rc; + unsigned char hash[64]; // Large enough for SHA-512 + size_t hash_len = 0; + mbedtls_md_type_t hash_algo; + unsigned char *signature; + + // Determine hash algorithm + if (hash_name == NULL || hash_name[0] == '\0') { + hash_algo = MBEDTLS_MD_SHA256; + VERBOSEF("(crypto) No digest specified, using default: SHA256"); + } else { + // Find the digest by name + hash_algo = find_digest_by_name(hash_name); + if (hash_algo == MBEDTLS_MD_NONE) { + WARNF("(crypto) Unknown digest: '%s'", hash_name); + return NULL; + } else { + hash_len = get_hash_size_from_md_type(hash_algo); + } + } + + // Parse private key + mbedtls_pk_context key; + mbedtls_pk_init(&key); + if ((rc = mbedtls_pk_parse_key(&key, (const unsigned char *)private_key_pem, + strlen(private_key_pem) + 1, NULL, 0)) != 0) { + WARNF("(crypto) Failed to parse private key (grep -0x%04x)", -rc); + mbedtls_pk_free(&key); + return NULL; + } + + // Check if key is RSA + if (mbedtls_pk_get_type(&key) != MBEDTLS_PK_RSA) { + WARNF("(crypto) Key is not an RSA key"); + mbedtls_pk_free(&key); + return NULL; + } + + // Hash the message + if ((rc = mbedtls_md(mbedtls_md_info_from_type(hash_algo), data, data_len, + hash)) != 0) { + mbedtls_pk_free(&key); + return NULL; + } + + // Allocate buffer for signature + signature = malloc(MBEDTLS_PK_SIGNATURE_MAX_SIZE); + if (!signature) { + mbedtls_pk_free(&key); + return NULL; + } + + // Sign the hash + if ((rc = mbedtls_pk_sign(&key, hash_algo, hash, hash_len, signature, sig_len, + GenerateRandom, 0)) != 0) { + free(signature); + mbedtls_pk_free(&key); + return NULL; + } + + // Clean up + mbedtls_pk_free(&key); + + return (char *)signature; +} +static int LuaRSASign(lua_State *L) { + size_t msg_len, key_len; + const char *msg, *key_pem, *hash_name = NULL; + unsigned char *signature; + size_t sig_len = 0; + + // Get parameters from Lua + if (lua_type(L, 1) != LUA_TSTRING) { + lua_pushnil(L); + lua_pushstring(L, "Key must be a string"); + return 2; + } + if (lua_type(L, 1) == LUA_TTABLE) { + lua_pushnil(L); + lua_pushstring(L, "Key must be a string, got a table instead"); + return 2; + } + // Ensure msg is a string + if (lua_type(L, 2) != LUA_TSTRING) { + lua_pushnil(L); + lua_pushstring(L, "Plaintext must be a string"); + return 2; + } + + key_pem = luaL_checklstring(L, 1, &key_len); + msg = luaL_checklstring(L, 2, &msg_len); + + // Optional hash algorithm parameter + if (!lua_isnoneornil(L, 3)) { + hash_name = luaL_checkstring(L, 3); + } + + // Call the C implementation + signature = (unsigned char *)RSASign(key_pem, (const unsigned char *)msg, + msg_len, hash_name, &sig_len); + + if (!signature) { + return luaL_error(L, "failed to sign message"); + } + + // Return the signature as a Lua string + lua_pushlstring(L, (char *)signature, sig_len); + + // Clean up + free(signature); + + return 1; +} + +// RSA PSS Signing +static char *RSAPSSSign(const char *private_key_pem, const unsigned char *data, + size_t data_len, const char *hash_name, size_t *sig_len, + int salt_len) { + int rc; + unsigned char hash[64]; // Large enough for SHA-512 + size_t hash_len = 0; + mbedtls_md_type_t hash_algo; + unsigned char *signature; + + // Determine hash algorithm + if (hash_name == NULL || hash_name[0] == '\0') { + hash_algo = MBEDTLS_MD_SHA256; + VERBOSEF("(crypto) No digest specified, using default: SHA256"); + } else { + // Find the digest by name + hash_algo = find_digest_by_name(hash_name); + if (hash_algo == MBEDTLS_MD_NONE) { + WARNF("(crypto) Unknown digest: '%s'", hash_name); + return NULL; + } else { + hash_len = get_hash_size_from_md_type(hash_algo); + } + } + + // Parse private key + mbedtls_pk_context key; + mbedtls_pk_init(&key); + rc = mbedtls_pk_parse_key(&key, (const unsigned char *)private_key_pem, + strlen(private_key_pem) + 1, NULL, 0); + if (rc != 0) { + WARNF("(crypto) Failed to parse private key (grep -0x%04x)", -rc); + mbedtls_pk_free(&key); + return NULL; + } + + // Check if key is RSA + if (mbedtls_pk_get_type(&key) != MBEDTLS_PK_RSA) { + WARNF("(crypto) Key is not an RSA key"); + mbedtls_pk_free(&key); + return NULL; + } + + // Hash the message + rc = mbedtls_md(mbedtls_md_info_from_type(hash_algo), data, data_len, hash); + if (rc != 0) { + mbedtls_pk_free(&key); + return NULL; + } + + // Allocate buffer for signature + signature = malloc(MBEDTLS_PK_SIGNATURE_MAX_SIZE); + if (!signature) { + mbedtls_pk_free(&key); + return NULL; + } + + // Get RSA context + mbedtls_rsa_context *rsa = mbedtls_pk_rsa(key); + mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V21, hash_algo); + + // Sign the hash using PSS + rc = mbedtls_rsa_rsassa_pss_sign(rsa, GenerateRandom, 0, MBEDTLS_RSA_PRIVATE, + hash_algo, (unsigned int)hash_len, hash, + signature); + if (rc != 0) { + free(signature); + mbedtls_pk_free(&key); + return NULL; + } + + *sig_len = mbedtls_pk_get_len(&key); + + // Clean up + mbedtls_pk_free(&key); + + return (char *)signature; +} + +static int LuaRSAPSSSign(lua_State *L) { + size_t key_len, msg_len; + const char *key_pem, *msg; + unsigned char *signature; + size_t sig_len = 0; + + // Get parameters from Lua + if (lua_type(L, 1) != LUA_TSTRING) { + lua_pushnil(L); + lua_pushstring(L, "Key must be a string"); + return 2; + } + // Ensure msg is a string + if (lua_type(L, 2) != LUA_TSTRING) { + lua_pushnil(L); + lua_pushstring(L, "Plaintext must be a string"); + return 2; + } + + // Get parameters from Lua + key_pem = luaL_checklstring(L, 1, &key_len); + msg = luaL_checklstring(L, 2, &msg_len); + + // Optional hash algorithm parameter + const char *hash_name = luaL_optstring(L, 3, "sha256"); + + // Optional salt length parameter + int salt_len = luaL_optinteger(L, 4, -1); + + // Call the C implementation + signature = + (unsigned char *)RSAPSSSign(key_pem, (const unsigned char *)msg, msg_len, + hash_name, &sig_len, salt_len); + + if (!signature) { + return luaL_error(L, "failed to sign message (PSS)"); + } + + // Return the signature as a Lua string + lua_pushlstring(L, (char *)signature, sig_len); + free(signature); + return 1; +} + +static int RSAVerify(const char *public_key_pem, const unsigned char *data, + size_t data_len, const unsigned char *signature, + size_t sig_len, const char *hash_name) { + int rc; + unsigned char hash[64]; // Large enough for SHA-512 + size_t hash_len = 0; + mbedtls_md_type_t hash_algo; + + // Determine hash algorithm + if (hash_name == NULL || hash_name[0] == '\0') { + hash_algo = MBEDTLS_MD_SHA256; + VERBOSEF("(crypto) No digest specified, using default: SHA256"); + } else { + // Find the digest by name + hash_algo = find_digest_by_name(hash_name); + if (hash_algo == MBEDTLS_MD_NONE) { + WARNF("(crypto) Unknown digest: '%s'", hash_name); + return -1; + } else { + hash_len = get_hash_size_from_md_type(hash_algo); + } + } + + // Parse public key + mbedtls_pk_context key; + mbedtls_pk_init(&key); + if ((rc = mbedtls_pk_parse_public_key(&key, + (const unsigned char *)public_key_pem, + strlen(public_key_pem) + 1)) != 0) { + WARNF("(crypto) Failed to parse public key (grep -0x%04x)", -rc); + mbedtls_pk_free(&key); + return -1; + } + + // Check if key is RSA + if (mbedtls_pk_get_type(&key) != MBEDTLS_PK_RSA) { + WARNF("(crypto) Key is not an RSA key"); + mbedtls_pk_free(&key); + return -1; + } + + // Hash the message + if ((rc = mbedtls_md(mbedtls_md_info_from_type(hash_algo), data, data_len, + hash)) != 0) { + mbedtls_pk_free(&key); + return -1; + } + + // Verify the signature + rc = mbedtls_pk_verify(&key, hash_algo, hash, hash_len, signature, sig_len); + + // Clean up + mbedtls_pk_free(&key); + + return rc; // 0 means success (valid signature) +} +static int LuaRSAVerify(lua_State *L) { + size_t msg_len, key_len, sig_len; + const char *msg, *key_pem, *signature, *hash_name = NULL; + int result; + + // Get parameters from Lua + if (!lua_isstring(L, 1)) { + lua_pushnil(L); + lua_pushstring(L, "Key must be a string"); + return 2; + } + // Ensure msg is a string + if (!lua_isstring(L, 2)) { + lua_pushnil(L); + lua_pushstring(L, "Plaintext must be a string"); + return 2; + } + key_pem = luaL_checklstring(L, 1, &key_len); + msg = luaL_checklstring(L, 2, &msg_len); + signature = luaL_checklstring(L, 3, &sig_len); + + // Optional hash algorithm parameter + if (!lua_isnoneornil(L, 4)) { + hash_name = luaL_checkstring(L, 4); + } + + // Call the C implementation + result = RSAVerify(key_pem, (const unsigned char *)msg, msg_len, + (const unsigned char *)signature, sig_len, hash_name); + + // Return boolean result (0 means valid signature) + lua_pushboolean(L, result == 0); + + return 1; +} + +// RSA PSS Verification +static int RSAPSSVerify(const char *public_key_pem, const unsigned char *data, + size_t data_len, const unsigned char *signature, + size_t sig_len, const char *hash_name, + int expected_salt_len) { + int rc; + unsigned char hash[64]; // Large enough for SHA-512 + size_t hash_len = 0; + mbedtls_md_type_t hash_algo; + + // Determine hash algorithm + if (hash_name == NULL || hash_name[0] == '\0') { + hash_algo = MBEDTLS_MD_SHA256; + VERBOSEF("(crypto) No digest specified, using default: SHA256"); + } else { + // Find the digest by name + hash_algo = find_digest_by_name(hash_name); + if (hash_algo == MBEDTLS_MD_NONE) { + WARNF("(crypto) Unknown digest: '%s'", hash_name); + return -1; + } else { + hash_len = get_hash_size_from_md_type(hash_algo); + } + } + + // Parse public key + mbedtls_pk_context key; + mbedtls_pk_init(&key); + if ((rc = mbedtls_pk_parse_public_key(&key, + (const unsigned char *)public_key_pem, + strlen(public_key_pem) + 1)) != 0) { + WARNF("(crypto) Failed to parse public key (grep -0x%04x)", -rc); + mbedtls_pk_free(&key); + return -1; + } + + // Check if key is RSA + if (mbedtls_pk_get_type(&key) != MBEDTLS_PK_RSA) { + WARNF("(crypto) Key is not an RSA key"); + mbedtls_pk_free(&key); + return -1; + } + + // Hash the message + if ((rc = mbedtls_md(mbedtls_md_info_from_type(hash_algo), data, data_len, + hash)) != 0) { + mbedtls_pk_free(&key); + return -1; + } + + // Get RSA context + mbedtls_rsa_context *rsa = mbedtls_pk_rsa(key); + + // Verify the signature using PSS + rc = mbedtls_rsa_rsassa_pss_verify(rsa, NULL, NULL, MBEDTLS_RSA_PUBLIC, + hash_algo, (unsigned int)hash_len, hash, + signature); + + // Clean up + mbedtls_pk_free(&key); + + return rc; // 0 means success (valid signature) +} + +static int LuaRSAPSSVerify(lua_State *L) { + // Args: key, msg, signature, hash_algo (optional), salt_len (optional + // crypto.verify('rsapss', key, msg, signature, hash_algo, salt_len) + size_t msg_len, key_len, sig_len; + const char *msg, *key_pem, *signature; + const char *hash_name = NULL; // Default to SHA-256 + int expected_salt_len = -1; + int result; + + // Get parameters from Lua + if (lua_type(L, 1) != LUA_TSTRING) { + lua_pushnil(L); + lua_pushstring(L, "Key must be a string"); + return 2; + } + // Ensure msg is a string + if (lua_type(L, 2) != LUA_TSTRING) { + lua_pushnil(L); + lua_pushstring(L, "Plaintext must be a string"); + return 2; + } + // Ensure signature is a string + if (lua_type(L, 3) != LUA_TSTRING) { + lua_pushnil(L); + lua_pushstring(L, "Signature must be a string"); + return 2; + } + // Get parameters from Lua + key_pem = luaL_checklstring(L, 1, &key_len); + msg = luaL_checklstring(L, 2, &msg_len); + signature = luaL_checklstring(L, 3, &sig_len); + // Optional hash algorithm parameter + if (lua_isstring(L, 4)) { + hash_name = luaL_checkstring(L, 4); + // Optional salt length parameter + expected_salt_len = luaL_optinteger(L, 5, 32); + } else if (lua_isnumber(L, 4)) { + // If it's a number, treat it as the expected salt length + expected_salt_len = (int)lua_tointeger(L, 4); + } + // Call the C implementation + result = RSAPSSVerify(key_pem, (const unsigned char *)msg, msg_len, + (const unsigned char *)signature, sig_len, hash_name, + expected_salt_len); + + // Return boolean result (0 means valid signature) + lua_pushboolean(L, result == 0); + + return 1; +} + +// Elliptic Curve Cryptography Functions + +// Generate an ECDSA key pair and return in PEM format +static int ECDSAGenerateKeyPair(const char *curve_name, char **priv_key_pem, + char **pub_key_pem) { + mbedtls_pk_context key; + unsigned char output_buf[16000]; + int ret; + mbedtls_ecp_group_id curve_id; + + // Initialize output parameters to NULL in case of early return + if (priv_key_pem) + *priv_key_pem = NULL; + if (pub_key_pem) + *pub_key_pem = NULL; + + // Use secp256r1 as default if curve_name is NULL or empty + if (curve_name == NULL || curve_name[0] == '\0') { + curve_id = MBEDTLS_ECP_DP_SECP256R1; + VERBOSEF("(crypto) No curve specified, using default: secp256r1"); + } else { + // Find the curve by name + curve_id = find_curve_by_name(curve_name); + if (curve_id == MBEDTLS_ECP_DP_NONE) { + WARNF("(crypto) Unknown curve: '%s'", curve_name); + return -1; + } + } + + mbedtls_pk_init(&key); + + // Generate the key with the specified curve + ret = mbedtls_pk_setup(&key, mbedtls_pk_info_from_type(MBEDTLS_PK_ECKEY)); + if (ret != 0) { + WARNF("(crypto) Failed to setup key: -0x%04x", -ret); + goto cleanup; + } + + ret = mbedtls_ecp_gen_key(curve_id, mbedtls_pk_ec(key), GenerateRandom, 0); + if (ret != 0) { + WARNF("(crypto) Failed to generate key: -0x%04x", -ret); + goto cleanup; + } + + // Generate private key PEM + if (priv_key_pem != NULL) { + memset(output_buf, 0, sizeof(output_buf)); + ret = mbedtls_pk_write_key_pem(&key, output_buf, sizeof(output_buf)); + if (ret != 0) { + WARNF("(crypto) Failed to write private key: -0x%04x", -ret); + goto cleanup; + } + *priv_key_pem = strdup((char *)output_buf); + if (*priv_key_pem == NULL) { + WARNF("(crypto) Failed to allocate memory for private key PEM"); + ret = -1; + goto cleanup; + } + } + + // Generate public key PEM + if (pub_key_pem != NULL) { + memset(output_buf, 0, sizeof(output_buf)); + ret = mbedtls_pk_write_pubkey_pem(&key, output_buf, sizeof(output_buf)); + if (ret != 0) { + WARNF("(crypto) Failed to write public key: -0x%04x", -ret); + goto cleanup; + } + *pub_key_pem = strdup((char *)output_buf); + if (*pub_key_pem == NULL) { + WARNF("(crypto) Failed to allocate memory for public key PEM"); + ret = -1; + goto cleanup; + } + } + +cleanup: + mbedtls_pk_free(&key); + if (ret != 0) { + // Clean up on error + if (priv_key_pem && *priv_key_pem) { + free(*priv_key_pem); + *priv_key_pem = NULL; + } + if (pub_key_pem && *pub_key_pem) { + free(*pub_key_pem); + *pub_key_pem = NULL; + } + } + return ret; +} +static int LuaECDSAGenerateKeyPair(lua_State *L) { + const char *curve_name = NULL; + char *priv_key_pem = NULL; + char *pub_key_pem = NULL; + + // Check if curve name is provided + if (lua_gettop(L) >= 1 && !lua_isnil(L, 1)) { + curve_name = luaL_checkstring(L, 1); + } + + int ret = ECDSAGenerateKeyPair(curve_name, &priv_key_pem, &pub_key_pem); + + if (ret == 0) { + lua_pushstring(L, priv_key_pem); + lua_pushstring(L, pub_key_pem); + free(priv_key_pem); + free(pub_key_pem); + return 2; + } else { + lua_pushnil(L); + lua_pushnil(L); + return 2; + } +} + +// Sign a message using an ECDSA private key in PEM format +static int ECDSASign(const char *priv_key_pem, const char *message, + mbedtls_md_type_t hash_alg, unsigned char **signature, + size_t *sig_len) { + mbedtls_pk_context key; + unsigned char hash[64]; // Max hash size (SHA-512) + size_t hash_size; + int ret; + + *signature = NULL; + *sig_len = 0; + + if (!priv_key_pem) { + WARNF("(crypto) Private key is NULL"); + return -1; + } + + // Get the length of the PEM string (excluding null terminator) + size_t key_len = strlen(priv_key_pem); + if (key_len == 0) { + WARNF("(crypto) Private key is empty"); + return -1; + } + + // Get hash size for the selected algorithm + hash_size = get_hash_size_from_md_type(hash_alg); + + mbedtls_pk_init(&key); + + // Parse the private key from PEM directly without creating a copy + ret = mbedtls_pk_parse_key(&key, (const unsigned char *)priv_key_pem, + key_len + 1, NULL, 0); + + if (ret != 0) { + WARNF("(crypto) Failed to parse private key: -0x%04x", -ret); + goto cleanup; + } + + // Compute hash of the message using the specified algorithm + ret = compute_hash(hash_alg, (const unsigned char *)message, strlen(message), + hash, sizeof(hash)); + if (ret != 0) { + WARNF("(crypto) Failed to compute message hash"); + goto cleanup; + } + + // Allocate memory for signature (max size for ECDSA) + *signature = malloc(MBEDTLS_ECDSA_MAX_LEN); + if (*signature == NULL) { + WARNF("(crypto) Failed to allocate memory for signature"); + ret = -1; + goto cleanup; + } + + // Sign the hash using GenerateRandom + ret = mbedtls_pk_sign(&key, hash_alg, hash, hash_size, *signature, sig_len, + GenerateRandom, 0); + + if (ret != 0) { + WARNF("(crypto) Failed to sign message: -0x%04x", -ret); + free(*signature); + *signature = NULL; + *sig_len = 0; + goto cleanup; + } + +cleanup: + mbedtls_pk_free(&key); + return ret; +} // Lua binding for signing a message +static int LuaECDSASign(lua_State *L) { + // Correct order: priv_key, message, hash_name (default sha256) + const char *priv_key_pem = luaL_checkstring(L, 1); + const char *message = luaL_checkstring(L, 2); + const char *hash_name = luaL_optstring(L, 3, "sha256"); + mbedtls_md_type_t hash_algo; + + if (hash_name == NULL || hash_name[0] == '\0') { + hash_algo = MBEDTLS_MD_SHA256; + VERBOSEF("(crypto) No digest specified, using default: SHA256"); + } else { + // Find the digest by name + hash_algo = find_digest_by_name(hash_name); + if (hash_algo == MBEDTLS_MD_NONE) { + WARNF("(crypto) Unknown digest: '%s'", hash_name); + return -1; + } + } + + unsigned char *signature = NULL; + size_t sig_len = 0; + + int ret = ECDSASign(priv_key_pem, message, hash_algo, &signature, &sig_len); + + if (ret == 0) { + lua_pushlstring(L, (const char *)signature, sig_len); + free(signature); + } else { + lua_pushnil(L); + } + + return 1; +} + +// Verify a signature using an ECDSA public key in PEM format +static int ECDSAVerify(const char *pub_key_pem, const char *message, + const unsigned char *signature, size_t sig_len, + mbedtls_md_type_t hash_alg) { + mbedtls_pk_context key; + unsigned char hash[64]; // Max hash size (SHA-512) + size_t hash_size; + int ret; + + if (!pub_key_pem) { + WARNF("(crypto) Public key is NULL"); + return -1; + } + + // Get the length of the PEM string (excluding null terminator) + size_t key_len = strlen(pub_key_pem); + if (key_len == 0) { + WARNF("(crypto) Public key is empty"); + return -1; + } + + // Get hash size for the selected algorithm + hash_size = get_hash_size_from_md_type(hash_alg); + + mbedtls_pk_init(&key); + + // Parse the public key from PEM + ret = mbedtls_pk_parse_public_key(&key, (const unsigned char *)pub_key_pem, + key_len + 1); + if (ret != 0) { + WARNF("(crypto) Failed to parse public key: -0x%04x", -ret); + goto cleanup; + } + + // Compute hash of the message using the specified algorithm + ret = compute_hash(hash_alg, (const unsigned char *)message, strlen(message), + hash, sizeof(hash)); + if (ret != 0) { + WARNF("(crypto) Failed to compute message hash"); + goto cleanup; + } + + // Verify the signature + ret = mbedtls_pk_verify(&key, hash_alg, hash, hash_size, signature, sig_len); + if (ret != 0) { + WARNF("(crypto) Signature verification failed: -0x%04x", -ret); + goto cleanup; + } + +cleanup: + mbedtls_pk_free(&key); + return ret; +} +static int LuaECDSAVerify(lua_State *L) { + // Correct order: pub_key, message, signature, hash_name (default sha256) + const char *pub_key_pem = luaL_checkstring(L, 1); + const char *message = luaL_checkstring(L, 2); + size_t sig_len; + const unsigned char *signature = + (const unsigned char *)luaL_checklstring(L, 3, &sig_len); + const char *hash_name = luaL_optstring(L, 4, "sha256"); + mbedtls_md_type_t hash_algo; + + if (hash_name == NULL || hash_name[0] == '\0') { + hash_algo = MBEDTLS_MD_SHA256; + VERBOSEF("(crypto) No digest specified, using default: SHA256"); + } else { + // Find the digest by name + hash_algo = find_digest_by_name(hash_name); + if (hash_algo == MBEDTLS_MD_NONE) { + WARNF("(crypto) Unknown digest: '%s'", hash_name); + lua_pushboolean(L, false); + } + } + + int ret = ECDSAVerify(pub_key_pem, message, signature, sig_len, hash_algo); + + lua_pushboolean(L, ret == 0); + return 1; +} + +// AES + +// AES key generation helper +static int LuaAesGenerateKey(lua_State *L) { + int keybits = 128; + if (lua_gettop(L) >= 1 && !lua_isnil(L, 1)) { + keybits = luaL_checkinteger(L, 1); + } + int keylen = keybits / 8; + if ((keybits != 128 && keybits != 192 && keybits != 256) || + (keylen != 16 && keylen != 24 && keylen != 32)) { + lua_pushnil(L); + lua_pushstring(L, "AES key length must be 128, 192, or 256 bits"); + return 2; + } + unsigned char key[32]; + // Generate random key + if (GenerateRandom(NULL, key, keylen) != 0) { + lua_pushnil(L); + lua_pushstring(L, "Failed to generate random key"); + return 2; + } + lua_pushlstring(L, (const char *)key, keylen); + return 1; +} + +// Helper to get string field from options table +typedef struct { + const char *mode; + const unsigned char *iv; + size_t ivlen; + const unsigned char *tag; + size_t taglen; + const unsigned char *aad; + size_t aadlen; +} aes_options_t; + +static void parse_aes_options(lua_State *L, int options_idx, + aes_options_t *opts) { + opts->mode = NULL; + opts->iv = NULL; + opts->ivlen = 0; + opts->tag = NULL; + opts->taglen = 0; + opts->aad = NULL; + opts->aadlen = 0; + + int mode_field_found = 0; + + if (lua_istable(L, options_idx)) { + // Get mode + lua_getfield(L, options_idx, "mode"); + if (!lua_isnil(L, -1)) { + mode_field_found = 1; + const char *mode = lua_tostring(L, -1); + if (mode && + (strcasecmp(mode, "cbc") == 0 || strcasecmp(mode, "gcm") == 0 || + strcasecmp(mode, "ctr") == 0)) { + opts->mode = mode; + } else { + opts->mode = NULL; // Invalid mode + } + } + lua_pop(L, 1); + + // Get IV + lua_getfield(L, options_idx, "iv"); + if (lua_isstring(L, -1)) { + size_t ivlen; + opts->iv = (const unsigned char *)lua_tolstring(L, -1, &ivlen); + opts->ivlen = ivlen; + } + lua_pop(L, 1); + + // Get tag (for GCM) + lua_getfield(L, options_idx, "tag"); + if (lua_isstring(L, -1)) { + size_t taglen; + opts->tag = (const unsigned char *)lua_tolstring(L, -1, &taglen); + opts->taglen = taglen; + } + lua_pop(L, 1); + + // Get aad (for GCM) + lua_getfield(L, options_idx, "aad"); + if (lua_isstring(L, -1)) { + size_t aadlen; + opts->aad = (const unsigned char *)lua_tolstring(L, -1, &aadlen); + opts->aadlen = aadlen; + } + lua_pop(L, 1); + } + + // Only default to cbc if mode field was not found at all + if (!mode_field_found) { + opts->mode = "cbc"; + } +} +// AES encryption supporting CBC, GCM, and CTR modes +static int LuaAesEncrypt(lua_State *L) { + // Args: key, plaintext, options table + size_t keylen, ptlen; + + // Get parameters from Lua + // Ensure key is a string + if (lua_type(L, 1) != LUA_TSTRING) { + lua_pushnil(L); + lua_pushstring(L, "Key must be a string"); + return 2; + } + // Ensure plaintext is a string + if (lua_type(L, 2) != LUA_TSTRING) { + lua_pushnil(L); + lua_pushstring(L, "Plaintext must be a string"); + return 2; + } + // Ensure options is a table or nil + if (!lua_istable(L, 3) && !lua_isnil(L, 3)) { + lua_pushnil(L); + lua_pushstring(L, "Options must be a table or nil"); + return 2; + } + + const unsigned char *key = + (const unsigned char *)luaL_checklstring(L, 1, &keylen); + const unsigned char *plaintext = + (const unsigned char *)luaL_checklstring(L, 2, &ptlen); + int options_idx = 3; + aes_options_t opts; + parse_aes_options(L, options_idx, &opts); + const char *mode = opts.mode; + if (!mode) { + lua_pushnil(L); + lua_pushstring(L, + "Invalid AES mode specified. Use 'cbc', 'gcm', or 'ctr'."); + return 2; + } + const unsigned char *iv = opts.iv; + size_t ivlen = opts.ivlen; + unsigned char *gen_iv = NULL; + int iv_was_generated = 0; + int ret = 0; + unsigned char *output = NULL; + int is_gcm = 0, is_ctr = 0, is_cbc = 0; + if (strcasecmp(mode, "cbc") == 0) { + is_cbc = 1; + } else if (strcasecmp(mode, "gcm") == 0) { + is_gcm = 1; + } else if (strcasecmp(mode, "ctr") == 0) { + is_ctr = 1; + } else { + lua_pushnil(L); + lua_pushstring(L, "Unsupported AES mode. Use 'cbc', 'gcm', or 'ctr'."); + return 2; + } + + // If IV is not provided, auto-generate + if (!iv) { + if (is_gcm) { + ivlen = 12; + } else { + ivlen = 16; + } + gen_iv = malloc(ivlen); + if (!gen_iv) { + lua_pushnil(L); + lua_pushstring(L, "Failed to allocate IV"); + return 2; + } + + // Generate random IV + if (GenerateRandom(NULL, gen_iv, ivlen) != 0) { + free(gen_iv); + lua_pushnil(L); + lua_pushstring(L, "Failed to generate random IV"); + return 2; + } + iv = gen_iv; + iv_was_generated = 1; + } + + // Validate IV/nonce length + if (is_cbc || is_ctr) { + if (opts.iv && opts.ivlen != 16) { + if (iv_was_generated) + free(gen_iv); + lua_pushnil(L); + lua_pushstring(L, "AES IV/nonce must be 16 bytes for CBC/CTR"); + return 2; + } + } else if (is_gcm) { + if (opts.iv && (opts.ivlen < 12 || opts.ivlen > 16)) { + if (iv_was_generated) + free(gen_iv); + lua_pushnil(L); + lua_pushstring(L, "AES GCM nonce must be 12-16 bytes"); + return 2; + } + } + + if (is_cbc) { + // PKCS7 padding + size_t block_size = 16; + size_t padlen = block_size - (ptlen % block_size); + size_t ctlen = ptlen + padlen; + unsigned char *input = malloc(ctlen); + if (!input) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + return 2; + } + memcpy(input, plaintext, ptlen); + memset(input + ptlen, (unsigned char)padlen, padlen); + output = malloc(ctlen); + if (!output) { + free(input); + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + return 2; + } + mbedtls_aes_context aes; + mbedtls_aes_init(&aes); + ret = mbedtls_aes_setkey_enc(&aes, key, keylen * 8); + if (ret != 0) { + free(input); + free(output); + mbedtls_aes_free(&aes); + lua_pushnil(L); + lua_pushstring(L, "Failed to set AES encryption key"); + return 2; + } + unsigned char iv_copy[16]; + memcpy(iv_copy, iv, 16); + ret = mbedtls_aes_crypt_cbc(&aes, MBEDTLS_AES_ENCRYPT, ctlen, iv_copy, + input, output); + mbedtls_aes_free(&aes); + free(input); + if (ret != 0) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "AES CBC encryption failed"); + return 2; + } + lua_pushlstring(L, (const char *)output, ctlen); + lua_pushlstring(L, (const char *)iv, ivlen); + free(output); + if (iv_was_generated) + free(gen_iv); + return 2; + } else if (is_ctr) { + // CTR mode: no padding + output = malloc(ptlen); + if (!output) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + return 2; + } + mbedtls_aes_context aes; + mbedtls_aes_init(&aes); + ret = mbedtls_aes_setkey_enc(&aes, key, keylen * 8); + if (ret != 0) { + free(output); + mbedtls_aes_free(&aes); + lua_pushnil(L); + lua_pushstring(L, "Failed to set AES encryption key"); + return 2; + } + unsigned char nonce_counter[16]; + unsigned char stream_block[16]; + size_t nc_off = 0; + memcpy(nonce_counter, iv, 16); + memset(stream_block, 0, 16); + ret = mbedtls_aes_crypt_ctr(&aes, ptlen, &nc_off, nonce_counter, + stream_block, plaintext, output); + mbedtls_aes_free(&aes); + if (ret != 0) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "AES CTR encryption failed"); + return 2; + } + lua_pushlstring(L, (const char *)output, ptlen); + lua_pushlstring(L, (const char *)iv, ivlen); + free(output); + if (iv_was_generated) + free(gen_iv); + return 2; + } else if (is_gcm) { + // GCM mode: authenticated encryption + size_t taglen = 16; + unsigned char tag[16]; + output = malloc(ptlen); + if (!output) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + return 2; + } + mbedtls_gcm_context gcm; + mbedtls_gcm_init(&gcm); + ret = mbedtls_gcm_setkey(&gcm, MBEDTLS_CIPHER_ID_AES, key, keylen * 8); + if (ret != 0) { + free(output); + mbedtls_gcm_free(&gcm); + lua_pushnil(L); + lua_pushstring(L, "Failed to set AES GCM key"); + return 2; + } + ret = mbedtls_gcm_crypt_and_tag(&gcm, MBEDTLS_GCM_ENCRYPT, ptlen, iv, ivlen, + NULL, 0, plaintext, output, taglen, tag); + mbedtls_gcm_free(&gcm); + if (ret != 0) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "AES GCM encryption failed"); + return 2; + } + lua_pushlstring(L, (const char *)output, ptlen); + lua_pushlstring(L, (const char *)iv, ivlen); + lua_pushlstring(L, (const char *)tag, taglen); + free(output); + if (iv_was_generated) + free(gen_iv); + return 3; + } + lua_pushnil(L); + lua_pushstring(L, "Internal error in AES encrypt"); + return 2; +} + +// AES decryption supporting CBC, GCM, and CTR modes +static int LuaAesDecrypt(lua_State *L) { + // Args: key, ciphertext, options table + size_t keylen, ctlen; + // Ensure key is a string + if (lua_type(L, 1) != LUA_TSTRING) { + lua_pushnil(L); + lua_pushstring(L, "Key must be a string"); + return 2; + } + + // Ensure ciphertext is a string + if (lua_type(L, 2) != LUA_TSTRING) { + lua_pushnil(L); + lua_pushstring(L, "Ciphertext must be a string"); + return 2; + } + const unsigned char *key = + (const unsigned char *)luaL_checklstring(L, 1, &keylen); + const unsigned char *ciphertext = + (const unsigned char *)luaL_checklstring(L, 2, &ctlen); + int options_idx = 3; + aes_options_t opts; + parse_aes_options(L, options_idx, &opts); + const char *mode = opts.mode; + if (!mode) { + lua_pushnil(L); + lua_pushstring(L, + "Invalid AES mode specified. Use 'cbc', 'gcm', or 'ctr'."); + return 2; + } + const unsigned char *iv = opts.iv; + size_t ivlen = opts.ivlen; + const unsigned char *tag = opts.tag; + size_t taglen = opts.taglen; + const unsigned char *aad = opts.aad; + size_t aadlen = opts.aadlen; + int is_gcm = 0, is_ctr = 0, is_cbc = 0; + if (strcasecmp(mode, "cbc") == 0) { + is_cbc = 1; + } else if (strcasecmp(mode, "gcm") == 0) { + is_gcm = 1; + } else if (strcasecmp(mode, "ctr") == 0) { + is_ctr = 1; + } else { + lua_pushnil(L); + lua_pushstring(L, "Unsupported AES mode. Use 'cbc', 'gcm', or 'ctr'."); + return 2; + } + // Validate key length (16, 24, 32 bytes) + if (keylen != 16 && keylen != 24 && keylen != 32) { + lua_pushnil(L); + lua_pushstring(L, "AES key must be 16, 24, or 32 bytes"); + return 2; + } + // Validate IV/nonce length + if (is_cbc || is_ctr) { + if (ivlen != 16) { + lua_pushnil(L); + lua_pushstring(L, "AES IV/nonce must be 16 bytes for CBC/CTR"); + return 2; + } + } else if (is_gcm) { + if (ivlen < 12 || ivlen > 16) { + lua_pushnil(L); + lua_pushstring(L, "AES GCM nonce must be 12-16 bytes"); + return 2; + } + } + + // GCM: require tag and optional AAD + if (is_gcm) { + if (!tag || taglen < 12 || taglen > 16) { + lua_pushnil(L); + lua_pushstring(L, "AES GCM tag must be 12-16 bytes"); + return 2; + } + } + + int ret = 0; + unsigned char *output = NULL; + + if (is_cbc) { + // Ciphertext must be a multiple of block size + if (ctlen == 0 || (ctlen % 16) != 0) { + lua_pushnil(L); + lua_pushstring(L, "Ciphertext length must be a multiple of 16"); + return 2; + } + output = malloc(ctlen); + if (!output) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + return 2; + } + mbedtls_aes_context aes; + mbedtls_aes_init(&aes); + ret = mbedtls_aes_setkey_dec(&aes, key, keylen * 8); + if (ret != 0) { + free(output); + mbedtls_aes_free(&aes); + lua_pushnil(L); + lua_pushstring(L, "Failed to set AES decryption key"); + return 2; + } + unsigned char iv_copy[16]; + memcpy(iv_copy, iv, 16); + ret = mbedtls_aes_crypt_cbc(&aes, MBEDTLS_AES_DECRYPT, ctlen, iv_copy, + ciphertext, output); + mbedtls_aes_free(&aes); + if (ret != 0) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "AES CBC decryption failed"); + return 2; + } + // PKCS7 unpadding + if (ctlen == 0) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "Decrypted data is empty"); + return 2; + } + unsigned char pad = output[ctlen - 1]; + if (pad == 0 || pad > 16) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "Invalid PKCS7 padding"); + return 2; + } + for (size_t i = 0; i < pad; ++i) { + if (output[ctlen - 1 - i] != pad) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "Invalid PKCS7 padding"); + return 2; + } + } + size_t ptlen = ctlen - pad; + lua_pushlstring(L, (const char *)output, ptlen); + free(output); + return 1; + } else if (is_ctr) { + // CTR mode: no padding + output = malloc(ctlen); + if (!output) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + return 2; + } + mbedtls_aes_context aes; + mbedtls_aes_init(&aes); + ret = mbedtls_aes_setkey_enc(&aes, key, keylen * 8); + if (ret != 0) { + free(output); + mbedtls_aes_free(&aes); + lua_pushnil(L); + lua_pushstring(L, "Failed to set AES encryption key"); + return 2; + } + unsigned char nonce_counter[16]; + unsigned char stream_block[16]; + size_t nc_off = 0; + memcpy(nonce_counter, iv, 16); + memset(stream_block, 0, 16); + ret = mbedtls_aes_crypt_ctr(&aes, ctlen, &nc_off, nonce_counter, + stream_block, ciphertext, output); + mbedtls_aes_free(&aes); + if (ret != 0) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "AES CTR decryption failed"); + return 2; + } + lua_pushlstring(L, (const char *)output, ctlen); + free(output); + return 1; + } else if (is_gcm) { + // GCM mode: authenticated decryption + output = malloc(ctlen); + if (!output) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + return 2; + } + mbedtls_gcm_context gcm; + mbedtls_gcm_init(&gcm); + ret = mbedtls_gcm_setkey(&gcm, MBEDTLS_CIPHER_ID_AES, key, keylen * 8); + if (ret != 0) { + free(output); + mbedtls_gcm_free(&gcm); + lua_pushnil(L); + lua_pushstring(L, "Failed to set AES GCM key"); + return 2; + } + ret = mbedtls_gcm_auth_decrypt(&gcm, ctlen, iv, ivlen, aad, aadlen, tag, + taglen, ciphertext, output); + mbedtls_gcm_free(&gcm); + if (ret != 0) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "AES GCM decryption failed or authentication failed"); + return 2; + } + lua_pushlstring(L, (const char *)output, ctlen); + free(output); + return 1; + } + lua_pushnil(L); + lua_pushstring(L, "Internal error in AES decrypt"); + return 2; +} + +// JWK functions +// Helper: convert base64url to standard base64 (in-place) +static void base64url_to_base64(char *input) { + if (!input) + return; + // Replace URL-safe characters with standard base64 characters + for (char *p = input; *p; ++p) { + if (*p == '-') + *p = '+'; + else if (*p == '_') + *p = '/'; + } + // Add padding if necessary + size_t len = strlen(input); + int mod = len % 4; + if (mod) { + for (int i = 0; i < 4 - mod; ++i) + input[len + i] = '='; + input[len + 4 - mod] = '\0'; + } +} + +// Helper: convert standard base64 to base64url (in-place) +static void base64_to_base64url(char *input) { + if (!input) + return; + size_t len = strlen(input); + // Replace standard base64 characters with URL-safe characters + for (size_t i = 0; i < len; i++) { + if (input[i] == '+') + input[i] = '-'; + else if (input[i] == '/') + input[i] = '_'; + } + // Remove padding + while (len > 0 && input[len - 1] == '=') { + input[--len] = '\0'; + } +} + +// Helper: encode binary to base64url +static char *b64url_encode(const unsigned char *data, size_t len) { + size_t b64_len; + mbedtls_base64_encode(NULL, 0, &b64_len, data, len); + char *b64 = malloc(b64_len + 1); + if (!b64) + return NULL; + mbedtls_base64_encode((unsigned char *)b64, b64_len, &b64_len, data, len); + b64[b64_len] = '\0'; + base64_to_base64url(b64); + return b64; +} + +// Convert JWK key to PEM (string) format +static int LuaConvertJwkToPem(lua_State *L) { + luaL_checktype(L, 1, LUA_TTABLE); + const char *kty; + + if (lua_isnoneornil(L, 1) || lua_type(L, 1) != LUA_TTABLE) { + lua_pushnil(L); + lua_pushstring(L, "Expected a JWK table, got nil"); + return 2; + } + + lua_getfield(L, 1, "kty"); + kty = lua_tostring(L, -1); + if (!kty) { + lua_pushnil(L); + lua_pushstring(L, "Missing 'kty' in JWK"); + return 2; + } + + int ret = -1; + char *pem = NULL; + mbedtls_pk_context pk; + mbedtls_pk_init(&pk); + + if (strcasecmp(kty, "RSA") == 0) { + // RSA JWK: n, e (base64url), optionally d, p, q, dp, dq, qi + lua_getfield(L, 1, "n"); + lua_getfield(L, 1, "e"); + const char *n_b64 = lua_tostring(L, -2); + const char *e_b64 = lua_tostring(L, -1); + if (!n_b64 || !*n_b64) { + lua_pushnil(L); + lua_pushstring(L, "Missing or empty 'n' in JWK"); + return 2; + } + if (!e_b64 || !*e_b64) { + lua_pushnil(L); + lua_pushstring(L, "Missing or empty 'e' in JWK"); + return 2; + } + // Optional private fields + lua_getfield(L, 1, "d"); + lua_getfield(L, 1, "p"); + lua_getfield(L, 1, "q"); + lua_getfield(L, 1, "dp"); + lua_getfield(L, 1, "dq"); + lua_getfield(L, 1, "qi"); + const char *d_b64 = lua_tostring(L, -6); + const char *p_b64 = lua_tostring(L, -5); + const char *q_b64 = lua_tostring(L, -4); + const char *dp_b64 = lua_tostring(L, -3); + const char *dq_b64 = lua_tostring(L, -2); + const char *qi_b64 = lua_tostring(L, -1); + int has_private = d_b64 && *d_b64; + // Decode base64url to binary + size_t n_len, e_len; + unsigned char n_bin[1024], e_bin[16]; + char *n_b64_std = strdup(n_b64); + char *e_b64_std = strdup(e_b64); + base64url_to_base64(n_b64_std); + base64url_to_base64(e_b64_std); + if (mbedtls_base64_decode(n_bin, sizeof(n_bin), &n_len, + (const unsigned char *)n_b64_std, + strlen(n_b64_std)) != 0 || + mbedtls_base64_decode(e_bin, sizeof(e_bin), &e_len, + (const unsigned char *)e_b64_std, + strlen(e_b64_std)) != 0) { + free(n_b64_std); + free(e_b64_std); + lua_pushnil(L); + lua_pushstring(L, "Base64 decode failed"); + return 2; + } + free(n_b64_std); + free(e_b64_std); + // Build RSA context in pk + if ((ret = mbedtls_pk_setup( + + &pk, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA))) != 0) { + lua_pushnil(L); + lua_pushstring(L, "mbedtls_pk_setup failed"); + return 2; + } + mbedtls_rsa_context *rsa = mbedtls_pk_rsa(pk); + mbedtls_rsa_init(rsa, MBEDTLS_RSA_PKCS_V15, 0); + mbedtls_mpi_read_binary(&rsa->N, n_bin, n_len); + mbedtls_mpi_read_binary(&rsa->E, e_bin, e_len); + rsa->len = n_len; + if (has_private) { + // Decode and set private fields + size_t d_len, p_len, q_len, dp_len, dq_len, qi_len; + unsigned char d_bin[1024], p_bin[512], q_bin[512], dp_bin[512], + dq_bin[512], qi_bin[512]; +// Decode all private fields (skip if NULL) +#define DECODE_B64URL(var, b64, bin, binlen) \ + if (b64 && *b64) { \ + char *b64_std = strdup(b64); \ + base64url_to_base64(b64_std); \ + mbedtls_base64_decode(bin, sizeof(bin), &binlen, \ + (const unsigned char *)b64_std, strlen(b64_std)); \ + free(b64_std); \ + } + DECODE_B64URL(d, d_b64, d_bin, d_len); + DECODE_B64URL(p, p_b64, p_bin, p_len); + DECODE_B64URL(q, q_b64, q_bin, q_len); + DECODE_B64URL(dp, dp_b64, dp_bin, dp_len); + DECODE_B64URL(dq, dq_b64, dq_bin, dq_len); + DECODE_B64URL(qi, qi_b64, qi_bin, qi_len); + mbedtls_mpi_read_binary(&rsa->D, d_bin, d_len); + mbedtls_mpi_read_binary(&rsa->P, p_bin, p_len); + mbedtls_mpi_read_binary(&rsa->Q, q_bin, q_len); + mbedtls_mpi_read_binary(&rsa->DP, dp_bin, dp_len); + mbedtls_mpi_read_binary(&rsa->DQ, dq_bin, dq_len); + mbedtls_mpi_read_binary(&rsa->QP, qi_bin, qi_len); + } + // Write PEM + unsigned char buf[4096]; + if (has_private) { + ret = mbedtls_pk_write_key_pem(&pk, buf, sizeof(buf)); + } else { + ret = mbedtls_pk_write_pubkey_pem(&pk, buf, sizeof(buf)); + } + if (ret != 0) { + mbedtls_pk_free(&pk); + lua_pushnil(L); + lua_pushstring(L, "PEM write failed"); + return 2; + } + pem = strdup((char *)buf); + mbedtls_pk_free(&pk); + lua_pushstring(L, pem); + free(pem); + return 1; + } else if (strcasecmp(kty, "EC") == 0) { + // EC JWK: crv, x, y (base64url), optionally d + lua_getfield(L, 1, "crv"); + lua_getfield(L, 1, "x"); + lua_getfield(L, 1, "y"); + lua_getfield(L, 1, "d"); + const char *crv = lua_tostring(L, -4); + const char *x_b64 = lua_tostring(L, -3); + const char *y_b64 = lua_tostring(L, -2); + const char *d_b64 = lua_tostring(L, -1); + if (!crv || !*crv) { + lua_pushnil(L); + lua_pushstring(L, "Missing or empty 'crv' in JWK"); + return 2; + } + if (!x_b64 || !*x_b64) { + lua_pushnil(L); + lua_pushstring(L, "Missing or empty 'x' in JWK"); + return 2; + } + if (!y_b64 || !*y_b64) { + lua_pushnil(L); + lua_pushstring(L, "Missing or empty 'y' in JWK"); + return 2; + } + int has_private = d_b64 && *d_b64; + mbedtls_ecp_group_id gid = find_curve_by_name(crv); + if (gid == MBEDTLS_ECP_DP_NONE) { + lua_pushnil(L); + lua_pushstring(L, "Unknown curve"); + return 2; + } + size_t x_len, y_len; + unsigned char x_bin[72], y_bin[72]; + char *x_b64_std = strdup(x_b64); + char *y_b64_std = strdup(y_b64); + base64url_to_base64(x_b64_std); + base64url_to_base64(y_b64_std); + if (mbedtls_base64_decode(x_bin, sizeof(x_bin), &x_len, + (const unsigned char *)x_b64_std, + strlen(x_b64_std)) != 0 || + mbedtls_base64_decode(y_bin, sizeof(y_bin), &y_len, + (const unsigned char *)y_b64_std, + strlen(y_b64_std)) != 0) { + free(x_b64_std); + free(y_b64_std); + lua_pushnil(L); + lua_pushstring(L, "Base64 decode failed"); + return 2; + } + free(x_b64_std); + free(y_b64_std); + if ((ret = mbedtls_pk_setup( + &pk, mbedtls_pk_info_from_type(MBEDTLS_PK_ECKEY))) != 0) { + lua_pushnil(L); + lua_pushstring(L, "mbedtls_pk_setup failed"); + return 2; + } + mbedtls_ecp_keypair *ec = mbedtls_pk_ec(pk); + mbedtls_ecp_keypair_init(ec); + mbedtls_ecp_group_load(&ec->grp, gid); + mbedtls_mpi_read_binary(&ec->Q.X, x_bin, x_len); + mbedtls_mpi_read_binary(&ec->Q.Y, y_bin, y_len); + mbedtls_mpi_lset(&ec->Q.Z, 1); + if (has_private) { + size_t d_len; + unsigned char d_bin[72]; + DECODE_B64URL(d, d_b64, d_bin, d_len); + mbedtls_mpi_read_binary(&ec->d, d_bin, d_len); + } + unsigned char buf[4096]; + if (has_private) { + ret = mbedtls_pk_write_key_pem(&pk, buf, sizeof(buf)); + } else { + ret = mbedtls_pk_write_pubkey_pem(&pk, buf, sizeof(buf)); + } + if (ret != 0) { + mbedtls_pk_free(&pk); + lua_pushnil(L); + lua_pushstring(L, "PEM write failed"); + return 2; + } + pem = strdup((char *)buf); + mbedtls_pk_free(&pk); + lua_pushstring(L, pem); + free(pem); + return 1; + } else { + lua_pushnil(L); + lua_pushstring(L, "Unsupported kty"); + return 2; + } +} + +static int LuaConvertPemToJwk(lua_State *L) { + const char *pem_key = luaL_checkstring(L, 1); + int has_claims = 0; + if (!lua_isnoneornil(L, 2) && lua_istable(L, 2)) { + has_claims = 1; + } + + mbedtls_pk_context key; + mbedtls_pk_init(&key); + int ret; + + // Parse the PEM key + if ((ret = mbedtls_pk_parse_key(&key, (const unsigned char *)pem_key, + strlen(pem_key) + 1, NULL, 0)) != 0 && + (ret = mbedtls_pk_parse_public_key(&key, (const unsigned char *)pem_key, + strlen(pem_key) + 1)) != 0) { + lua_pushnil(L); + lua_pushfstring(L, "Failed to parse PEM key: -0x%04x", -ret); + mbedtls_pk_free(&key); + return 2; + } + + lua_newtable(L); + + if (mbedtls_pk_get_type(&key) == MBEDTLS_PK_RSA) { + lua_pushstring(L, "RSA"); + lua_setfield(L, -2, "kty"); + const mbedtls_rsa_context *rsa = mbedtls_pk_rsa(key); + size_t n_len = mbedtls_mpi_size(&rsa->N); + size_t e_len = mbedtls_mpi_size(&rsa->E); + unsigned char *n = malloc(n_len); + unsigned char *e = malloc(e_len); + if (!n || !e) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + free(n); + free(e); + mbedtls_pk_free(&key); + return 2; + } + mbedtls_mpi_write_binary(&rsa->N, n, n_len); + mbedtls_mpi_write_binary(&rsa->E, e, e_len); + char *n_b64 = b64url_encode(n, n_len); + char *e_b64 = b64url_encode(e, e_len); + lua_pushstring(L, n_b64); + lua_setfield(L, -2, "n"); + lua_pushstring(L, e_b64); + lua_setfield(L, -2, "e"); + // If private key, add private fields + if (mbedtls_rsa_check_privkey(rsa) == 0 && rsa->D.p) { + size_t d_len = mbedtls_mpi_size(&rsa->D); + size_t p_len = mbedtls_mpi_size(&rsa->P); + size_t q_len = mbedtls_mpi_size(&rsa->Q); + size_t dp_len = mbedtls_mpi_size(&rsa->DP); + size_t dq_len = mbedtls_mpi_size(&rsa->DQ); + size_t qi_len = mbedtls_mpi_size(&rsa->QP); + unsigned char *d = malloc(d_len), *p = malloc(p_len), *q = malloc(q_len), + *dp = malloc(dp_len), *dq = malloc(dq_len), + *qi = malloc(qi_len); + if (!d || !p || !q || !dp || !dq || !qi) { + free(d); + free(p); + free(q); + free(dp); + free(dq); + free(qi); + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + mbedtls_pk_free(&key); + return 2; + } + mbedtls_mpi_write_binary(&rsa->D, d, d_len); + mbedtls_mpi_write_binary(&rsa->P, p, p_len); + mbedtls_mpi_write_binary(&rsa->Q, q, q_len); + mbedtls_mpi_write_binary(&rsa->DP, dp, dp_len); + mbedtls_mpi_write_binary(&rsa->DQ, dq, dq_len); + mbedtls_mpi_write_binary(&rsa->QP, qi, qi_len); + char *d_b64 = b64url_encode(d, d_len); + char *p_b64 = b64url_encode(p, p_len); + char *q_b64 = b64url_encode(q, q_len); + char *dp_b64 = b64url_encode(dp, dp_len); + char *dq_b64 = b64url_encode(dq, dq_len); + char *qi_b64 = b64url_encode(qi, qi_len); + lua_pushstring(L, d_b64); + lua_setfield(L, -2, "d"); + lua_pushstring(L, p_b64); + lua_setfield(L, -2, "p"); + lua_pushstring(L, q_b64); + lua_setfield(L, -2, "q"); + lua_pushstring(L, dp_b64); + lua_setfield(L, -2, "dp"); + lua_pushstring(L, dq_b64); + lua_setfield(L, -2, "dq"); + lua_pushstring(L, qi_b64); + lua_setfield(L, -2, "qi"); + free(d); + free(p); + free(q); + free(dp); + free(dq); + free(qi); + free(d_b64); + free(p_b64); + free(q_b64); + free(dp_b64); + free(dq_b64); + free(qi_b64); + } + free(n); + free(e); + free(n_b64); + free(e_b64); + } else if (mbedtls_pk_get_type(&key) == MBEDTLS_PK_ECKEY) { + // Handle ECDSA keys + const mbedtls_ecp_keypair *ec = mbedtls_pk_ec(key); + const mbedtls_ecp_point *Q = &ec->Q; + size_t x_len = (ec->grp.pbits + 7) / 8; + size_t y_len = (ec->grp.pbits + 7) / 8; + unsigned char *x = malloc(x_len); + unsigned char *y = malloc(y_len); + if (!x || !y) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + free(x); + free(y); + mbedtls_pk_free(&key); + return 2; + } + mbedtls_mpi_write_binary(&Q->X, x, x_len); + mbedtls_mpi_write_binary(&Q->Y, y, y_len); + char *x_b64 = b64url_encode(x, x_len); + char *y_b64 = b64url_encode(y, y_len); + // Set kty and crv for EC keys + lua_pushstring(L, "EC"); + lua_setfield(L, -2, "kty"); + const mbedtls_ecp_curve_info *curve_info = + mbedtls_ecp_curve_info_from_grp_id(ec->grp.id); + if (curve_info && curve_info->name) { + lua_pushstring(L, curve_info->name); + lua_setfield(L, -2, "crv"); + } else { + lua_pushstring(L, "unknown"); + lua_setfield(L, -2, "crv"); + } + lua_pushstring(L, x_b64); + lua_setfield(L, -2, "x"); + lua_pushstring(L, y_b64); + lua_setfield(L, -2, "y"); + // If private key, add 'd' + if (mbedtls_ecp_check_privkey(&ec->grp, &ec->d) == 0 && ec->d.p) { + size_t d_len = mbedtls_mpi_size(&ec->d); + unsigned char *d = malloc(d_len); + if (!d) { + free(x); + free(y); + free(x_b64); + free(y_b64); + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + mbedtls_pk_free(&key); + return 2; + } + mbedtls_mpi_write_binary(&ec->d, d, d_len); + char *d_b64 = b64url_encode(d, d_len); + lua_pushstring(L, d_b64); + lua_setfield(L, -2, "d"); + free(d); + free(d_b64); + } + free(x); + free(y); + free(x_b64); + free(y_b64); + } else { + lua_pushnil(L); + lua_pushstring(L, "Unsupported key type"); + mbedtls_pk_free(&key); + return 2; + } + + mbedtls_pk_free(&key); + + // Merge additional claims if provided and compatible with RFC7517 + if (has_claims) { + static const char *allowed[] = {"kty", "use", "sig", "key_ops", + "alg", "kid", "x5u", "x5c", + "x5t", "x5t#S256", NULL}; + lua_pushnil(L); // first key + while (lua_next(L, 2) != 0) { + const char *k = lua_tostring(L, -2); + int allowed_key = 0; + for (int i = 0; allowed[i]; ++i) { + if (strcmp(k, allowed[i]) == 0) { + allowed_key = 1; + break; + } + } + if (allowed_key) { + lua_pushvalue(L, -2); + lua_insert(L, -2); + lua_settable(L, -4); + } else { + lua_pop(L, 1); + } + } + } + + return 1; +} + +// CSR creation Function +static int LuaGenerateCSR(lua_State *L) { + const char *key_pem = luaL_checkstring(L, 1); + const char *subject_name; + const char *san_list = luaL_optstring(L, 3, NULL); + + if (lua_isnoneornil(L, 2)) { + subject_name = ""; + } else { + subject_name = luaL_checkstring(L, 2); + } + + if (lua_isnoneornil(L, 3) && subject_name[0] == '\0') { + lua_pushnil(L); + lua_pushstring(L, "Subject name or SANs are required"); + return 2; + } + mbedtls_pk_context key; + mbedtls_x509write_csr req; + char buf[4096]; + int ret; + + mbedtls_pk_init(&key); + mbedtls_x509write_csr_init(&req); + + if ((ret = mbedtls_pk_parse_key(&key, (const unsigned char *)key_pem, + strlen(key_pem) + 1, NULL, 0)) != 0) { + lua_pushnil(L); + lua_pushfstring(L, "Failed to parse key: %d", ret); + return 2; + } + + mbedtls_x509write_csr_set_subject_name(&req, subject_name); + mbedtls_x509write_csr_set_key(&req, &key); + mbedtls_x509write_csr_set_md_alg(&req, MBEDTLS_MD_SHA256); + + if (san_list) { + if ((ret = mbedtls_x509write_csr_set_extension( + &req, MBEDTLS_OID_SUBJECT_ALT_NAME, + MBEDTLS_OID_SIZE(MBEDTLS_OID_SUBJECT_ALT_NAME), + (const unsigned char *)san_list, strlen(san_list))) != 0) { + lua_pushnil(L); + lua_pushfstring(L, "Failed to set SANs: %d", ret); + return 2; + } + } + + if ((ret = mbedtls_x509write_csr_pem(&req, (unsigned char *)buf, sizeof(buf), + NULL, NULL)) < 0) { + lua_pushnil(L); + lua_pushfstring(L, "Failed to write CSR: %d", ret); + return 2; + } + + lua_pushstring(L, buf); + + mbedtls_pk_free(&key); + mbedtls_x509write_csr_free(&req); + + return 1; +} + +// LuaCrypto compatible API +static int LuaCryptoSign(lua_State *L) { + // Type of signature (e.g., "rsa", "ecdsa", "rsa-pss") + const char *dtype = luaL_checkstring(L, 1); + // Remove the first argument (key type or cipher type) before dispatching + lua_remove(L, 1); + + if (strcasecmp(dtype, "rsa") == 0) { + return LuaRSASign(L); + } else if (strcasecmp(dtype, "rsa-pss") == 0 || + strcasecmp(dtype, "rsapss") == 0) { + return LuaRSAPSSSign(L); + } else if (strcasecmp(dtype, "ecdsa") == 0) { + return LuaECDSASign(L); + } else { + return luaL_error(L, "Unsupported signature type: %s", dtype); + } +} + +static int LuaCryptoVerify(lua_State *L) { + // Type of signature (e.g., "rsa", "ecdsa", "rsa-pss") + const char *dtype = luaL_checkstring(L, 1); + // Remove the first argument (key type or cipher type) before dispatching + lua_remove(L, 1); + + if (strcasecmp(dtype, "rsa") == 0) { + return LuaRSAVerify(L); + } else if (strcasecmp(dtype, "rsa-pss") == 0 || + strcasecmp(dtype, "rsapss") == 0) { + return LuaRSAPSSVerify(L); + } else if (strcasecmp(dtype, "ecdsa") == 0) { + return LuaECDSAVerify(L); + } else { + return luaL_error(L, "Unsupported signature type: %s", dtype); + } +} + +static int LuaCryptoEncrypt(lua_State *L) { + // Args: cipher_type, key, msg, options table + const char *cipher = luaL_checkstring(L, 1); + // Remove cipher_type from stack, so key is at 1, msg at 2, options at 3 + lua_remove(L, 1); + + if (strcasecmp(cipher, "rsa") == 0) { + return LuaRSAEncrypt(L); + } else if (strcasecmp(cipher, "aes") == 0) { + return LuaAesEncrypt(L); + } else { + return luaL_error(L, "Unsupported cipher type: %s", cipher); + } +} + +static int LuaCryptoDecrypt(lua_State *L) { + // Args: cipher_type, key, ciphertext, options table + const char *cipher = luaL_checkstring(L, 1); + // Remove cipher_type, so key is at 1, ciphertext at 2, options at 3 + lua_remove(L, 1); + + if (strcasecmp(cipher, "rsa") == 0) { + return LuaRSADecrypt(L); + } else if (strcasecmp(cipher, "aes") == 0) { + return LuaAesDecrypt(L); + } else { + return luaL_error(L, "Unsupported cipher type: %s", cipher); + } +} + +static int LuaCryptoGenerateKeyPair(lua_State *L) { + // If the first argument is a number, treat it as RSA key length + if (lua_gettop(L) >= 1 && lua_type(L, 1) == LUA_TNUMBER) { + // Call LuaRSAGenerateKeyPair with the number as the key length + return LuaRSAGenerateKeyPair(L); + } + // Otherwise, get the key type from the first argument, default to "rsa" + const char *type = luaL_optstring(L, 1, "rsa"); + lua_remove(L, 1); + + if (strcasecmp(type, "rsa") == 0) { + return LuaRSAGenerateKeyPair(L); + } else if (strcasecmp(type, "ecdsa") == 0) { + return LuaECDSAGenerateKeyPair(L); + } else if (strcasecmp(type, "aes") == 0) { + return LuaAesGenerateKey(L); + } else { + return luaL_error(L, "Unsupported key type: %s", type); + } +} + +// Returns a Lua table array of supported digests and ciphers (strings), +// depending on the type argument: +// "ciphers" - returns list of ciphers supported by crypto.encrypt and +// crypto.decrypt "digests" - returns list of digests in supported_digests +// "curves" - returns list of curves in supported_curves +// If no argument is provided, returns a table with all three types +static int LuaList(lua_State *L) { + // Create a new table to hold the result + lua_newtable(L); + + // No argument provided - return all types in a structured table + if (lua_isnoneornil(L, 1)) { + // Create subtable for digests + lua_pushstring(L, "digests"); + lua_newtable(L); + const digest_map_t *digest = supported_digests; + int i = 1; + while (digest->name != NULL) { + lua_pushstring(L, digest->name); + lua_rawseti(L, -2, i++); + digest++; + } + lua_settable(L, -3); + + // Create subtable for curves + lua_pushstring(L, "curves"); + lua_newtable(L); + const curve_map_t *curve = supported_curves; + i = 1; + while (curve->name != NULL) { + lua_pushstring(L, curve->name); + lua_rawseti(L, -2, i++); + curve++; + } + lua_settable(L, -3); + + // Create subtable for ciphers + lua_pushstring(L, "ciphers"); + lua_newtable(L); + const ciphers_map_t *cipher = supported_ciphers; + i = 1; + while (cipher->name != NULL) { + lua_pushstring(L, cipher->name); + lua_rawseti(L, -2, i++); + cipher++; + } + lua_settable(L, -3); + + return 1; + } + + // Argument provided - handle specific type + const char *type = luaL_checkstring(L, 1); + + if (strcasecmp(type, "curves") == 0) { + // List all available curves + const curve_map_t *curve = supported_curves; + int i = 1; + + while (curve->name != NULL) { + lua_pushstring(L, curve->name); + lua_rawseti(L, -2, i++); + curve++; + } + } else if (strcasecmp(type, "digests") == 0) { + // List all available digests + const digest_map_t *digest = supported_digests; + int i = 1; + + while (digest->name != NULL) { + lua_pushstring(L, digest->name); + lua_rawseti(L, -2, i++); + digest++; + } + } else if (strcasecmp(type, "ciphers") == 0) { + // List all available ciphers + const ciphers_map_t *cipher = supported_ciphers; + int i = 1; + + while (cipher->name != NULL) { + lua_pushstring(L, cipher->name); + lua_rawseti(L, -2, i++); + cipher++; + } + } else { + // Invalid type, return empty table + lua_pushstring(L, "Invalid type. Use 'ciphers', 'digests', or 'curves'"); + lua_setfield(L, -2, "error"); + } + + return 1; // Return the table +} + +static const luaL_Reg kLuaCrypto[] = { + {"sign", LuaCryptoSign}, // + {"verify", LuaCryptoVerify}, // + {"encrypt", LuaCryptoEncrypt}, // + {"decrypt", LuaCryptoDecrypt}, // + {"generateKeyPair", LuaCryptoGenerateKeyPair}, // + {"convertJwkToPem", LuaConvertJwkToPem}, // + {"convertPemToJwk", LuaConvertPemToJwk}, // + {"generateCsr", LuaGenerateCSR}, // + {"list", LuaList}, // + {0}, // +}; + +int LuaCrypto(lua_State *L) { + luaL_newlib(L, kLuaCrypto); + return 1; +} diff --git a/tool/net/lcrypto.h b/tool/net/lcrypto.h new file mode 100644 index 00000000000..0e1cac87239 --- /dev/null +++ b/tool/net/lcrypto.h @@ -0,0 +1,9 @@ +#ifndef COSMOPOLITAN_TOOL_NET_LCRYPTO_H_ +#define COSMOPOLITAN_TOOL_NET_LCRYPTO_H_ +#include "third_party/lua/lauxlib.h" +COSMOPOLITAN_C_START_ + +int LuaCrypto(lua_State *L); + +COSMOPOLITAN_C_END_ +#endif /* COSMOPOLITAN_TOOL_NET_LCRYPTO_H_ */ diff --git a/tool/net/lfuncs.h b/tool/net/lfuncs.h index 7bc3fc748ff..4fcbd0fa51e 100644 --- a/tool/net/lfuncs.h +++ b/tool/net/lfuncs.h @@ -8,6 +8,7 @@ int LuaMaxmind(lua_State *); int LuaRe(lua_State *); int luaopen_argon2(lua_State *); int luaopen_lsqlite3(lua_State *); +int LuaCrypto(lua_State *); int LuaBarf(lua_State *); int LuaBenchmark(lua_State *); diff --git a/tool/net/redbean.c b/tool/net/redbean.c index 93816d1aa43..c7b9de60134 100644 --- a/tool/net/redbean.c +++ b/tool/net/redbean.c @@ -5426,6 +5426,7 @@ static const luaL_Reg kLuaLibs[] = { {"path", LuaPath}, // {"re", LuaRe}, // {"unix", LuaUnix}, // + {"crypto", LuaCrypto}, // }; static void LuaSetArgv(lua_State *L) {