DEV Community

Cover image for How I Used TPM for Key Encryption in Rust (Using Windows APIs)
tsuruko
tsuruko

Posted on

How I Used TPM for Key Encryption in Rust (Using Windows APIs)

I implemented TPM-based encryption on Windows using windows-sys crate.

At first, I tried using SRK for key wrapping with tss-esapi crate, but since it's designed for Linux, using it on Windows would require a more complex setup.
For this reason, I chose the windows-sys crate for the implementation.

Table of Contents


Cargo.toml

windows-sys = { version = "0.61", features = [
    "Win32_Security_Cryptography",
    "Win32_Foundation",
] }
zeroize = "1.8"
Enter fullscreen mode Exit fullscreen mode

Implementation

This library uses FFI, so you need to write an unsafe block.

use std::{
    ffi::{OsStr, c_void},
    os::windows::ffi::OsStrExt,
    ptr,
};

use windows_sys::{
    core::HRESULT,
    Win32::{
        Foundation::NTE_BAD_KEYSET, 
        Security::Cryptography::*,
    },
};
use zeroize::Zeroize;

const KEY_NAME: &str = "RSA_KEY";

fn open_provider() -> Result<NCRYPT_PROV_HANDLE, HRESULT> {
    unsafe{
        let mut hprov: NCRYPT_PROV_HANDLE = 0;

        let status = NCryptOpenStorageProvider(
            &mut hprov,
            MS_PLATFORM_CRYPTO_PROVIDER,
            0,
        );
        if status != 0 {
            Err(status)
        } else {
            Ok(hprov)
        }        
    }
}

fn create_padding_info() -> BCRYPT_OAEP_PADDING_INFO {
    BCRYPT_OAEP_PADDING_INFO {
        pszAlgId: BCRYPT_SHA256_ALGORITHM,
        pbLabel: ptr::null_mut(),
        cbLabel: 0,
    }
}

fn to_utf16(s: &str) -> Vec<u16> {
    let mut utf16: Vec<u16> = OsStr::new(s).encode_wide().collect();
    utf16.push(0);

    utf16
}

// The target key is assumed a 32 bytes (AES-256), 
// which is generally recommended.

// The status is an i32 (HRESULT) value, 
// and 0 means the function was successful.

fn wrap_key(mut target_key: [u8; 32]) -> Result<Vec<u8>, HRESULT> {
    unsafe {
        let hprov = open_provider()?;

        let mut hkey: NCRYPT_KEY_HANDLE = 0;
        let key_name = to_utf16(KEY_NAME);

        // get the key that’s registered in the TPM
        let status = NCryptOpenKey(
            hprov,
            &mut hkey,
            key_name.as_ptr(),
            0,
            0,
        );
        if status != 0 {
            match status {
                NTE_BAD_KEYSET => {
                    // this error means the key name isn't registered in the TPM

                    // create a persisted RSA key with the given name
                    let status_2 = NCryptCreatePersistedKey(
                        hprov,
                        &mut hkey,
                        BCRYPT_RSA_ALGORITHM,
                        key_name.as_ptr(),
                        0,
                        0,
                    );
                    if status_2 != 0 {
                        NCryptFreeObject(hprov);
                        return Err(status_2)
                    }

                    let status_2 = NCryptFinalizeKey(
                        hkey,
                        0,
                    );
                    if status_2 != 0 {
                        NCryptFreeObject(hkey);
                        NCryptFreeObject(hprov);
                        return Err(status_2)
                    }
                },
                _ => {
                    NCryptFreeObject(hprov);
                    return Err(status)
                },
            }
        } 

        // get encrypted data size
        let mut size: u32 = 0;
        let padding_info = create_padding_info();

        let status = NCryptEncrypt(
            hkey,
            target_key.as_ptr(),
            target_key.len() as u32,
            &padding_info as *const _ as *const c_void,
            ptr::null_mut(),
            0,
            &mut size,
            NCRYPT_PAD_OAEP_FLAG,
        );
        if status != 0 {
            NCryptFreeObject(hkey);
            NCryptFreeObject(hprov);
            return Err(status)
        }

        // wrap the target key
        let mut wrapped_key = vec![0u8; size as usize];

        let status = NCryptEncrypt(
            hkey,
            target_key.as_ptr(),
            target_key.len() as u32,
            &padding_info as *const _ as *const c_void,
            wrapped_key.as_mut_ptr(),
            size,
            &mut size,
            NCRYPT_PAD_OAEP_FLAG,
        );
        if status != 0 {
            NCryptFreeObject(hkey);
            NCryptFreeObject(hprov);
            return Err(status)
        }

        // zeroize the original key
        target_key.zeroize();

        // adjust the buffer to the actual encrypted data size
        wrapped_key.truncate(size as usize);
        NCryptFreeObject(hkey);
        NCryptFreeObject(hprov);

        Ok(wrapped_key)
    }
}

