[v2,01/11] virt: sev-guest: Use AES GCM crypto library

Message ID 20230326144701.3039598-2-nikunj@amd.com
State New
Headers
Series Add Secure TSC support for SNP guests |

Commit Message

Nikunj A. Dadhania March 26, 2023, 2:46 p.m. UTC
  SEV-SNP guests with SecureTSC enabled need to send a TSC_INFO SNP
Guest message to the AMD security processor before the smpboot phase
starts. Details from the TSC_INFO response have to be programmed in
the VMSA before the secondary CPUs are brought up.

Start using the AES GCM library implementation as the crypto API is not
available yet this early.

Link: https://lore.kernel.org/all/20221103192259.2229-1-ardb@kernel.org
CC: Ard Biesheuvel <ardb@kernel.org>
Signed-off-by: Nikunj A Dadhania <nikunj@amd.com>
---
 drivers/virt/coco/sev-guest/Kconfig     |   3 +-
 drivers/virt/coco/sev-guest/sev-guest.c | 172 +++++++-----------------
 drivers/virt/coco/sev-guest/sev-guest.h |   3 +
 3 files changed, 53 insertions(+), 125 deletions(-)
  

Comments

Tom Lendacky April 3, 2023, 7:09 p.m. UTC | #1
On 3/26/23 09:46, Nikunj A Dadhania wrote:
> SEV-SNP guests with SecureTSC enabled need to send a TSC_INFO SNP
> Guest message to the AMD security processor before the smpboot phase
> starts. Details from the TSC_INFO response have to be programmed in
> the VMSA before the secondary CPUs are brought up.
> 
> Start using the AES GCM library implementation as the crypto API is not
> available yet this early.

This isn't quite true, yet. You should add that the encryption code will 
be moved out of sev-guest to support Secure TSC, but to make the diffs 
easier to review, convert the Crypto API usage over to AES GCM library 
usage before moving it.

Thanks,
Tom

