On Wed, Dec 20, 2023 at 08:43:43PM +0530, Nikunj A Dadhania wrote:
> @@ -307,11 +197,16 @@ static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, void *payload,
> * If the message size is greater than our buffer length then return
> * an error.
> */
> - if (unlikely((resp_hdr->msg_sz + crypto->a_len) > sz))
> + if (unlikely((resp_hdr->msg_sz + ctx->authsize) > sz))
> return -EBADMSG;
>
> /* Decrypt the payload */
> - return dec_payload(snp_dev, resp, payload, resp_hdr->msg_sz + crypto->a_len);
> + memcpy(iv, &resp_hdr->msg_seqno, sizeof(resp_hdr->msg_seqno));
sizeof(iv) != sizeof(resp_hdr->msg_seqno) and it fits now.
However, for protection against future bugs, this should be:
memcpy(iv, &resp_hdr->msg_seqno, min(sizeof(iv), sizeof(resp_hdr->msg_seqno)));
> + if (!aesgcm_decrypt(ctx, payload, resp->payload, resp_hdr->msg_sz,
> + &resp_hdr->algo, AAD_LEN, iv, resp_hdr->authtag))
> + return -EBADMSG;
> +
> + return 0;
> }
>
> static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, int version, u8 type,
> @@ -319,6 +214,8 @@ static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, int version, u8
> {
> struct snp_guest_msg *req = &snp_dev->secret_request;
> struct snp_guest_msg_hdr *hdr = &req->hdr;
> + struct aesgcm_ctx *ctx = snp_dev->ctx;
> + u8 iv[GCM_AES_IV_SIZE] = {};
>
> memset(req, 0, sizeof(*req));
>
> @@ -338,7 +235,14 @@ static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, int version, u8
> dev_dbg(snp_dev->dev, "request [seqno %lld type %d version %d sz %d]\n",
> hdr->msg_seqno, hdr->msg_type, hdr->msg_version, hdr->msg_sz);
>
> - return __enc_payload(snp_dev, req, payload, sz);
> + if (WARN_ON((sz + ctx->authsize) > sizeof(req->payload)))
> + return -EBADMSG;
> +
> + memcpy(iv, &hdr->msg_seqno, sizeof(hdr->msg_seqno));
Ditto.
> + aesgcm_encrypt(ctx, req->payload, payload, sz, &hdr->algo, AAD_LEN,
> + iv, hdr->authtag);
> +
> + return 0;
On 1/25/2024 4:06 PM, Borislav Petkov wrote:
> On Wed, Dec 20, 2023 at 08:43:43PM +0530, Nikunj A Dadhania wrote:
>> @@ -307,11 +197,16 @@ static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, void *payload,
>> * If the message size is greater than our buffer length then return
>> * an error.
>> */
>> - if (unlikely((resp_hdr->msg_sz + crypto->a_len) > sz))
>> + if (unlikely((resp_hdr->msg_sz + ctx->authsize) > sz))
>> return -EBADMSG;
>>
>> /* Decrypt the payload */
>> - return dec_payload(snp_dev, resp, payload, resp_hdr->msg_sz + crypto->a_len);
>> + memcpy(iv, &resp_hdr->msg_seqno, sizeof(resp_hdr->msg_seqno));
>
> sizeof(iv) != sizeof(resp_hdr->msg_seqno) and it fits now.
>
> However, for protection against future bugs, this should be:
>
> memcpy(iv, &resp_hdr->msg_seqno, min(sizeof(iv), sizeof(resp_hdr->msg_seqno)));
Sure, will change.
>
>> + if (!aesgcm_decrypt(ctx, payload, resp->payload, resp_hdr->msg_sz,
>> + &resp_hdr->algo, AAD_LEN, iv, resp_hdr->authtag))
>> + return -EBADMSG;
>> +
>> + return 0;
>> }
>>
>> static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, int version, u8 type,
>> @@ -319,6 +214,8 @@ static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, int version, u8
>> {
>> struct snp_guest_msg *req = &snp_dev->secret_request;
>> struct snp_guest_msg_hdr *hdr = &req->hdr;
>> + struct aesgcm_ctx *ctx = snp_dev->ctx;
>> + u8 iv[GCM_AES_IV_SIZE] = {};
>>
>> memset(req, 0, sizeof(*req));
>>
>> @@ -338,7 +235,14 @@ static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, int version, u8
>> dev_dbg(snp_dev->dev, "request [seqno %lld type %d version %d sz %d]\n",
>> hdr->msg_seqno, hdr->msg_type, hdr->msg_version, hdr->msg_sz);
>>
>> - return __enc_payload(snp_dev, req, payload, sz);
>> + if (WARN_ON((sz + ctx->authsize) > sizeof(req->payload)))
>> + return -EBADMSG;
>> +
>> + memcpy(iv, &hdr->msg_seqno, sizeof(hdr->msg_seqno));
>
> Ditto.
Sure.
>
>> + aesgcm_encrypt(ctx, req->payload, payload, sz, &hdr->algo, AAD_LEN,
>> + iv, hdr->authtag);
>> +
>> + return 0;
>
Thanks,
Nikunj
@@ -2,9 +2,7 @@ config SEV_GUEST
tristate "AMD SEV Guest driver"
default m
depends on AMD_MEM_ENCRYPT
- select CRYPTO
- select CRYPTO_AEAD2
- select CRYPTO_GCM
+ select CRYPTO_LIB_AESGCM
select TSM_REPORTS
help
SEV-SNP firmware provides the guest a mechanism to communicate with
@@ -17,8 +17,7 @@
#include <linux/set_memory.h>
#include <linux/fs.h>
#include <linux/tsm.h>
-#include <crypto/aead.h>
-#include <linux/scatterlist.h>
+#include <crypto/gcm.h>
#include <linux/psp-sev.h>
#include <linux/sockptr.h>
#include <linux/cleanup.h>
@@ -32,24 +31,16 @@
#include "sev-guest.h"
#define DEVICE_NAME "sev-guest"
-#define AAD_LEN 48
-#define MSG_HDR_VER 1
#define SNP_REQ_MAX_RETRY_DURATION (60*HZ)
#define SNP_REQ_RETRY_DELAY (2*HZ)
-struct snp_guest_crypto {
- struct crypto_aead *tfm;
- u8 *iv, *authtag;
- int iv_len, a_len;
-};
-
struct snp_guest_dev {
struct device *dev;
struct miscdevice misc;
void *certs_data;
- struct snp_guest_crypto *crypto;
+ struct aesgcm_ctx *ctx;
/* request and response are in unencrypted memory */
struct snp_guest_msg *request, *response;
@@ -161,132 +152,31 @@ static inline struct snp_guest_dev *to_snp_dev(struct file *file)
return container_of(dev, struct snp_guest_dev, misc);
}
-static struct snp_guest_crypto *init_crypto(struct snp_guest_dev *snp_dev, u8 *key, size_t keylen)
+static struct aesgcm_ctx *snp_init_crypto(u8 *key, size_t keylen)
{
- struct snp_guest_crypto *crypto;
+ struct aesgcm_ctx *ctx;
- crypto = kzalloc(sizeof(*crypto), GFP_KERNEL_ACCOUNT);
- if (!crypto)
+ ctx = kzalloc(sizeof(*ctx), GFP_KERNEL_ACCOUNT);
+ if (!ctx)
return NULL;
- crypto->tfm = crypto_alloc_aead("gcm(aes)", 0, 0);
- if (IS_ERR(crypto->tfm))
- goto e_free;
-
- if (crypto_aead_setkey(crypto->tfm, key, keylen))
- goto e_free_crypto;
-
- crypto->iv_len = crypto_aead_ivsize(crypto->tfm);
- crypto->iv = kmalloc(crypto->iv_len, GFP_KERNEL_ACCOUNT);
- if (!crypto->iv)
- goto e_free_crypto;
-
- if (crypto_aead_authsize(crypto->tfm) > MAX_AUTHTAG_LEN) {
- if (crypto_aead_setauthsize(crypto->tfm, MAX_AUTHTAG_LEN)) {
- dev_err(snp_dev->dev, "failed to set authsize to %d\n", MAX_AUTHTAG_LEN);
- goto e_free_iv;
- }
+ if (aesgcm_expandkey(ctx, key, keylen, AUTHTAG_LEN)) {
+ pr_err("Crypto context initialization failed\n");
+ kfree(ctx);
+ return NULL;
}
- crypto->a_len = crypto_aead_authsize(crypto->tfm);
- crypto->authtag = kmalloc(crypto->a_len, GFP_KERNEL_ACCOUNT);
- if (!crypto->authtag)
- goto e_free_iv;
-
- return crypto;
-
-e_free_iv:
- kfree(crypto->iv);
-e_free_crypto:
- crypto_free_aead(crypto->tfm);
-e_free:
- kfree(crypto);
-
- return NULL;
-}
-
-static void deinit_crypto(struct snp_guest_crypto *crypto)
-{
- crypto_free_aead(crypto->tfm);
- kfree(crypto->iv);
- kfree(crypto->authtag);
- kfree(crypto);
-}
-
-static int enc_dec_message(struct snp_guest_crypto *crypto, struct snp_guest_msg *msg,
- u8 *src_buf, u8 *dst_buf, size_t len, bool enc)
-{
- struct snp_guest_msg_hdr *hdr = &msg->hdr;
- struct scatterlist src[3], dst[3];
- DECLARE_CRYPTO_WAIT(wait);
- struct aead_request *req;
- int ret;
-
- req = aead_request_alloc(crypto->tfm, GFP_KERNEL);
- if (!req)
- return -ENOMEM;
-
- /*
- * AEAD memory operations:
- * +------ AAD -------+------- DATA -----+---- AUTHTAG----+
- * | msg header | plaintext | hdr->authtag |
- * | bytes 30h - 5Fh | or | |
- * | | cipher | |
- * +------------------+------------------+----------------+
- */
- sg_init_table(src, 3);
- sg_set_buf(&src[0], &hdr->algo, AAD_LEN);
- sg_set_buf(&src[1], src_buf, hdr->msg_sz);
- sg_set_buf(&src[2], hdr->authtag, crypto->a_len);
-
- sg_init_table(dst, 3);
- sg_set_buf(&dst[0], &hdr->algo, AAD_LEN);
- sg_set_buf(&dst[1], dst_buf, hdr->msg_sz);
- sg_set_buf(&dst[2], hdr->authtag, crypto->a_len);
-
- aead_request_set_ad(req, AAD_LEN);
- aead_request_set_tfm(req, crypto->tfm);
- aead_request_set_callback(req, 0, crypto_req_done, &wait);
-
- aead_request_set_crypt(req, src, dst, len, crypto->iv);
- ret = crypto_wait_req(enc ? crypto_aead_encrypt(req) : crypto_aead_decrypt(req), &wait);
-
- aead_request_free(req);
- return ret;
-}
-
-static int __enc_payload(struct snp_guest_dev *snp_dev, struct snp_guest_msg *msg,
- void *plaintext, size_t len)
-{
- struct snp_guest_crypto *crypto = snp_dev->crypto;
- struct snp_guest_msg_hdr *hdr = &msg->hdr;
-
- memset(crypto->iv, 0, crypto->iv_len);
- memcpy(crypto->iv, &hdr->msg_seqno, sizeof(hdr->msg_seqno));
-
- return enc_dec_message(crypto, msg, plaintext, msg->payload, len, true);
-}
-
-static int dec_payload(struct snp_guest_dev *snp_dev, struct snp_guest_msg *msg,
- void *plaintext, size_t len)
-{
- struct snp_guest_crypto *crypto = snp_dev->crypto;
- struct snp_guest_msg_hdr *hdr = &msg->hdr;
-
- /* Build IV with response buffer sequence number */
- memset(crypto->iv, 0, crypto->iv_len);
- memcpy(crypto->iv, &hdr->msg_seqno, sizeof(hdr->msg_seqno));
-
- return enc_dec_message(crypto, msg, msg->payload, plaintext, len, false);
+ return ctx;
}
static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, void *payload, u32 sz)
{
- struct snp_guest_crypto *crypto = snp_dev->crypto;
struct snp_guest_msg *resp = &snp_dev->secret_response;
struct snp_guest_msg *req = &snp_dev->secret_request;
struct snp_guest_msg_hdr *req_hdr = &req->hdr;
struct snp_guest_msg_hdr *resp_hdr = &resp->hdr;
+ struct aesgcm_ctx *ctx = snp_dev->ctx;
+ u8 iv[GCM_AES_IV_SIZE] = {};
dev_dbg(snp_dev->dev, "response [seqno %lld type %d version %d sz %d]\n",
resp_hdr->msg_seqno, resp_hdr->msg_type, resp_hdr->msg_version, resp_hdr->msg_sz);
@@ -307,11 +197,16 @@ static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, void *payload,
* If the message size is greater than our buffer length then return
* an error.
*/
- if (unlikely((resp_hdr->msg_sz + crypto->a_len) > sz))
+ if (unlikely((resp_hdr->msg_sz + ctx->authsize) > sz))
return -EBADMSG;
/* Decrypt the payload */
- return dec_payload(snp_dev, resp, payload, resp_hdr->msg_sz + crypto->a_len);
+ memcpy(iv, &resp_hdr->msg_seqno, sizeof(resp_hdr->msg_seqno));
+ if (!aesgcm_decrypt(ctx, payload, resp->payload, resp_hdr->msg_sz,
+ &resp_hdr->algo, AAD_LEN, iv, resp_hdr->authtag))
+ return -EBADMSG;
+
+ return 0;
}
static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, int version, u8 type,
@@ -319,6 +214,8 @@ static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, int version, u8
{
struct snp_guest_msg *req = &snp_dev->secret_request;
struct snp_guest_msg_hdr *hdr = &req->hdr;
+ struct aesgcm_ctx *ctx = snp_dev->ctx;
+ u8 iv[GCM_AES_IV_SIZE] = {};
memset(req, 0, sizeof(*req));
@@ -338,7 +235,14 @@ static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, int version, u8
dev_dbg(snp_dev->dev, "request [seqno %lld type %d version %d sz %d]\n",
hdr->msg_seqno, hdr->msg_type, hdr->msg_version, hdr->msg_sz);
- return __enc_payload(snp_dev, req, payload, sz);
+ if (WARN_ON((sz + ctx->authsize) > sizeof(req->payload)))
+ return -EBADMSG;
+
+ memcpy(iv, &hdr->msg_seqno, sizeof(hdr->msg_seqno));
+ aesgcm_encrypt(ctx, req->payload, payload, sz, &hdr->algo, AAD_LEN,
+ iv, hdr->authtag);
+
+ return 0;
}
static int __handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
@@ -486,7 +390,6 @@ struct snp_req_resp {
static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg)
{
- struct snp_guest_crypto *crypto = snp_dev->crypto;
struct snp_report_req *req = &snp_dev->req.report;
struct snp_report_resp *resp;
int rc, resp_len;
@@ -504,7 +407,7 @@ static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_io
* response payload. Make sure that it has enough space to cover the
* authtag.
*/
- resp_len = sizeof(resp->data) + crypto->a_len;
+ resp_len = sizeof(resp->data) + snp_dev->ctx->authsize;
resp = kzalloc(resp_len, GFP_KERNEL_ACCOUNT);
if (!resp)
return -ENOMEM;
@@ -526,7 +429,6 @@ static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_io
static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg)
{
struct snp_derived_key_req *req = &snp_dev->req.derived_key;
- struct snp_guest_crypto *crypto = snp_dev->crypto;
struct snp_derived_key_resp resp = {0};
int rc, resp_len;
/* Response data is 64 bytes and max authsize for GCM is 16 bytes. */
@@ -542,7 +444,7 @@ static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_reque
* response payload. Make sure that it has enough space to cover the
* authtag.
*/
- resp_len = sizeof(resp.data) + crypto->a_len;
+ resp_len = sizeof(resp.data) + snp_dev->ctx->authsize;
if (sizeof(buf) < resp_len)
return -ENOMEM;
@@ -569,7 +471,6 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
{
struct snp_ext_report_req *req = &snp_dev->req.ext_report;
- struct snp_guest_crypto *crypto = snp_dev->crypto;
struct snp_report_resp *resp;
int ret, npages = 0, resp_len;
sockptr_t certs_address;
@@ -612,7 +513,7 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
* response payload. Make sure that it has enough space to cover the
* authtag.
*/
- resp_len = sizeof(resp->data) + crypto->a_len;
+ resp_len = sizeof(resp->data) + snp_dev->ctx->authsize;
resp = kzalloc(resp_len, GFP_KERNEL_ACCOUNT);
if (!resp)
return -ENOMEM;
@@ -954,8 +855,8 @@ static int __init sev_guest_probe(struct platform_device *pdev)
goto e_free_response;
ret = -EIO;
- snp_dev->crypto = init_crypto(snp_dev, snp_dev->vmpck, VMPCK_KEY_LEN);
- if (!snp_dev->crypto)
+ snp_dev->ctx = snp_init_crypto(snp_dev->vmpck, VMPCK_KEY_LEN);
+ if (!snp_dev->ctx)
goto e_free_cert_data;
misc = &snp_dev->misc;
@@ -978,11 +879,13 @@ static int __init sev_guest_probe(struct platform_device *pdev)
ret = misc_register(misc);
if (ret)
- goto e_free_cert_data;
+ goto e_free_ctx;
dev_info(dev, "Initialized SEV guest driver (using vmpck_id %d)\n", vmpck_id);
return 0;
+e_free_ctx:
+ kfree(snp_dev->ctx);
e_free_cert_data:
free_shared_pages(snp_dev->certs_data, SEV_FW_BLOB_MAX_SIZE);
e_free_response:
@@ -1001,7 +904,7 @@ static int __exit sev_guest_remove(struct platform_device *pdev)
free_shared_pages(snp_dev->certs_data, SEV_FW_BLOB_MAX_SIZE);
free_shared_pages(snp_dev->response, sizeof(struct snp_guest_msg));
free_shared_pages(snp_dev->request, sizeof(struct snp_guest_msg));
- deinit_crypto(snp_dev->crypto);
+ kfree(snp_dev->ctx);
misc_deregister(&snp_dev->misc);
return 0;
@@ -13,6 +13,9 @@
#include <linux/types.h>
#define MAX_AUTHTAG_LEN 32
+#define AUTHTAG_LEN 16
+#define AAD_LEN 48
+#define MSG_HDR_VER 1
/* See SNP spec SNP_GUEST_REQUEST section for the structure */
enum msg_type {