fn unwrap_key(target_key: Vec<u8>) -> Result<Vec<u8>, HRESULT> {
    // almost the same process as encryption 
    unsafe {
        let hprov = open_provider()?;

        let mut hkey: NCRYPT_KEY_HANDLE = 0;
        let key_name = to_utf16(KEY_NAME);

        let status = NCryptOpenKey(
            hprov,
            &mut hkey,
            key_name.as_ptr(),
            0,
            0,
        );
        if status != 0 {
            NCryptFreeObject(hprov);
            return Err(status)
        }

        // get decrypted data size
        let mut size: u32 = 0;
        let padding_info = create_padding_info();

        let status = NCryptDecrypt(
            hkey,
            target_key.as_ptr(),
            target_key.len() as u32,
            &padding_info as *const _ as *const c_void,
            ptr::null_mut(),
            0,
            &mut size,
            NCRYPT_PAD_OAEP_FLAG,
        );
        if status != 0 {
            NCryptFreeObject(hkey);
            NCryptFreeObject(hprov);
            return Err(status)      
        }

        // unwap the target key
        let mut unwrapped_key = vec![0u8; size as usize];

        let status = NCryptDecrypt(
            hkey,
            target_key.as_ptr(),
            target_key.len() as u32,
            &padding_info as *const _ as *const c_void,
            unwrapped_key.as_mut_ptr(),
            size,
            &mut size,
            NCRYPT_PAD_OAEP_FLAG,
        );
        if status != 0 {
            NCryptFreeObject(hkey);
            NCryptFreeObject(hprov);
            return Err(status)      
        }

        // adjust the buffer to the actual decrypted data size
        unwrapped_key.truncate(size as usize);
        NCryptFreeObject(hkey);
        NCryptFreeObject(hprov);

        Ok(unwrapped_key)  
    }
}
Enter fullscreen mode Exit fullscreen mode

Try running it in the main function!

use rand::RngCore;

fn main() {
    let mut key = [0u8; 32];
    let mut rng = rand::rng();
    rng.fill_bytes(&mut key);

    println!("Original Key: {:?}", key);

    let wrapped_key = match wrap_key(key) {
        Ok(k) => {
            println!("Wrapped Key: {:?}", k);
            k
        },
        Err(e) => {
            println!("Error: {e}");
            return
        }, 
    };

    match unwrap_key(wrapped_key) {
        Ok(k) => println!("Unwrapped Key: {:?}", k),
        Err(e) => println!("Error: {e}"), 
    }
}
Enter fullscreen mode Exit fullscreen mode

The unwrapped key will match the original key!

If you want to delete the registered key:

fn delete_key() -> Result<(), HRESULT> {
    unsafe{
        let hprov = open_provider()?;

        let mut hkey: NCRYPT_KEY_HANDLE = 0;
        let key_name = to_utf16(KEY_NAME);

        let status = NCryptOpenKey(
            hprov,
            &mut hkey,
            key_name.as_ptr(),
            0,
            0,
        );
        if status != 0 {
            NCryptFreeObject(hprov);
            return Err(status)
        }

        // delete the registered key
        let status = NCryptDeleteKey(
            hkey,
            0,
        );
        if status != 0 {
            NCryptFreeObject(hkey);
            NCryptFreeObject(hprov);
            Err(status)
        } else {
            // the key handle is automatically released by NCryptDeleteKey 
            NCryptFreeObject(hprov);
            Ok(())
        } 
    }
}
Enter fullscreen mode Exit fullscreen mode