> 
> Link: https://lore.kernel.org/all/20221103192259.2229-1-ardb@kernel.org
> CC: Ard Biesheuvel <ardb@kernel.org>
> Signed-off-by: Nikunj A Dadhania <nikunj@amd.com>
> ---
>   drivers/virt/coco/sev-guest/Kconfig     |   3 +-
>   drivers/virt/coco/sev-guest/sev-guest.c | 172 +++++++-----------------
>   drivers/virt/coco/sev-guest/sev-guest.h |   3 +
>   3 files changed, 53 insertions(+), 125 deletions(-)
> 
> diff --git a/drivers/virt/coco/sev-guest/Kconfig b/drivers/virt/coco/sev-guest/Kconfig
> index f9db0799ae67..bcc760bfb468 100644
> --- a/drivers/virt/coco/sev-guest/Kconfig
> +++ b/drivers/virt/coco/sev-guest/Kconfig
> @@ -2,8 +2,7 @@ config SEV_GUEST
>   	tristate "AMD SEV Guest driver"
>   	default m
>   	depends on AMD_MEM_ENCRYPT
> -	select CRYPTO_AEAD2
> -	select CRYPTO_GCM
> +	select CRYPTO_LIB_AESGCM
>   	help
>   	  SEV-SNP firmware provides the guest a mechanism to communicate with
>   	  the PSP without risk from a malicious hypervisor who wishes to read,
> diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c
> index 46f1a8d558b0..57af908bafba 100644
> --- a/drivers/virt/coco/sev-guest/sev-guest.c
> +++ b/drivers/virt/coco/sev-guest/sev-guest.c
> @@ -16,8 +16,7 @@
>   #include <linux/miscdevice.h>
>   #include <linux/set_memory.h>
>   #include <linux/fs.h>
> -#include <crypto/aead.h>
> -#include <linux/scatterlist.h>
> +#include <crypto/gcm.h>
>   #include <linux/psp-sev.h>
>   #include <uapi/linux/sev-guest.h>
>   #include <uapi/linux/psp-sev.h>
> @@ -28,24 +27,16 @@
>   #include "sev-guest.h"
>   
>   #define DEVICE_NAME	"sev-guest"
> -#define AAD_LEN		48
> -#define MSG_HDR_VER	1
>   
>   #define SNP_REQ_MAX_RETRY_DURATION	(60*HZ)
>   #define SNP_REQ_RETRY_DELAY		(2*HZ)
>   
> -struct snp_guest_crypto {
> -	struct crypto_aead *tfm;
> -	u8 *iv, *authtag;
> -	int iv_len, a_len;
> -};
> -
>   struct snp_guest_dev {
>   	struct device *dev;
>   	struct miscdevice misc;
>   
>   	void *certs_data;
> -	struct snp_guest_crypto *crypto;
> +	struct aesgcm_ctx *ctx;
>   	struct snp_guest_msg *request, *response;
>   	struct snp_secrets_page_layout *layout;
>   	struct snp_req_data input;
> @@ -60,6 +51,15 @@ MODULE_PARM_DESC(vmpck_id, "The VMPCK ID to use when communicating with the PSP.
>   /* Mutex to serialize the shared buffer access and command handling. */
>   static DEFINE_MUTEX(snp_cmd_mutex);
>   
> +static inline unsigned int get_ctx_authsize(struct snp_guest_dev *snp_dev)
> +{
> +	if (snp_dev && snp_dev->ctx)
> +		return snp_dev->ctx->authsize;
> +
> +	WARN_ONCE(1, "Unable to get crypto authsize\n");
> +	return 0;
> +}
> +
>   static bool is_vmpck_empty(struct snp_guest_dev *snp_dev)
>   {
>   	char zero_key[VMPCK_KEY_LEN] = {0};
> @@ -144,132 +144,59 @@ static inline struct snp_guest_dev *to_snp_dev(struct file *file)
>   	return container_of(dev, struct snp_guest_dev, misc);
>   }
>   
> -static struct snp_guest_crypto *init_crypto(struct snp_guest_dev *snp_dev, u8 *key, size_t keylen)
> +static struct aesgcm_ctx *snp_init_crypto(u8 *key, size_t keylen)
>   {
> -	struct snp_guest_crypto *crypto;
> +	struct aesgcm_ctx *ctx;
>   
> -	crypto = kzalloc(sizeof(*crypto), GFP_KERNEL_ACCOUNT);
> -	if (!crypto)
> +	ctx = kzalloc(sizeof(*ctx), GFP_KERNEL_ACCOUNT);
> +	if (!ctx)
>   		return NULL;
>   
> -	crypto->tfm = crypto_alloc_aead("gcm(aes)", 0, 0);
> -	if (IS_ERR(crypto->tfm))
> -		goto e_free;
> -
> -	if (crypto_aead_setkey(crypto->tfm, key, keylen))
> -		goto e_free_crypto;
> -
> -	crypto->iv_len = crypto_aead_ivsize(crypto->tfm);
> -	crypto->iv = kmalloc(crypto->iv_len, GFP_KERNEL_ACCOUNT);
> -	if (!crypto->iv)
> -		goto e_free_crypto;
> -
> -	if (crypto_aead_authsize(crypto->tfm) > MAX_AUTHTAG_LEN) {
> -		if (crypto_aead_setauthsize(crypto->tfm, MAX_AUTHTAG_LEN)) {
> -			dev_err(snp_dev->dev, "failed to set authsize to %d\n", MAX_AUTHTAG_LEN);
> -			goto e_free_iv;
> -		}
> +	if (aesgcm_expandkey(ctx, key, keylen, AUTHTAG_LEN)) {
> +		pr_err("SNP: crypto init failed\n");
> +		kfree(ctx);
> +		return NULL;
>   	}
>   
> -	crypto->a_len = crypto_aead_authsize(crypto->tfm);
> -	crypto->authtag = kmalloc(crypto->a_len, GFP_KERNEL_ACCOUNT);
> -	if (!crypto->authtag)
> -		goto e_free_iv;
> -
> -	return crypto;
> -
> -e_free_iv:
> -	kfree(crypto->iv);
> -e_free_crypto:
> -	crypto_free_aead(crypto->tfm);
> -e_free:
> -	kfree(crypto);
> -
> -	return NULL;
> +	return ctx;
>   }
>   
> -static void deinit_crypto(struct snp_guest_crypto *crypto)
> -{
> -	crypto_free_aead(crypto->tfm);
> -	kfree(crypto->iv);
> -	kfree(crypto->authtag);
> -	kfree(crypto);
> -}
> -
> -static int enc_dec_message(struct snp_guest_crypto *crypto, struct snp_guest_msg *msg,
> -			   u8 *src_buf, u8 *dst_buf, size_t len, bool enc)
> -{
> -	struct snp_guest_msg_hdr *hdr = &msg->hdr;
> -	struct scatterlist src[3], dst[3];
> -	DECLARE_CRYPTO_WAIT(wait);
> -	struct aead_request *req;
> -	int ret;
> -
> -	req = aead_request_alloc(crypto->tfm, GFP_KERNEL);
> -	if (!req)
> -		return -ENOMEM;
> -
> -	/*
> -	 * AEAD memory operations:
> -	 * +------ AAD -------+------- DATA -----+---- AUTHTAG----+
> -	 * |  msg header      |  plaintext       |  hdr->authtag  |
> -	 * | bytes 30h - 5Fh  |    or            |                |
> -	 * |                  |   cipher         |                |
> -	 * +------------------+------------------+----------------+
> -	 */
> -	sg_init_table(src, 3);
> -	sg_set_buf(&src[0], &hdr->algo, AAD_LEN);
> -	sg_set_buf(&src[1], src_buf, hdr->msg_sz);
> -	sg_set_buf(&src[2], hdr->authtag, crypto->a_len);
> -
> -	sg_init_table(dst, 3);
> -	sg_set_buf(&dst[0], &hdr->algo, AAD_LEN);
> -	sg_set_buf(&dst[1], dst_buf, hdr->msg_sz);
> -	sg_set_buf(&dst[2], hdr->authtag, crypto->a_len);
> -
> -	aead_request_set_ad(req, AAD_LEN);
> -	aead_request_set_tfm(req, crypto->tfm);
> -	aead_request_set_callback(req, 0, crypto_req_done, &wait);
> -
> -	aead_request_set_crypt(req, src, dst, len, crypto->iv);
> -	ret = crypto_wait_req(enc ? crypto_aead_encrypt(req) : crypto_aead_decrypt(req), &wait);
> -
> -	aead_request_free(req);
> -	return ret;
> -}
> -
> -static int __enc_payload(struct snp_guest_dev *snp_dev, struct snp_guest_msg *msg,
> +static int __enc_payload(struct aesgcm_ctx *ctx, struct snp_guest_msg *msg,
>   			 void *plaintext, size_t len)
>   {
> -	struct snp_guest_crypto *crypto = snp_dev->crypto;
>   	struct snp_guest_msg_hdr *hdr = &msg->hdr;
> +	u8 iv[GCM_AES_IV_SIZE] = {};
>   
> -	memset(crypto->iv, 0, crypto->iv_len);
> -	memcpy(crypto->iv, &hdr->msg_seqno, sizeof(hdr->msg_seqno));
> +	if (WARN_ON((hdr->msg_sz + ctx->authsize) > sizeof(msg->payload)))
> +		return -EBADMSG;
>   
> -	return enc_dec_message(crypto, msg, plaintext, msg->payload, len, true);
> +	memcpy(iv, &hdr->msg_seqno, sizeof(hdr->msg_seqno));
> +	aesgcm_encrypt(ctx, msg->payload, plaintext, len, &hdr->algo, AAD_LEN,
> +		       iv, hdr->authtag);
> +	return 0;
>   }
>   
> -static int dec_payload(struct snp_guest_dev *snp_dev, struct snp_guest_msg *msg,
> +static int dec_payload(struct aesgcm_ctx *ctx, struct snp_guest_msg *msg,
>   		       void *plaintext, size_t len)
>   {
> -	struct snp_guest_crypto *crypto = snp_dev->crypto;
>   	struct snp_guest_msg_hdr *hdr = &msg->hdr;
> +	u8 iv[GCM_AES_IV_SIZE] = {};
>   
> -	/* Build IV with response buffer sequence number */
> -	memset(crypto->iv, 0, crypto->iv_len);
> -	memcpy(crypto->iv, &hdr->msg_seqno, sizeof(hdr->msg_seqno));
> -
> -	return enc_dec_message(crypto, msg, msg->payload, plaintext, len, false);
> +	memcpy(iv, &hdr->msg_seqno, sizeof(hdr->msg_seqno));
> +	if (aesgcm_decrypt(ctx, plaintext, msg->payload, len, &hdr->algo,
> +			   AAD_LEN, iv, hdr->authtag))
> +		return 0;
> +	else
> +		return -EBADMSG;
>   }
>   
>   static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, void *payload, u32 sz)
>   {
> -	struct snp_guest_crypto *crypto = snp_dev->crypto;
>   	struct snp_guest_msg *resp = snp_dev->response;
>   	struct snp_guest_msg *req = snp_dev->request;
>   	struct snp_guest_msg_hdr *req_hdr = &req->hdr;
>   	struct snp_guest_msg_hdr *resp_hdr = &resp->hdr;
> +	struct aesgcm_ctx *ctx = snp_dev->ctx;
>   
>   	dev_dbg(snp_dev->dev, "response [seqno %lld type %d version %d sz %d]\n",
>   		resp_hdr->msg_seqno, resp_hdr->msg_type, resp_hdr->msg_version, resp_hdr->msg_sz);
> @@ -287,11 +214,11 @@ static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, void *payload,
>   	 * If the message size is greater than our buffer length then return
>   	 * an error.
>   	 */
> -	if (unlikely((resp_hdr->msg_sz + crypto->a_len) > sz))
> +	if (unlikely((resp_hdr->msg_sz + ctx->authsize) > sz))
>   		return -EBADMSG;
>   
>   	/* Decrypt the payload */
> -	return dec_payload(snp_dev, resp, payload, resp_hdr->msg_sz + crypto->a_len);
> +	return dec_payload(ctx, resp, payload, resp_hdr->msg_sz);
>   }
>   
>   static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, int version, u8 type,
> @@ -318,7 +245,7 @@ static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, int version, u8
>   	dev_dbg(snp_dev->dev, "request [seqno %lld type %d version %d sz %d]\n",
>   		hdr->msg_seqno, hdr->msg_type, hdr->msg_version, hdr->msg_sz);
>   
> -	return __enc_payload(snp_dev, req, payload, sz);
> +	return __enc_payload(snp_dev->ctx, req, payload, sz);
>   }
>   
>   static int __handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code, __u64 *fw_err)
> @@ -446,7 +373,6 @@ static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code, in
>   
>   static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg)
>   {
> -	struct snp_guest_crypto *crypto = snp_dev->crypto;
>   	struct snp_report_resp *resp;
>   	struct snp_report_req req;
>   	int rc, resp_len;
> @@ -464,7 +390,7 @@ static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_io
>   	 * response payload. Make sure that it has enough space to cover the
>   	 * authtag.
>   	 */
> -	resp_len = sizeof(resp->data) + crypto->a_len;
> +	resp_len = sizeof(resp->data) + get_ctx_authsize(snp_dev);
>   	resp = kzalloc(resp_len, GFP_KERNEL_ACCOUNT);
>   	if (!resp)
>   		return -ENOMEM;
> @@ -485,7 +411,6 @@ static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_io
>   
>   static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg)
>   {
> -	struct snp_guest_crypto *crypto = snp_dev->crypto;
>   	struct snp_derived_key_resp resp = {0};
>   	struct snp_derived_key_req req;
>   	int rc, resp_len;
> @@ -502,7 +427,7 @@ static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_reque
>   	 * response payload. Make sure that it has enough space to cover the
>   	 * authtag.
>   	 */
> -	resp_len = sizeof(resp.data) + crypto->a_len;
> +	resp_len = sizeof(resp.data) + get_ctx_authsize(snp_dev);
>   	if (sizeof(buf) < resp_len)
>   		return -ENOMEM;
>   
> @@ -527,7 +452,6 @@ static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_reque
>   
>   static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg)
>   {
> -	struct snp_guest_crypto *crypto = snp_dev->crypto;
>   	struct snp_ext_report_req req;
>   	struct snp_report_resp *resp;
>   	int ret, npages = 0, resp_len;
> @@ -565,7 +489,7 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
>   	 * response payload. Make sure that it has enough space to cover the
>   	 * authtag.
>   	 */
> -	resp_len = sizeof(resp->data) + crypto->a_len;
> +	resp_len = sizeof(resp->data) + get_ctx_authsize(snp_dev);
>   	resp = kzalloc(resp_len, GFP_KERNEL_ACCOUNT);
>   	if (!resp)
>   		return -ENOMEM;
> @@ -777,8 +701,8 @@ static int __init sev_guest_probe(struct platform_device *pdev)
>   		goto e_free_response;
>   
>   	ret = -EIO;
> -	snp_dev->crypto = init_crypto(snp_dev, snp_dev->vmpck, VMPCK_KEY_LEN);
> -	if (!snp_dev->crypto)
> +	snp_dev->ctx = snp_init_crypto(snp_dev->vmpck, VMPCK_KEY_LEN);
> +	if (!snp_dev->ctx)
>   		goto e_free_cert_data;
>   
>   	misc = &snp_dev->misc;
> @@ -793,11 +717,13 @@ static int __init sev_guest_probe(struct platform_device *pdev)
>   
>   	ret =  misc_register(misc);
>   	if (ret)
> -		goto e_free_cert_data;
> +		goto e_free_ctx;
>   
>   	dev_info(dev, "Initialized SEV guest driver (using vmpck_id %d)\n", vmpck_id);
>   	return 0;
>   
> +e_free_ctx:
> +	kfree(snp_dev->ctx);
>   e_free_cert_data:
>   	free_shared_pages(snp_dev->certs_data, SEV_FW_BLOB_MAX_SIZE);
>   e_free_response:
> @@ -816,7 +742,7 @@ static int __exit sev_guest_remove(struct platform_device *pdev)
>   	free_shared_pages(snp_dev->certs_data, SEV_FW_BLOB_MAX_SIZE);
>   	free_shared_pages(snp_dev->response, sizeof(struct snp_guest_msg));
>   	free_shared_pages(snp_dev->request, sizeof(struct snp_guest_msg));
> -	deinit_crypto(snp_dev->crypto);
> +	kfree(snp_dev->ctx);
>   	misc_deregister(&snp_dev->misc);
>   
>   	return 0;
> diff --git a/drivers/virt/coco/sev-guest/sev-guest.h b/drivers/virt/coco/sev-guest/sev-guest.h
> index 21bda26fdb95..ceb798a404d6 100644
> --- a/drivers/virt/coco/sev-guest/sev-guest.h
> +++ b/drivers/virt/coco/sev-guest/sev-guest.h
> @@ -13,6 +13,9 @@
>   #include <linux/types.h>
>   
>   #define MAX_AUTHTAG_LEN		32
> +#define AUTHTAG_LEN		16
> +#define AAD_LEN			48
> +#define MSG_HDR_VER		1
>   
>   /* See SNP spec SNP_GUEST_REQUEST section for the structure */
>   enum msg_type {
  
Nikunj A. Dadhania April 5, 2023, 5:10 a.m. UTC | #2
On 4/4/2023 12:39 AM, Tom Lendacky wrote:
> On 3/26/23 09:46, Nikunj A Dadhania wrote:
>> SEV-SNP guests with SecureTSC enabled need to send a TSC_INFO SNP
>> Guest message to the AMD security processor before the smpboot phase
>> starts. Details from the TSC_INFO response have to be programmed in
>> the VMSA before the secondary CPUs are brought up.
>>
>> Start using the AES GCM library implementation as the crypto API is not
>> available yet this early.
> 
> This isn't quite true, yet. You should add that the encryption code will 
> be moved out of sev-guest to support Secure TSC, but to make the diffs> easier to review, convert the Crypto API usage over to AES GCM library > usage before moving it.

Yes, will change it accordingly.

Regards
Nikunj
  

Patch

diff --git a/drivers/virt/coco/sev-guest/Kconfig b/drivers/virt/coco/sev-guest/Kconfig
index f9db0799ae67..bcc760bfb468 100644
--- a/drivers/virt/coco/sev-guest/Kconfig
+++ b/drivers/virt/coco/sev-guest/Kconfig
@@ -2,8 +2,7 @@  config SEV_GUEST
 	tristate "AMD SEV Guest driver"
 	default m
 	depends on AMD_MEM_ENCRYPT
-	select CRYPTO_AEAD2
-	select CRYPTO_GCM
+	select CRYPTO_LIB_AESGCM
 	help
 	  SEV-SNP firmware provides the guest a mechanism to communicate with
 	  the PSP without risk from a malicious hypervisor who wishes to read,
diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c
index 46f1a8d558b0..57af908bafba 100644
--- a/drivers/virt/coco/sev-guest/sev-guest.c
+++ b/drivers/virt/coco/sev-guest/sev-guest.c
@@ -16,8 +16,7 @@ 
 #include <linux/miscdevice.h>
 #include <linux/set_memory.h>
 #include <linux/fs.h>
-#include <crypto/aead.h>
-#include <linux/scatterlist.h>
+#include <crypto/gcm.h>
 #include <linux/psp-sev.h>
 #include <uapi/linux/sev-guest.h>
 #include <uapi/linux/psp-sev.h>
@@ -28,24 +27,16 @@ 
 #include "sev-guest.h"
 
 #define DEVICE_NAME	"sev-guest"
-#define AAD_LEN		48
-#define MSG_HDR_VER	1
 
 #define SNP_REQ_MAX_RETRY_DURATION	(60*HZ)
 #define SNP_REQ_RETRY_DELAY		(2*HZ)
 
-struct snp_guest_crypto {
-	struct crypto_aead *tfm;
-	u8 *iv, *authtag;
-	int iv_len, a_len;
-};
-
 struct snp_guest_dev {
 	struct device *dev;
 	struct miscdevice misc;
 
 	void *certs_data;
-	struct snp_guest_crypto *crypto;
+	struct aesgcm_ctx *ctx;
 	struct snp_guest_msg *request, *response;
 	struct snp_secrets_page_layout *layout;
 	struct snp_req_data input;
@@ -60,6 +51,15 @@  MODULE_PARM_DESC(vmpck_id, "The VMPCK ID to use when communicating with the PSP.
 /* Mutex to serialize the shared buffer access and command handling. */
 static DEFINE_MUTEX(snp_cmd_mutex);
 
+static inline unsigned int get_ctx_authsize(struct snp_guest_dev *snp_dev)
+{
+	if (snp_dev && snp_dev->ctx)
+		return snp_dev->ctx->authsize;
+
+	WARN_ONCE(1, "Unable to get crypto authsize\n");
+	return 0;
+}
+
 static bool is_vmpck_empty(struct snp_guest_dev *snp_dev)
 {
 	char zero_key[VMPCK_KEY_LEN] = {0};
@@ -144,132 +144,59 @@  static inline struct snp_guest_dev *to_snp_dev(struct file *file)
 	return container_of(dev, struct snp_guest_dev, misc);
 }
 
-static struct snp_guest_crypto *init_crypto(struct snp_guest_dev *snp_dev, u8 *key, size_t keylen)
+static struct aesgcm_ctx *snp_init_crypto(u8 *key, size_t keylen)
 {
-	struct snp_guest_crypto *crypto;
+	struct aesgcm_ctx *ctx;
 
-	crypto = kzalloc(sizeof(*crypto), GFP_KERNEL_ACCOUNT);
-	if (!crypto)
+	ctx = kzalloc(sizeof(*ctx), GFP_KERNEL_ACCOUNT);
+	if (!ctx)
 		return NULL;
 
-	crypto->tfm = crypto_alloc_aead("gcm(aes)", 0, 0);
-	if (IS_ERR(crypto->tfm))
-		goto e_free;
-
-	if (crypto_aead_setkey(crypto->tfm, key, keylen))
-		goto e_free_crypto;
-
-	crypto->iv_len = crypto_aead_ivsize(crypto->tfm);
-	crypto->iv = kmalloc(crypto->iv_len, GFP_KERNEL_ACCOUNT);
-	if (!crypto->iv)
-		goto e_free_crypto;
-
-	if (crypto_aead_authsize(crypto->tfm) > MAX_AUTHTAG_LEN) {
-		if (crypto_aead_setauthsize(crypto->tfm, MAX_AUTHTAG_LEN)) {
-			dev_err(snp_dev->dev, "failed to set authsize to %d\n", MAX_AUTHTAG_LEN);
-			goto e_free_iv;
-		}
+	if (aesgcm_expandkey(ctx, key, keylen, AUTHTAG_LEN)) {
+		pr_err("SNP: crypto init failed\n");
+		kfree(ctx);
+		return NULL;
 	}
 
-	crypto->a_len = crypto_aead_authsize(crypto->tfm);
-	crypto->authtag = kmalloc(crypto->a_len, GFP_KERNEL_ACCOUNT);
-	if (!crypto->authtag)
-		goto e_free_iv;
-
-	return crypto;
-
-e_free_iv:
-	kfree(crypto->iv);
-e_free_crypto:
-	crypto_free_aead(crypto->tfm);
-e_free:
-	kfree(crypto);
-
-	return NULL;
+	return ctx;
 }
 
-static void deinit_crypto(struct snp_guest_crypto *crypto)
-{
-	crypto_free_aead(crypto->tfm);
-	kfree(crypto->iv);
-	kfree(crypto->authtag);
-	kfree(crypto);
-}
-
-static int enc_dec_message(struct snp_guest_crypto *crypto, struct snp_guest_msg *msg,
-			   u8 *src_buf, u8 *dst_buf, size_t len, bool enc)
-{
-	struct snp_guest_msg_hdr *hdr = &msg->hdr;
-	struct scatterlist src[3], dst[3];
-	DECLARE_CRYPTO_WAIT(wait);
-	struct aead_request *req;
-	int ret;
-
-	req = aead_request_alloc(crypto->tfm, GFP_KERNEL);
-	if (!req)
-		return -ENOMEM;
-
-	/*
-	 * AEAD memory operations:
-	 * +------ AAD -------+------- DATA -----+---- AUTHTAG----+
-	 * |  msg header      |  plaintext       |  hdr->authtag  |
-	 * | bytes 30h - 5Fh  |    or            |                |
-	 * |                  |   cipher         |                |
-	 * +------------------+------------------+----------------+
-	 */
-	sg_init_table(src, 3);
-	sg_set_buf(&src[0], &hdr->algo, AAD_LEN);
-	sg_set_buf(&src[1], src_buf, hdr->msg_sz);
-	sg_set_buf(&src[2], hdr->authtag, crypto->a_len);
-
-	sg_init_table(dst, 3);
-	sg_set_buf(&dst[0], &hdr->algo, AAD_LEN);
-	sg_set_buf(&dst[1], dst_buf, hdr->msg_sz);
-	sg_set_buf(&dst[2], hdr->authtag, crypto->a_len);
-
-	aead_request_set_ad(req, AAD_LEN);
-	aead_request_set_tfm(req, crypto->tfm);
-	aead_request_set_callback(req, 0, crypto_req_done, &wait);
-
-	aead_request_set_crypt(req, src, dst, len, crypto->iv);
-	ret = crypto_wait_req(enc ? crypto_aead_encrypt(req) : crypto_aead_decrypt(req), &wait);
-
-	aead_request_free(req);
-	return ret;
-}
-
-static int __enc_payload(struct snp_guest_dev *snp_dev, struct snp_guest_msg *msg,
+static int __enc_payload(struct aesgcm_ctx *ctx, struct snp_guest_msg *msg,
 			 void *plaintext, size_t len)
 {
-	struct snp_guest_crypto *crypto = snp_dev->crypto;
 	struct snp_guest_msg_hdr *hdr = &msg->hdr;
+	u8 iv[GCM_AES_IV_SIZE] = {};
 
-	memset(crypto->iv, 0, crypto->iv_len);
-	memcpy(crypto->iv, &hdr->msg_seqno, sizeof(hdr->msg_seqno));
+	if (WARN_ON((hdr->msg_sz + ctx->authsize) > sizeof(msg->payload)))
+		return -EBADMSG;
 
-	return enc_dec_message(crypto, msg, plaintext, msg->payload, len, true);
+	memcpy(iv, &hdr->msg_seqno, sizeof(hdr->msg_seqno));
+	aesgcm_encrypt(ctx, msg->payload, plaintext, len, &hdr->algo, AAD_LEN,
+		       iv, hdr->authtag);
+	return 0;
 }
 
-static int dec_payload(struct snp_guest_dev *snp_dev, struct snp_guest_msg *msg,
+static int dec_payload(struct aesgcm_ctx *ctx, struct snp_guest_msg *msg,
 		       void *plaintext, size_t len)
 {
-	struct snp_guest_crypto *crypto = snp_dev->crypto;
 	struct snp_guest_msg_hdr *hdr = &msg->hdr;
+	u8 iv[GCM_AES_IV_SIZE] = {};
 
-	/* Build IV with response buffer sequence number */
-	memset(crypto->iv, 0, crypto->iv_len);
-	memcpy(crypto->iv, &hdr->msg_seqno, sizeof(hdr->msg_seqno));
-
-	return enc_dec_message(crypto, msg, msg->payload, plaintext, len, false);
+	memcpy(iv, &hdr->msg_seqno, sizeof(hdr->msg_seqno));
+	if (aesgcm_decrypt(ctx, plaintext, msg->payload, len, &hdr->algo,
+			   AAD_LEN, iv, hdr->authtag))
+		return 0;
+	else
+		return -EBADMSG;
 }
 
 static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, void *payload, u32 sz)
 {
-	struct snp_guest_crypto *crypto = snp_dev->crypto;
 	struct snp_guest_msg *resp = snp_dev->response;
 	struct snp_guest_msg *req = snp_dev->request;
 	struct snp_guest_msg_hdr *req_hdr = &req->hdr;
 	struct snp_guest_msg_hdr *resp_hdr = &resp->hdr;
+	struct aesgcm_ctx *ctx = snp_dev->ctx;
 
 	dev_dbg(snp_dev->dev, "response [seqno %lld type %d version %d sz %d]\n",
 		resp_hdr->msg_seqno, resp_hdr->msg_type, resp_hdr->msg_version, resp_hdr->msg_sz);
@@ -287,11 +214,11 @@  static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, void *payload,
 	 * If the message size is greater than our buffer length then return
 	 * an error.
 	 */
-	if (unlikely((resp_hdr->msg_sz + crypto->a_len) > sz))
+	if (unlikely((resp_hdr->msg_sz + ctx->authsize) > sz))
 		return -EBADMSG;
 
 	/* Decrypt the payload */
-	return dec_payload(snp_dev, resp, payload, resp_hdr->msg_sz + crypto->a_len);
+	return dec_payload(ctx, resp, payload, resp_hdr->msg_sz);
 }
 
 static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, int version, u8 type,
@@ -318,7 +245,7 @@  static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, int version, u8
 	dev_dbg(snp_dev->dev, "request [seqno %lld type %d version %d sz %d]\n",
 		hdr->msg_seqno, hdr->msg_type, hdr->msg_version, hdr->msg_sz);
 
-	return __enc_payload(snp_dev, req, payload, sz);
+	return __enc_payload(snp_dev->ctx, req, payload, sz);
 }
 
 static int __handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code, __u64 *fw_err)
@@ -446,7 +373,6 @@  static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code, in
 
 static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg)
 {
-	struct snp_guest_crypto *crypto = snp_dev->crypto;
 	struct snp_report_resp *resp;
 	struct snp_report_req req;
 	int rc, resp_len;
@@ -464,7 +390,7 @@  static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_io
 	 * response payload. Make sure that it has enough space to cover the
 	 * authtag.
 	 */
-	resp_len = sizeof(resp->data) + crypto->a_len;
+	resp_len = sizeof(resp->data) + get_ctx_authsize(snp_dev);
 	resp = kzalloc(resp_len, GFP_KERNEL_ACCOUNT);
 	if (!resp)
 		return -ENOMEM;
@@ -485,7 +411,6 @@  static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_io
 
 static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg)
 {
-	struct snp_guest_crypto *crypto = snp_dev->crypto;
 	struct snp_derived_key_resp resp = {0};
 	struct snp_derived_key_req req;
 	int rc, resp_len;
@@ -502,7 +427,7 @@  static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_reque
 	 * response payload. Make sure that it has enough space to cover the
 	 * authtag.
 	 */
-	resp_len = sizeof(resp.data) + crypto->a_len;
+	resp_len = sizeof(resp.data) + get_ctx_authsize(snp_dev);
 	if (sizeof(buf) < resp_len)
 		return -ENOMEM;
 
@@ -527,7 +452,6 @@  static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_reque
 
 static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg)
 {
-	struct snp_guest_crypto *crypto = snp_dev->crypto;
 	struct snp_ext_report_req req;
 	struct snp_report_resp *resp;
 	int ret, npages = 0, resp_len;
@@ -565,7 +489,7 @@  static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
 	 * response payload. Make sure that it has enough space to cover the
 	 * authtag.
 	 */
-	resp_len = sizeof(resp->data) + crypto->a_len;
+	resp_len = sizeof(resp->data) + get_ctx_authsize(snp_dev);
 	resp = kzalloc(resp_len, GFP_KERNEL_ACCOUNT);
 	if (!resp)
 		return -ENOMEM;
@@ -777,8 +701,8 @@  static int __init sev_guest_probe(struct platform_device *pdev)
 		goto e_free_response;
 
 	ret = -EIO;
-	snp_dev->crypto = init_crypto(snp_dev, snp_dev->vmpck, VMPCK_KEY_LEN);
-	if (!snp_dev->crypto)
+	snp_dev->ctx = snp_init_crypto(snp_dev->vmpck, VMPCK_KEY_LEN);
+	if (!snp_dev->ctx)
 		goto e_free_cert_data;
 
 	misc = &snp_dev->misc;
@@ -793,11 +717,13 @@  static int __init sev_guest_probe(struct platform_device *pdev)
 
 	ret =  misc_register(misc);
 	if (ret)
-		goto e_free_cert_data;
+		goto e_free_ctx;
 
 	dev_info(dev, "Initialized SEV guest driver (using vmpck_id %d)\n", vmpck_id);
 	return 0;
 
+e_free_ctx:
+	kfree(snp_dev->ctx);
 e_free_cert_data:
 	free_shared_pages(snp_dev->certs_data, SEV_FW_BLOB_MAX_SIZE);
 e_free_response:
@@ -816,7 +742,7 @@  static int __exit sev_guest_remove(struct platform_device *pdev)
 	free_shared_pages(snp_dev->certs_data, SEV_FW_BLOB_MAX_SIZE);
 	free_shared_pages(snp_dev->response, sizeof(struct snp_guest_msg));
 	free_shared_pages(snp_dev->request, sizeof(struct snp_guest_msg));
-	deinit_crypto(snp_dev->crypto);
+	kfree(snp_dev->ctx);
 	misc_deregister(&snp_dev->misc);
 
 	return 0;
diff --git a/drivers/virt/coco/sev-guest/sev-guest.h b/drivers/virt/coco/sev-guest/sev-guest.h
index 21bda26fdb95..ceb798a404d6 100644
--- a/drivers/virt/coco/sev-guest/sev-guest.h
+++ b/drivers/virt/coco/sev-guest/sev-guest.h
@@ -13,6 +13,9 @@ 
 #include <linux/types.h>
 
 #define MAX_AUTHTAG_LEN		32
+#define AUTHTAG_LEN		16
+#define AAD_LEN			48
+#define MSG_HDR_VER		1
 
 /* See SNP spec SNP_GUEST_REQUEST section for the structure */
 enum msg_type {