Explanation

I referred to the official Microsoft documentation to decide which functions to use and understand the meaning of each argument, as well as the official windows-sys documentation to check argument types!

References:

I'm going to explain each function and argument used in that code, which I learned from these official documents.


1: Open the TPM provider

First, you need to open the provider using the NCryptOpenStorageProvider function.

Microsoft Documentation:

windows-sys Documentation:

phprovider needs to be a mutable usize to receive the provider handle.

pszprovidername has 3 possible options:

MS_PLATFORM_CRYPTO_PROVIDER is used as the target in this case, but you can also use MS_KEY_STORAGE_PROVIDER instead if the PC doesn't have a TPM.

The dwflags parameter is reserved, so it should always be set to 0 for now.

let mut hprov: NCRYPT_PROV_HANDLE = 0;

let status = NCryptOpenStorageProvider(
    &mut hprov,
    MS_PLATFORM_CRYPTO_PROVIDER,
    0,
);
Enter fullscreen mode Exit fullscreen mode

Some of the return codes:

If the function fails, you shouldn’t access or release the handle!

↓ It’s described in the Microsoft documentation


2: Create a persisted key (If it doesn't exist)

The NCryptCreatePersistedKey can create a persisted key with a given name.

Microsoft Documentation:

windows-sys Documentation:

phkey needs to be a mutable usize reference that receives the key handle.

pszalgid needs to be a null-terminated Unicode string that specifies the cryptographic algorithm.

↓ You can check out the standard CNG Algorithm Identifiers from the link below!

CNG Algorithm Identifiers

I chose BCRYPT_RSA_ALGORITHM since it’s for key wrapping.

pszkeyname needs a reference to a null-terminated Unicode string containing the key name.
If the parameter is set to std::ptr::null(), this function will create an ephemeral key that won't be stored.

fn to_utf16(s: &str) -> Vec<u16> {
    let mut utf16: Vec<u16> = OsStr::new(s).encode_wide().collect();
    utf16.push(0);

    utf16
}
Enter fullscreen mode Exit fullscreen mode

encode_wide() helps convert an &OsStr into UTF-16.
And since pszkeyname needs a null at the end, a 0 is pushed to the string.

dwlegacykeyspec has 3 possible options:

In CNG, this value should be set to 0, since the other options are for CryptoAPI. If you set it to anything other than 0, it'll cause an error (NTE_NOT_SUPPORTED).

dwflags has 5 options:

But these options aren't commonly used for general TPM operations.
If no options are needed, set this value to 0.

let mut hkey: NCRYPT_KEY_HANDLE = 0;
let key_name = to_utf16(KEY_NAME);

let status_2 = NCryptCreatePersistedKey(
    hprov,
    &mut hkey,
    BCRYPT_RSA_ALGORITHM,
    key_name.as_ptr(),
    0,
    0,
);
if status_2 != 0 {
    NCryptFreeObject(hprov); // release provider handle
    return Err(status_2)
}
Enter fullscreen mode Exit fullscreen mode

If the function fails, you should call the NCryptFreeObject function to release the provider handle but not the key handle!

Some of the return codes:


NCryptSetProperty function can be used to set key properties, but it’s skipped here.

After creating the key, you need to call the NCryptFinalizeKey function to use it.

Microsoft Documentation:

windows-sys Documentation:

dwflags has 3 options:

NCRYPT_SILENT_FLAG is used to prevent any user interface from being displayed, but TPM doesn't have one.

If no options are needed, set this value to 0.

let status_2 != 0 (
    hkey,
    0,
);
if status_2 != 0 {
    NCryptFreeObject(hkey);
    NCryptFreeObject(hprov);
    return Err(status_2)
}
Enter fullscreen mode Exit fullscreen mode

If the function fails, you should call the NCryptFreeObject function, as well as the NCryptCreatePersistedKey function, to release both the provider and key handles!

Some of the return codes:


2: Get a key from the provider (If it already exists)

If a key has already existed, you can retrieve it from the provider using the NCryptOpenKey function.

Microsoft Documentation:

windows-sys Documentation:

These parameters are the same as those of NCryptCreatePersistedKey, except for the cryptographic algorithm.

let mut hkey: NCRYPT_KEY_HANDLE = 0;
let key_name = to_utf16(KEY_NAME);

let status = NCryptOpenKey(
    hprov,
    &mut hkey,
    key_name.as_ptr(),
    0,
    0,
);
Enter fullscreen mode Exit fullscreen mode

If the specified key name doesn't exist, it'll cause NTE_BAD_KEYSET error.
You can handle return codes with match or if statements by comparing the constant with the returned value!

if status != 0 {
    match status {
        NTE_BAD_KEYSET => {
            // this error means the key name isn't registered in the TPM

            // create a persisted RSA key with the given name
            let status_2 = NCryptCreatePersistedKey(
                hprov,
                &mut hkey,
                BCRYPT_RSA_ALGORITHM,
                key_name.as_ptr(),
                0,
                0,
            );
            if status_2 != 0 {
                NCryptFreeObject(hprov);
                return Err(status_2)
            }

            let status_2 = NCryptFinalizeKey(
                hkey,
                0,
            );
            if status_2 != 0 {
                NCryptFreeObject(hprov);
                return Err(status_2)
            }
        },
        _ => {
            NCryptFreeObject(hprov);
            return Err(status)
        },
    }
} 
Enter fullscreen mode Exit fullscreen mode

Some of the return codes:


3: Get the encrypted data size

You need to get the size of the encrypted data before performing the encryption.

Microsoft Documentation:

windows-sys Documentation:

pbinput needs to a reference to a u8 buffer that contains the target data to be encrypted.

cdinput needs the numbers of bytes for the target data.


ppaddinginfo is used for asymmetric keys and it needs a BCRYPT_OAEP_PADDING_INFO structure.

Microsoft Documentation:

windows-sys Documentation:

pszAlgId must be one of the hashing algorithms used for creating the padding.

↓ You can check them out from the link below

CNG Algorithm Identifiers

pbLabel is used when you need to add a label, and it needs a reference to a mutable u8 buffer that contains the data for padding creation.
If it’s not needed, set this value to std::ptr::null_mut().

cbLabel is also used when you need to add a label, and needs the number of bytes in the pbLabel buffer as a u32 value.

If it’s not needed, set this value to 0.

fn create_padding_info() -> BCRYPT_OAEP_PADDING_INFO {
    BCRYPT_OAEP_PADDING_INFO {
        pszAlgId: BCRYPT_SHA256_ALGORITHM,
        pbLabel: ptr::null_mut(),
        cbLabel: 0,
    }
}
Enter fullscreen mode Exit fullscreen mode

I used BCRYPT_SHA256_ALGORITHM as a hashing algorithm in that code.


pboutput needs a reference to a mutable u8 buffer to receive the encrypted data.

If you only want to get the size of the encrypted data, set it to std::ptr::null_mut().

cboutput needs a u32 value that specifies the size of the buffer to receive the encrypted data.

If pboutput is set to std::ptr::null_mut(), this parameter is ignored, but when I set it to anything other than 0, it caused NTE_INVALID_PARAMETER error.

pcbresult needs a reference to a mutable u32.

If pboutput is set to std::ptr::null_mut(), it receives the size of the encrypted data.

dwflags has 4 options for asymmetric keys:

NCRYPT_PAD_OAEP_FLAG is recommended for modern security.

It needs a reference to a BCRYPT_OAEP_PADDING_INFO structure when used.

For symmetric keys, set this value to 0.

let mut size: u32 = 0;
let padding_info = create_padding_info();

let status = NCryptEncrypt(
    hkey,
    target_key.as_ptr(),
    target_key.len() as u32,
    &padding_info as *const _ as *const c_void,
    ptr::null_mut(), // set to null when getting the data size
    0,               // set to 0 when getting the data size
    &mut size,
    NCRYPT_PAD_OAEP_FLAG,
);
if status != 0 {
    NCryptFreeObject(hkey);
    NCryptFreeObject(hprov);
    return Err(status)
}
Enter fullscreen mode Exit fullscreen mode

If the function fails, you should call the NCryptFreeObject function to release both the provider and key handles!

Some of the return codes:


4: Wrap the target key using the created key

Finally, wrap the key using the NCryptEncrypt function!

let mut wrapped_key = vec![0u8; size as usize];

let status = NCryptEncrypt(
    hkey,
    target_key.as_ptr(),
    target_key.len() as u32,
    &padding_info as *const _ as *const c_void,
    wrapped_key.as_mut_ptr(),
    size,
    &mut size,
    NCRYPT_PAD_OAEP_FLAG,
);
if status != 0 {
    NCryptFreeObject(hkey);
    NCryptFreeObject(hprov);
    return Err(status)
}
Enter fullscreen mode Exit fullscreen mode

The size you got from the previous process needs to be set to cboutput for the buffer size and pcbresult for the actual encrypted data size.

If the function fails, you should call the NCryptFreeObject function to release both the provider and key handles!

And it needs to truncate to the actual encrypted data size because wrapped_key’s size is usize.

wrapped_key.truncate(size as usize);
NCryptFreeObject(hkey);
NCryptFreeObject(hprov);
Enter fullscreen mode Exit fullscreen mode

Lastly, don't forget to release these handles before returning from the function!!


Decrypt the target key

Decrypting the encrypted key is almost the same process as encryption — just replace the NCryptEncrypt with NCryptDecrypt function.

let mut size: u32 = 0;
let padding_info = create_padding_info();

let status = NCryptDecrypt(
    hkey,
    target_key.as_ptr(),
    target_key.len() as u32,
    &padding_info as *const _ as *const c_void,
    ptr::null_mut(),
    0,
    &mut size,
    NCRYPT_PAD_OAEP_FLAG,
);
if status != 0 {
    NCryptFreeObject(hkey);
    NCryptFreeObject(hprov);
    return Err(status)      
}

// unwap the target key
let mut unwrapped_key = vec![0u8; size as usize];

let status = NCryptDecrypt(
    hkey,
    target_key.as_ptr(),
    target_key.len() as u32,
    &padding_info as *const _ as *const c_void,
    unwrapped_key.as_mut_ptr(),
    size,
    &mut size,
    NCRYPT_PAD_OAEP_FLAG,
);
if status != 0 {
    NCryptFreeObject(hkey);
    NCryptFreeObject(hprov);
    return Err(status)      
}
Enter fullscreen mode Exit fullscreen mode

Some of the return codes:


Delete registered key

You can delete registered key using the NCryptDeleteKey function, and the process follows the same steps up to retrieving the key.

Microsoft Documentation:

windows-sys Documentation:

dwflags only supports NCRYPT_SILENT_FLAG option.
If you don't need it, set this value to 0.

let status = NCryptDeleteKey(
    hkey,
    0,
);
if status != 0 {
    NCryptFreeObject(hkey);
    NCryptFreeObject(hprov);
    Err(status)
} else {
    NCryptFreeObject(hprov);
    Ok(())
} 
Enter fullscreen mode Exit fullscreen mode

If the function is succeeds, it releases the key handle automatically.
You just need to call the NCryptFreeObject to release the provider handle before returning the value.

↓ It’s described in the Microsoft documentation

Some of the return codes:


Final Thoughts

At first, since I didn’t know much about C++, I struggled to figure out how to read the function signatures...
But through the implementation, I was able to understand them better as well as how TPM works!
Next, I'm going to try implementing TPM for macOS and Linux as well.

Thanks for reading!

Top comments (0)