[v6,07/16] x86/sev: Move and reorganize sev guest request api

Message ID 20231128125959.1810039-8-nikunj@amd.com
State New
Headers
Series Add Secure TSC support for SNP guests |

Commit Message

Nikunj A. Dadhania Nov. 28, 2023, 12:59 p.m. UTC
  For enabling Secure TSC, SEV-SNP guests need to communicate with the
AMD Security Processor early during boot. Many of the required
functions are implemented in the sev-guest driver and therefore not
available at early boot. Move the required functions and provide
API to the sev guest driver for sending guest message and vmpck
routines.

As there is no external caller for snp_issue_guest_request() anymore,
make it static and drop the prototype from sev-guest.h.

Signed-off-by: Nikunj A Dadhania <nikunj@amd.com>
---
 arch/x86/Kconfig                        |   1 +
 arch/x86/include/asm/sev-guest.h        |  91 ++++-
 arch/x86/include/asm/sev.h              |  10 -
 arch/x86/kernel/sev.c                   | 451 +++++++++++++++++++++-
 drivers/virt/coco/sev-guest/Kconfig     |   1 -
 drivers/virt/coco/sev-guest/sev-guest.c | 479 +-----------------------
 6 files changed, 550 insertions(+), 483 deletions(-)
  

Comments

kernel test robot Nov. 28, 2023, 10:50 p.m. UTC | #1
Hi Nikunj,

kernel test robot noticed the following build warnings:

[auto build test WARNING on tip/x86/mm]
[also build test WARNING on linus/master v6.7-rc3 next-20231128]
[cannot apply to tip/x86/core kvm/queue kvm/linux-next]
[If your patch is applied to the wrong git tree, kindly drop us a note.
And when submitting patch, we suggest to use '--base' as documented in
https://git-scm.com/docs/git-format-patch#_base_tree_information]

url:    https://github.com/intel-lab-lkp/linux/commits/Nikunj-A-Dadhania/virt-sev-guest-Move-mutex-to-SNP-guest-device-structure/20231128-220026
base:   tip/x86/mm
patch link:    https://lore.kernel.org/r/20231128125959.1810039-8-nikunj%40amd.com
patch subject: [PATCH v6 07/16] x86/sev: Move and reorganize sev guest request api
config: x86_64-allyesconfig (https://download.01.org/0day-ci/archive/20231129/202311290340.dJ5EmU8J-lkp@intel.com/config)
compiler: clang version 16.0.4 (https://github.com/llvm/llvm-project.git ae42196bc493ffe877a7e3dff8be32035dea4d07)
reproduce (this is a W=1 build): (https://download.01.org/0day-ci/archive/20231129/202311290340.dJ5EmU8J-lkp@intel.com/reproduce)

If you fix the issue in a separate patch/commit (i.e. not just a new version of
the same patch/commit), kindly add following tags
| Reported-by: kernel test robot <lkp@intel.com>
| Closes: https://lore.kernel.org/oe-kbuild-all/202311290340.dJ5EmU8J-lkp@intel.com/

All warnings (new ones prefixed by >>):

>> drivers/virt/coco/sev-guest/sev-guest.c:450:6: warning: variable 'ret' is used uninitialized whenever 'if' condition is true [-Wsometimes-uninitialized]
           if (!snp_dev->certs_data)
               ^~~~~~~~~~~~~~~~~~~~
   drivers/virt/coco/sev-guest/sev-guest.c:480:9: note: uninitialized use occurs here
           return ret;
                  ^~~
   drivers/virt/coco/sev-guest/sev-guest.c:450:2: note: remove the 'if' if its condition is always false
           if (!snp_dev->certs_data)
           ^~~~~~~~~~~~~~~~~~~~~~~~~
   drivers/virt/coco/sev-guest/sev-guest.c:424:9: note: initialize the variable 'ret' to silence this warning
           int ret;
                  ^
                   = 0
   1 warning generated.


vim +450 drivers/virt/coco/sev-guest/sev-guest.c

f47906782c7629 drivers/virt/coco/sev-guest/sev-guest.c Dan Williams          2023-10-10  418  
2bf93ffbb97e06 drivers/virt/coco/sevguest/sevguest.c   Tom Lendacky          2022-04-20  419  static int __init sev_guest_probe(struct platform_device *pdev)
fce96cf0443083 drivers/virt/coco/sevguest/sevguest.c   Brijesh Singh         2022-03-07  420  {
fce96cf0443083 drivers/virt/coco/sevguest/sevguest.c   Brijesh Singh         2022-03-07  421  	struct device *dev = &pdev->dev;
fce96cf0443083 drivers/virt/coco/sevguest/sevguest.c   Brijesh Singh         2022-03-07  422  	struct snp_guest_dev *snp_dev;
fce96cf0443083 drivers/virt/coco/sevguest/sevguest.c   Brijesh Singh         2022-03-07  423  	struct miscdevice *misc;
fce96cf0443083 drivers/virt/coco/sevguest/sevguest.c   Brijesh Singh         2022-03-07  424  	int ret;
fce96cf0443083 drivers/virt/coco/sevguest/sevguest.c   Brijesh Singh         2022-03-07  425  
d6fd48eff7506b drivers/virt/coco/sev-guest/sev-guest.c Borislav Petkov (AMD  2023-02-15  426) 	if (!cc_platform_has(CC_ATTR_GUEST_SEV_SNP))
d6fd48eff7506b drivers/virt/coco/sev-guest/sev-guest.c Borislav Petkov (AMD  2023-02-15  427) 		return -ENODEV;
d6fd48eff7506b drivers/virt/coco/sev-guest/sev-guest.c Borislav Petkov (AMD  2023-02-15  428) 
fce96cf0443083 drivers/virt/coco/sevguest/sevguest.c   Brijesh Singh         2022-03-07  429  	snp_dev = devm_kzalloc(&pdev->dev, sizeof(struct snp_guest_dev), GFP_KERNEL);
fce96cf0443083 drivers/virt/coco/sevguest/sevguest.c   Brijesh Singh         2022-03-07  430  	if (!snp_dev)
81b918a5844565 drivers/virt/coco/sev-guest/sev-guest.c Nikunj A Dadhania     2023-11-28  431  		return -ENOMEM;
fce96cf0443083 drivers/virt/coco/sevguest/sevguest.c   Brijesh Singh         2022-03-07  432  
523ae6405daace drivers/virt/coco/sev-guest/sev-guest.c Nikunj A Dadhania     2023-11-28  433  	if (!snp_assign_vmpck(snp_dev, vmpck_id)) {
523ae6405daace drivers/virt/coco/sev-guest/sev-guest.c Nikunj A Dadhania     2023-11-28  434  		dev_err(dev, "invalid vmpck id %u\n", vmpck_id);
81b918a5844565 drivers/virt/coco/sev-guest/sev-guest.c Nikunj A Dadhania     2023-11-28  435  		ret = -EINVAL;
81b918a5844565 drivers/virt/coco/sev-guest/sev-guest.c Nikunj A Dadhania     2023-11-28  436  		goto e_free_snpdev;
fce96cf0443083 drivers/virt/coco/sevguest/sevguest.c   Brijesh Singh         2022-03-07  437  	}
fce96cf0443083 drivers/virt/coco/sevguest/sevguest.c   Brijesh Singh         2022-03-07  438  
81b918a5844565 drivers/virt/coco/sev-guest/sev-guest.c Nikunj A Dadhania     2023-11-28  439  	if (snp_setup_psp_messaging(snp_dev)) {
81b918a5844565 drivers/virt/coco/sev-guest/sev-guest.c Nikunj A Dadhania     2023-11-28  440  		dev_err(dev, "Unable to setup PSP messaging vmpck id %u\n", snp_dev->vmpck_id);
81b918a5844565 drivers/virt/coco/sev-guest/sev-guest.c Nikunj A Dadhania     2023-11-28  441  		ret = -ENODEV;
81b918a5844565 drivers/virt/coco/sev-guest/sev-guest.c Nikunj A Dadhania     2023-11-28  442  		goto e_free_snpdev;
fce96cf0443083 drivers/virt/coco/sevguest/sevguest.c   Brijesh Singh         2022-03-07  443  	}
fce96cf0443083 drivers/virt/coco/sevguest/sevguest.c   Brijesh Singh         2022-03-07  444  
4ec0ddf1cc3c0c drivers/virt/coco/sev-guest/sev-guest.c Nikunj A Dadhania     2023-11-28  445  	mutex_init(&snp_dev->cmd_mutex);
fce96cf0443083 drivers/virt/coco/sevguest/sevguest.c   Brijesh Singh         2022-03-07  446  	platform_set_drvdata(pdev, snp_dev);
fce96cf0443083 drivers/virt/coco/sevguest/sevguest.c   Brijesh Singh         2022-03-07  447  	snp_dev->dev = dev;
fce96cf0443083 drivers/virt/coco/sevguest/sevguest.c   Brijesh Singh         2022-03-07  448  
81b918a5844565 drivers/virt/coco/sev-guest/sev-guest.c Nikunj A Dadhania     2023-11-28  449  	snp_dev->certs_data = alloc_shared_pages(SEV_FW_BLOB_MAX_SIZE);
d80b494f712317 drivers/virt/coco/sevguest/sevguest.c   Brijesh Singh         2022-03-07 @450  	if (!snp_dev->certs_data)
81b918a5844565 drivers/virt/coco/sev-guest/sev-guest.c Nikunj A Dadhania     2023-11-28  451  		goto e_free_ctx;
fce96cf0443083 drivers/virt/coco/sevguest/sevguest.c   Brijesh Singh         2022-03-07  452  
fce96cf0443083 drivers/virt/coco/sevguest/sevguest.c   Brijesh Singh         2022-03-07  453  	misc = &snp_dev->misc;
fce96cf0443083 drivers/virt/coco/sevguest/sevguest.c   Brijesh Singh         2022-03-07  454  	misc->minor = MISC_DYNAMIC_MINOR;
fce96cf0443083 drivers/virt/coco/sevguest/sevguest.c   Brijesh Singh         2022-03-07  455  	misc->name = DEVICE_NAME;
fce96cf0443083 drivers/virt/coco/sevguest/sevguest.c   Brijesh Singh         2022-03-07  456  	misc->fops = &snp_guest_fops;
fce96cf0443083 drivers/virt/coco/sevguest/sevguest.c   Brijesh Singh         2022-03-07  457  
f47906782c7629 drivers/virt/coco/sev-guest/sev-guest.c Dan Williams          2023-10-10  458  	ret = tsm_register(&sev_tsm_ops, snp_dev, &tsm_report_extra_type);
f47906782c7629 drivers/virt/coco/sev-guest/sev-guest.c Dan Williams          2023-10-10  459  	if (ret)
f47906782c7629 drivers/virt/coco/sev-guest/sev-guest.c Dan Williams          2023-10-10  460  		goto e_free_cert_data;
f47906782c7629 drivers/virt/coco/sev-guest/sev-guest.c Dan Williams          2023-10-10  461  
f47906782c7629 drivers/virt/coco/sev-guest/sev-guest.c Dan Williams          2023-10-10  462  	ret = devm_add_action_or_reset(&pdev->dev, unregister_sev_tsm, NULL);
f47906782c7629 drivers/virt/coco/sev-guest/sev-guest.c Dan Williams          2023-10-10  463  	if (ret)
f47906782c7629 drivers/virt/coco/sev-guest/sev-guest.c Dan Williams          2023-10-10  464  		goto e_free_cert_data;
f47906782c7629 drivers/virt/coco/sev-guest/sev-guest.c Dan Williams          2023-10-10  465  
fce96cf0443083 drivers/virt/coco/sevguest/sevguest.c   Brijesh Singh         2022-03-07  466  	ret =  misc_register(misc);
fce96cf0443083 drivers/virt/coco/sevguest/sevguest.c   Brijesh Singh         2022-03-07  467  	if (ret)
81b918a5844565 drivers/virt/coco/sev-guest/sev-guest.c Nikunj A Dadhania     2023-11-28  468  		goto e_free_cert_data;
81b918a5844565 drivers/virt/coco/sev-guest/sev-guest.c Nikunj A Dadhania     2023-11-28  469  
81b918a5844565 drivers/virt/coco/sev-guest/sev-guest.c Nikunj A Dadhania     2023-11-28  470  	dev_info(dev, "Initialized SEV guest driver (using vmpck_id %u)\n", snp_dev->vmpck_id);
fce96cf0443083 drivers/virt/coco/sevguest/sevguest.c   Brijesh Singh         2022-03-07  471  
fce96cf0443083 drivers/virt/coco/sevguest/sevguest.c   Brijesh Singh         2022-03-07  472  	return 0;
fce96cf0443083 drivers/virt/coco/sevguest/sevguest.c   Brijesh Singh         2022-03-07  473  
d80b494f712317 drivers/virt/coco/sevguest/sevguest.c   Brijesh Singh         2022-03-07  474  e_free_cert_data:
d80b494f712317 drivers/virt/coco/sevguest/sevguest.c   Brijesh Singh         2022-03-07  475  	free_shared_pages(snp_dev->certs_data, SEV_FW_BLOB_MAX_SIZE);
81b918a5844565 drivers/virt/coco/sev-guest/sev-guest.c Nikunj A Dadhania     2023-11-28  476  e_free_ctx:
81b918a5844565 drivers/virt/coco/sev-guest/sev-guest.c Nikunj A Dadhania     2023-11-28  477  	kfree(snp_dev->ctx);
81b918a5844565 drivers/virt/coco/sev-guest/sev-guest.c Nikunj A Dadhania     2023-11-28  478  e_free_snpdev:
81b918a5844565 drivers/virt/coco/sev-guest/sev-guest.c Nikunj A Dadhania     2023-11-28  479  	kfree(snp_dev);
fce96cf0443083 drivers/virt/coco/sevguest/sevguest.c   Brijesh Singh         2022-03-07  480  	return ret;
fce96cf0443083 drivers/virt/coco/sevguest/sevguest.c   Brijesh Singh         2022-03-07  481  }
fce96cf0443083 drivers/virt/coco/sevguest/sevguest.c   Brijesh Singh         2022-03-07  482
  
kernel test robot Nov. 29, 2023, 2:40 a.m. UTC | #2
Hi Nikunj,

kernel test robot noticed the following build warnings:

[auto build test WARNING on tip/x86/mm]
[also build test WARNING on linus/master v6.7-rc3 next-20231128]
[cannot apply to tip/x86/core kvm/queue kvm/linux-next]
[If your patch is applied to the wrong git tree, kindly drop us a note.
And when submitting patch, we suggest to use '--base' as documented in
https://git-scm.com/docs/git-format-patch#_base_tree_information]

url:    https://github.com/intel-lab-lkp/linux/commits/Nikunj-A-Dadhania/virt-sev-guest-Move-mutex-to-SNP-guest-device-structure/20231128-220026
base:   tip/x86/mm
patch link:    https://lore.kernel.org/r/20231128125959.1810039-8-nikunj%40amd.com
patch subject: [PATCH v6 07/16] x86/sev: Move and reorganize sev guest request api
config: x86_64-rhel-8.3-rust (https://download.01.org/0day-ci/archive/20231129/202311290851.yrAyZYIl-lkp@intel.com/config)
compiler: clang version 16.0.4 (https://github.com/llvm/llvm-project.git ae42196bc493ffe877a7e3dff8be32035dea4d07)
reproduce (this is a W=1 build): (https://download.01.org/0day-ci/archive/20231129/202311290851.yrAyZYIl-lkp@intel.com/reproduce)

If you fix the issue in a separate patch/commit (i.e. not just a new version of
the same patch/commit), kindly add following tags
| Reported-by: kernel test robot <lkp@intel.com>
| Closes: https://lore.kernel.org/oe-kbuild-all/202311290851.yrAyZYIl-lkp@intel.com/

All warnings (new ones prefixed by >>):

>> arch/x86/kernel/sev.c:2404:6: warning: variable 'ret' is used uninitialized whenever 'if' condition is true [-Wsometimes-uninitialized]
           if (!pdata->layout) {
               ^~~~~~~~~~~~~~
   arch/x86/kernel/sev.c:2446:9: note: uninitialized use occurs here
           return ret;
                  ^~~
   arch/x86/kernel/sev.c:2404:2: note: remove the 'if' if its condition is always false
           if (!pdata->layout) {
           ^~~~~~~~~~~~~~~~~~~~~
   arch/x86/kernel/sev.c:2380:9: note: initialize the variable 'ret' to silence this warning
           int ret;
                  ^
                   = 0
   1 warning generated.


vim +2404 arch/x86/kernel/sev.c

  2376	
  2377	int snp_setup_psp_messaging(struct snp_guest_dev *snp_dev)
  2378	{
  2379		struct sev_guest_platform_data *pdata;
  2380		int ret;
  2381	
  2382		if (!cc_platform_has(CC_ATTR_GUEST_SEV_SNP)) {
  2383			pr_err("SNP not supported\n");
  2384			return 0;
  2385		}
  2386	
  2387		if (platform_data) {
  2388			pr_debug("SNP platform data already initialized.\n");
  2389			goto create_ctx;
  2390		}
  2391	
  2392		if (!secrets_pa) {
  2393			pr_err("SNP secrets page not found\n");
  2394			return -ENODEV;
  2395		}
  2396	
  2397		pdata = kzalloc(sizeof(struct sev_guest_platform_data), GFP_KERNEL);
  2398		if (!pdata) {
  2399			pr_err("Allocation of SNP guest platform data failed\n");
  2400			return -ENOMEM;
  2401		}
  2402	
  2403		pdata->layout = (__force void *)ioremap_encrypted(secrets_pa, PAGE_SIZE);
> 2404		if (!pdata->layout) {
  2405			pr_err("Failed to map SNP secrets page.\n");
  2406			goto e_free_pdata;
  2407		}
  2408	
  2409		ret = -ENOMEM;
  2410		/* Allocate the shared page used for the request and response message. */
  2411		pdata->request = alloc_shared_pages(sizeof(struct snp_guest_msg));
  2412		if (!pdata->request)
  2413			goto e_unmap;
  2414	
  2415		pdata->response = alloc_shared_pages(sizeof(struct snp_guest_msg));
  2416		if (!pdata->response)
  2417			goto e_free_request;
  2418	
  2419		/* initial the input address for guest request */
  2420		pdata->input.req_gpa = __pa(pdata->request);
  2421		pdata->input.resp_gpa = __pa(pdata->response);
  2422		platform_data = pdata;
  2423	
  2424	create_ctx:
  2425		ret = -EIO;
  2426		snp_dev->ctx = snp_init_crypto(snp_dev->vmpck_id);
  2427		if (!snp_dev->ctx) {
  2428			pr_err("SNP crypto context initialization failed\n");
  2429			platform_data = NULL;
  2430			goto e_free_response;
  2431		}
  2432	
  2433		snp_dev->pdata = platform_data;
  2434	
  2435		return 0;
  2436	
  2437	e_free_response:
  2438		free_shared_pages(pdata->response, sizeof(struct snp_guest_msg));
  2439	e_free_request:
  2440		free_shared_pages(pdata->request, sizeof(struct snp_guest_msg));
  2441	e_unmap:
  2442		iounmap(pdata->layout);
  2443	e_free_pdata:
  2444		kfree(pdata);
  2445	
  2446		return ret;
  2447	}
  2448	EXPORT_SYMBOL_GPL(snp_setup_psp_messaging);
  2449
  
Dionna Amalie Glaze Dec. 5, 2023, 5:13 p.m. UTC | #3
On Tue, Nov 28, 2023 at 5:01 AM Nikunj A Dadhania <nikunj@amd.com> wrote:
>
> For enabling Secure TSC, SEV-SNP guests need to communicate with the
> AMD Security Processor early during boot. Many of the required
> functions are implemented in the sev-guest driver and therefore not
> available at early boot. Move the required functions and provide
> API to the sev guest driver for sending guest message and vmpck
> routines.
>
> As there is no external caller for snp_issue_guest_request() anymore,
> make it static and drop the prototype from sev-guest.h.
>
> Signed-off-by: Nikunj A Dadhania <nikunj@amd.com>
> ---
>  arch/x86/Kconfig                        |   1 +
>  arch/x86/include/asm/sev-guest.h        |  91 ++++-
>  arch/x86/include/asm/sev.h              |  10 -
>  arch/x86/kernel/sev.c                   | 451 +++++++++++++++++++++-
>  drivers/virt/coco/sev-guest/Kconfig     |   1 -
>  drivers/virt/coco/sev-guest/sev-guest.c | 479 +-----------------------
>  6 files changed, 550 insertions(+), 483 deletions(-)
>
> diff --git a/arch/x86/Kconfig b/arch/x86/Kconfig
> index 3762f41bb092..b8f374ec5651 100644
> --- a/arch/x86/Kconfig
> +++ b/arch/x86/Kconfig
> @@ -1534,6 +1534,7 @@ config AMD_MEM_ENCRYPT
>         select ARCH_HAS_CC_PLATFORM
>         select X86_MEM_ENCRYPT
>         select UNACCEPTED_MEMORY
> +       select CRYPTO_LIB_AESGCM
>         help
>           Say yes to enable support for the encryption of system memory.
>           This requires an AMD processor that supports Secure Memory
> diff --git a/arch/x86/include/asm/sev-guest.h b/arch/x86/include/asm/sev-guest.h
> index 27cc15ad6131..16bf25c14e6f 100644
> --- a/arch/x86/include/asm/sev-guest.h
> +++ b/arch/x86/include/asm/sev-guest.h
> @@ -11,6 +11,11 @@
>  #define __VIRT_SEVGUEST_H__
>
>  #include <linux/types.h>
> +#include <linux/miscdevice.h>
> +#include <asm/sev.h>
> +
> +#define SNP_REQ_MAX_RETRY_DURATION    (60*HZ)
> +#define SNP_REQ_RETRY_DELAY           (2*HZ)
>
>  #define MAX_AUTHTAG_LEN                32
>  #define AUTHTAG_LEN            16
> @@ -58,11 +63,52 @@ struct snp_guest_msg_hdr {
>         u8 rsvd3[35];
>  } __packed;
>
> +/* SNP Guest message request */
> +struct snp_req_data {
> +       unsigned long req_gpa;
> +       unsigned long resp_gpa;
> +};
> +
>  struct snp_guest_msg {
>         struct snp_guest_msg_hdr hdr;
>         u8 payload[4000];
>  } __packed;
>
> +struct sev_guest_platform_data {
> +       /* request and response are in unencrypted memory */
> +       struct snp_guest_msg *request;
> +       struct snp_guest_msg *response;
> +
> +       struct snp_secrets_page_layout *layout;
> +       struct snp_req_data input;
> +};
> +
> +struct snp_guest_dev {
> +       struct device *dev;
> +       struct miscdevice misc;
> +
> +       /* Mutex to serialize the shared buffer access and command handling. */
> +       struct mutex cmd_mutex;
> +
> +       void *certs_data;
> +       struct aesgcm_ctx *ctx;
> +
> +       /*
> +        * Avoid information leakage by double-buffering shared messages
> +        * in fields that are in regular encrypted memory
> +        */
> +       struct snp_guest_msg secret_request;
> +       struct snp_guest_msg secret_response;
> +
> +       struct sev_guest_platform_data *pdata;
> +       union {
> +               struct snp_report_req report;
> +               struct snp_derived_key_req derived_key;
> +               struct snp_ext_report_req ext_report;
> +       } req;
> +       unsigned int vmpck_id;
> +};
> +
>  struct snp_guest_req {
>         void *req_buf;
>         size_t req_sz;
> @@ -79,6 +125,47 @@ struct snp_guest_req {
>         u8 msg_type;
>  };
>
> -int snp_issue_guest_request(struct snp_guest_req *req, struct snp_req_data *input,
> -                           struct snp_guest_request_ioctl *rio);
> +int snp_setup_psp_messaging(struct snp_guest_dev *snp_dev);
> +int snp_send_guest_request(struct snp_guest_dev *dev, struct snp_guest_req *req,
> +                          struct snp_guest_request_ioctl *rio);
> +bool snp_assign_vmpck(struct snp_guest_dev *dev, unsigned int vmpck_id);
> +bool snp_is_vmpck_empty(unsigned int vmpck_id);
> +
> +static inline void free_shared_pages(void *buf, size_t sz)
> +{
> +       unsigned int npages = PAGE_ALIGN(sz) >> PAGE_SHIFT;
> +       int ret;
> +
> +       if (!buf)
> +               return;
> +
> +       ret = set_memory_encrypted((unsigned long)buf, npages);
> +       if (ret) {
> +               WARN_ONCE(ret, "failed to restore encryption mask (leak it)\n");
> +               return;
> +       }
> +
> +       __free_pages(virt_to_page(buf), get_order(sz));
> +}
> +
> +static inline void *alloc_shared_pages(size_t sz)
> +{
> +       unsigned int npages = PAGE_ALIGN(sz) >> PAGE_SHIFT;
> +       struct page *page;
> +       int ret;
> +
> +       page = alloc_pages(GFP_KERNEL_ACCOUNT, get_order(sz));
> +       if (!page)
> +               return NULL;
> +
> +       ret = set_memory_decrypted((unsigned long)page_address(page), npages);
> +       if (ret) {
> +               pr_err("%s: failed to mark page shared, ret=%d\n", __func__, ret);
> +               __free_pages(page, get_order(sz));
> +               return NULL;
> +       }
> +
> +       return page_address(page);
> +}
> +
>  #endif /* __VIRT_SEVGUEST_H__ */
> diff --git a/arch/x86/include/asm/sev.h b/arch/x86/include/asm/sev.h
> index 78465a8c7dc6..783150458864 100644
> --- a/arch/x86/include/asm/sev.h
> +++ b/arch/x86/include/asm/sev.h
> @@ -93,16 +93,6 @@ extern bool handle_vc_boot_ghcb(struct pt_regs *regs);
>
>  #define RMPADJUST_VMSA_PAGE_BIT                BIT(16)
>
> -/* SNP Guest message request */
> -struct snp_req_data {
> -       unsigned long req_gpa;
> -       unsigned long resp_gpa;
> -};
> -
> -struct sev_guest_platform_data {
> -       u64 secrets_gpa;
> -};
> -
>  /*
>   * The secrets page contains 96-bytes of reserved field that can be used by
>   * the guest OS. The guest OS uses the area to save the message sequence
> diff --git a/arch/x86/kernel/sev.c b/arch/x86/kernel/sev.c
> index 479ea61f40f3..a413add2fd2c 100644
> --- a/arch/x86/kernel/sev.c
> +++ b/arch/x86/kernel/sev.c
> @@ -24,6 +24,7 @@
>  #include <linux/io.h>
>  #include <linux/psp-sev.h>
>  #include <uapi/linux/sev-guest.h>
> +#include <crypto/gcm.h>
>
>  #include <asm/cpu_entry_area.h>
>  #include <asm/stacktrace.h>
> @@ -2150,8 +2151,8 @@ static int __init init_sev_config(char *str)
>  }
>  __setup("sev=", init_sev_config);
>
> -int snp_issue_guest_request(struct snp_guest_req *req, struct snp_req_data *input,
> -                           struct snp_guest_request_ioctl *rio)
> +static int snp_issue_guest_request(struct snp_guest_req *req, struct snp_req_data *input,
> +                                  struct snp_guest_request_ioctl *rio)
>  {
>         struct ghcb_state state;
>         struct es_em_ctxt ctxt;
> @@ -2218,7 +2219,6 @@ int snp_issue_guest_request(struct snp_guest_req *req, struct snp_req_data *inpu
>
>         return ret;
>  }
> -EXPORT_SYMBOL_GPL(snp_issue_guest_request);
>
>  static struct platform_device sev_guest_device = {
>         .name           = "sev-guest",
> @@ -2227,22 +2227,451 @@ static struct platform_device sev_guest_device = {
>
>  static int __init snp_init_platform_device(void)
>  {
> -       struct sev_guest_platform_data data;
> -
>         if (!cc_platform_has(CC_ATTR_GUEST_SEV_SNP))
>                 return -ENODEV;
>
> -       if (!secrets_pa)
> +       if (platform_device_register(&sev_guest_device))
>                 return -ENODEV;
>
> -       data.secrets_gpa = secrets_pa;
> -       if (platform_device_add_data(&sev_guest_device, &data, sizeof(data)))
> +       pr_info("SNP guest platform device initialized.\n");
> +       return 0;
> +}
> +device_initcall(snp_init_platform_device);
> +
> +static struct sev_guest_platform_data *platform_data;
> +
> +static inline u8 *snp_get_vmpck(unsigned int vmpck_id)
> +{
> +       if (!platform_data)
> +               return NULL;
> +
> +       return platform_data->layout->vmpck0 + vmpck_id * VMPCK_KEY_LEN;
> +}
> +
> +static inline u32 *snp_get_os_area_msg_seqno(unsigned int vmpck_id)
> +{
> +       if (!platform_data)
> +               return NULL;
> +
> +       return &platform_data->layout->os_area.msg_seqno_0 + vmpck_id;
> +}
> +
> +bool snp_is_vmpck_empty(unsigned int vmpck_id)
> +{
> +       char zero_key[VMPCK_KEY_LEN] = {0};
> +       u8 *key = snp_get_vmpck(vmpck_id);
> +
> +       if (key)
> +               return !memcmp(key, zero_key, VMPCK_KEY_LEN);
> +
> +       return true;
> +}
> +EXPORT_SYMBOL_GPL(snp_is_vmpck_empty);
> +
> +/*
> + * If an error is received from the host or AMD Secure Processor (ASP) there
> + * are two options. Either retry the exact same encrypted request or discontinue
> + * using the VMPCK.
> + *
> + * This is because in the current encryption scheme GHCB v2 uses AES-GCM to
> + * encrypt the requests. The IV for this scheme is the sequence number. GCM
> + * cannot tolerate IV reuse.
> + *
> + * The ASP FW v1.51 only increments the sequence numbers on a successful
> + * guest<->ASP back and forth and only accepts messages at its exact sequence
> + * number.
> + *
> + * So if the sequence number were to be reused the encryption scheme is
> + * vulnerable. If the sequence number were incremented for a fresh IV the ASP
> + * will reject the request.
> + */
> +static void snp_disable_vmpck(struct snp_guest_dev *snp_dev)
> +{
> +       u8 *key = snp_get_vmpck(snp_dev->vmpck_id);
> +
> +       pr_alert("Disabling vmpck_id %u to prevent IV reuse.\n", snp_dev->vmpck_id);
> +       memzero_explicit(key, VMPCK_KEY_LEN);
> +}
> +
> +static inline u64 __snp_get_msg_seqno(struct snp_guest_dev *snp_dev)
> +{
> +       u32 *os_area_msg_seqno = snp_get_os_area_msg_seqno(snp_dev->vmpck_id);
> +       u64 count;
> +
> +       if (!os_area_msg_seqno) {
> +               pr_err("SNP unable to get message sequence counter\n");
> +               return 0;
> +       }
> +
> +       lockdep_assert_held(&snp_dev->cmd_mutex);
> +
> +       /* Read the current message sequence counter from secrets pages */
> +       count = *os_area_msg_seqno;
> +
> +       return count + 1;
> +}
> +
> +/* Return a non-zero on success */
> +static u64 snp_get_msg_seqno(struct snp_guest_dev *snp_dev)
> +{
> +       u64 count = __snp_get_msg_seqno(snp_dev);
> +
> +       /*
> +        * The message sequence counter for the SNP guest request is a  64-bit
> +        * value but the version 2 of GHCB specification defines a 32-bit storage
> +        * for it. If the counter exceeds the 32-bit value then return zero.
> +        * The caller should check the return value, but if the caller happens to
> +        * not check the value and use it, then the firmware treats zero as an
> +        * invalid number and will fail the  message request.
> +        */
> +       if (count >= UINT_MAX) {
> +               pr_err("SNP request message sequence counter overflow\n");
> +               return 0;
> +       }
> +
> +       return count;
> +}
> +
> +static void snp_inc_msg_seqno(struct snp_guest_dev *snp_dev)
> +{
> +       u32 *os_area_msg_seqno = snp_get_os_area_msg_seqno(snp_dev->vmpck_id);
> +
> +       if (!os_area_msg_seqno) {
> +               pr_err("SNP unable to get message sequence counter\n");
> +               return;
> +       }
> +
> +       lockdep_assert_held(&snp_dev->cmd_mutex);
> +
> +       /*
> +        * The counter is also incremented by the PSP, so increment it by 2
> +        * and save in secrets page.
> +        */
> +       *os_area_msg_seqno += 2;
> +}
> +
> +static struct aesgcm_ctx *snp_init_crypto(unsigned int vmpck_id)
> +{
> +       struct aesgcm_ctx *ctx;
> +       u8 *key;
> +
> +       if (snp_is_vmpck_empty(vmpck_id)) {
> +               pr_err("VM communication key VMPCK%u is null\n", vmpck_id);
> +               return NULL;
> +       }
> +
> +       ctx = kzalloc(sizeof(*ctx), GFP_KERNEL_ACCOUNT);
> +       if (!ctx)
> +               return NULL;
> +
> +       key = snp_get_vmpck(vmpck_id);
> +       if (aesgcm_expandkey(ctx, key, VMPCK_KEY_LEN, AUTHTAG_LEN)) {
> +               pr_err("Crypto context initialization failed\n");
> +               kfree(ctx);
> +               return NULL;
> +       }
> +
> +       return ctx;
> +}
> +
> +int snp_setup_psp_messaging(struct snp_guest_dev *snp_dev)
> +{
> +       struct sev_guest_platform_data *pdata;
> +       int ret;
> +
> +       if (!cc_platform_has(CC_ATTR_GUEST_SEV_SNP)) {

Note that this may be going away in favor of an
cpu_feature_enabled(X86_FEATURE_...) check given Kirill's "[PATCH]
x86/coco, x86/sev: Use cpu_feature_enabled() to detect SEV guest
flavor"

> +               pr_err("SNP not supported\n");
> +               return 0;
> +       }
> +
> +       if (platform_data) {
> +               pr_debug("SNP platform data already initialized.\n");
> +               goto create_ctx;
> +       }
> +
> +       if (!secrets_pa) {
> +               pr_err("SNP secrets page not found\n");
>                 return -ENODEV;
> +       }
>
> -       if (platform_device_register(&sev_guest_device))
> +       pdata = kzalloc(sizeof(struct sev_guest_platform_data), GFP_KERNEL);
> +       if (!pdata) {
> +               pr_err("Allocation of SNP guest platform data failed\n");
> +               return -ENOMEM;
> +       }
> +
> +       pdata->layout = (__force void *)ioremap_encrypted(secrets_pa, PAGE_SIZE);
> +       if (!pdata->layout) {
> +               pr_err("Failed to map SNP secrets page.\n");
> +               goto e_free_pdata;
> +       }
> +
> +       ret = -ENOMEM;
> +       /* Allocate the shared page used for the request and response message. */
> +       pdata->request = alloc_shared_pages(sizeof(struct snp_guest_msg));
> +       if (!pdata->request)
> +               goto e_unmap;
> +
> +       pdata->response = alloc_shared_pages(sizeof(struct snp_guest_msg));
> +       if (!pdata->response)
> +               goto e_free_request;
> +
> +       /* initial the input address for guest request */
> +       pdata->input.req_gpa = __pa(pdata->request);
> +       pdata->input.resp_gpa = __pa(pdata->response);
> +       platform_data = pdata;
> +
> +create_ctx:
> +       ret = -EIO;
> +       snp_dev->ctx = snp_init_crypto(snp_dev->vmpck_id);
> +       if (!snp_dev->ctx) {
> +               pr_err("SNP crypto context initialization failed\n");
> +               platform_data = NULL;
> +               goto e_free_response;
> +       }
> +
> +       snp_dev->pdata = platform_data;
> +
> +       return 0;
> +
> +e_free_response:
> +       free_shared_pages(pdata->response, sizeof(struct snp_guest_msg));
> +e_free_request:
> +       free_shared_pages(pdata->request, sizeof(struct snp_guest_msg));
> +e_unmap:
> +       iounmap(pdata->layout);
> +e_free_pdata:
> +       kfree(pdata);
> +
> +       return ret;
> +}
> +EXPORT_SYMBOL_GPL(snp_setup_psp_messaging);
> +
> +static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, struct snp_guest_req *guest_req,
> +                                 struct sev_guest_platform_data *pdata)
> +{
> +       struct snp_guest_msg *resp = &snp_dev->secret_response;
> +       struct snp_guest_msg *req = &snp_dev->secret_request;
> +       struct snp_guest_msg_hdr *req_hdr = &req->hdr;
> +       struct snp_guest_msg_hdr *resp_hdr = &resp->hdr;
> +       struct aesgcm_ctx *ctx = snp_dev->ctx;
> +       u8 iv[GCM_AES_IV_SIZE] = {};
> +
> +       pr_debug("response [seqno %lld type %d version %d sz %d]\n",
> +                resp_hdr->msg_seqno, resp_hdr->msg_type, resp_hdr->msg_version,
> +                resp_hdr->msg_sz);
> +
> +       /* Copy response from shared memory to encrypted memory. */
> +       memcpy(resp, pdata->response, sizeof(*resp));
> +
> +       /* Verify that the sequence counter is incremented by 1 */
> +       if (unlikely(resp_hdr->msg_seqno != (req_hdr->msg_seqno + 1)))
> +               return -EBADMSG;
> +
> +       /* Verify response message type and version number. */
> +       if (resp_hdr->msg_type != (req_hdr->msg_type + 1) ||
> +           resp_hdr->msg_version != req_hdr->msg_version)
> +               return -EBADMSG;
> +
> +       /*
> +        * If the message size is greater than our buffer length then return
> +        * an error.
> +        */
> +       if (unlikely((resp_hdr->msg_sz + ctx->authsize) > guest_req->resp_sz))
> +               return -EBADMSG;
> +
> +       /* Decrypt the payload */
> +       memcpy(iv, &resp_hdr->msg_seqno, sizeof(resp_hdr->msg_seqno));
> +       if (!aesgcm_decrypt(ctx, guest_req->resp_buf, resp->payload, resp_hdr->msg_sz,
> +                           &resp_hdr->algo, AAD_LEN, iv, resp_hdr->authtag))
> +               return -EBADMSG;
> +
> +       return 0;
> +}
> +
> +static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, struct snp_guest_req *req)
> +{
> +       struct snp_guest_msg *msg = &snp_dev->secret_request;
> +       struct snp_guest_msg_hdr *hdr = &msg->hdr;
> +       struct aesgcm_ctx *ctx = snp_dev->ctx;
> +       u8 iv[GCM_AES_IV_SIZE] = {};
> +
> +       memset(msg, 0, sizeof(*msg));
> +
> +       hdr->algo = SNP_AEAD_AES_256_GCM;
> +       hdr->hdr_version = MSG_HDR_VER;
> +       hdr->hdr_sz = sizeof(*hdr);
> +       hdr->msg_type = req->msg_type;
> +       hdr->msg_version = req->msg_version;
> +       hdr->msg_seqno = seqno;
> +       hdr->msg_vmpck = req->vmpck_id;
> +       hdr->msg_sz = req->req_sz;
> +
> +       /* Verify the sequence number is non-zero */
> +       if (!hdr->msg_seqno)
> +               return -ENOSR;
> +
> +       pr_debug("request [seqno %lld type %d version %d sz %d]\n",
> +                hdr->msg_seqno, hdr->msg_type, hdr->msg_version, hdr->msg_sz);
> +
> +       if (WARN_ON((req->req_sz + ctx->authsize) > sizeof(msg->payload)))
> +               return -EBADMSG;
> +
> +       memcpy(iv, &hdr->msg_seqno, sizeof(hdr->msg_seqno));
> +       aesgcm_encrypt(ctx, msg->payload, req->req_buf, req->req_sz, &hdr->algo,
> +                      AAD_LEN, iv, hdr->authtag);
> +
> +       return 0;
> +}
> +
> +static int __handle_guest_request(struct snp_guest_dev *snp_dev, struct snp_guest_req *req,
> +                                 struct snp_guest_request_ioctl *rio,
> +                                 struct sev_guest_platform_data *pdata)
> +{
> +       unsigned long req_start = jiffies;
> +       unsigned int override_npages = 0;
> +       u64 override_err = 0;
> +       int rc;
> +
> +retry_request:
> +       /*
> +        * Call firmware to process the request. In this function the encrypted
> +        * message enters shared memory with the host. So after this call the
> +        * sequence number must be incremented or the VMPCK must be deleted to
> +        * prevent reuse of the IV.
> +        */
> +       rc = snp_issue_guest_request(req, &pdata->input, rio);
> +       switch (rc) {
> +       case -ENOSPC:
> +               /*
> +                * If the extended guest request fails due to having too
> +                * small of a certificate data buffer, retry the same
> +                * guest request without the extended data request in
> +                * order to increment the sequence number and thus avoid
> +                * IV reuse.
> +                */
> +               override_npages = req->data_npages;
> +               req->exit_code  = SVM_VMGEXIT_GUEST_REQUEST;
> +
> +               /*
> +                * Override the error to inform callers the given extended
> +                * request buffer size was too small and give the caller the
> +                * required buffer size.
> +                */
> +               override_err    = SNP_GUEST_VMM_ERR(SNP_GUEST_VMM_ERR_INVALID_LEN);
> +
> +               /*
> +                * If this call to the firmware succeeds, the sequence number can
> +                * be incremented allowing for continued use of the VMPCK. If
> +                * there is an error reflected in the return value, this value
> +                * is checked further down and the result will be the deletion
> +                * of the VMPCK and the error code being propagated back to the
> +                * user as an ioctl() return code.
> +                */
> +               goto retry_request;
> +
> +       /*
> +        * The host may return SNP_GUEST_REQ_ERR_BUSY if the request has been
> +        * throttled. Retry in the driver to avoid returning and reusing the
> +        * message sequence number on a different message.
> +        */
> +       case -EAGAIN:
> +               if (jiffies - req_start > SNP_REQ_MAX_RETRY_DURATION) {
> +                       rc = -ETIMEDOUT;
> +                       break;
> +               }
> +               schedule_timeout_killable(SNP_REQ_RETRY_DELAY);
> +               goto retry_request;
> +       }
> +
> +       /*
> +        * Increment the message sequence number. There is no harm in doing
> +        * this now because decryption uses the value stored in the response
> +        * structure and any failure will wipe the VMPCK, preventing further
> +        * use anyway.
> +        */
> +       snp_inc_msg_seqno(snp_dev);
> +
> +       if (override_err) {
> +               rio->exitinfo2 = override_err;
> +
> +               /*
> +                * If an extended guest request was issued and the supplied certificate
> +                * buffer was not large enough, a standard guest request was issued to
> +                * prevent IV reuse. If the standard request was successful, return -EIO
> +                * back to the caller as would have originally been returned.
> +                */
> +               if (!rc && override_err == SNP_GUEST_VMM_ERR(SNP_GUEST_VMM_ERR_INVALID_LEN))
> +                       rc = -EIO;
> +       }
> +
> +       if (override_npages)
> +               req->data_npages = override_npages;
> +
> +       return rc;
> +}
> +
> +int snp_send_guest_request(struct snp_guest_dev *snp_dev, struct snp_guest_req *req,
> +                          struct snp_guest_request_ioctl *rio)
> +{
> +       struct sev_guest_platform_data *pdata;
> +       u64 seqno;
> +       int rc;
> +
> +       if (!snp_dev || !snp_dev->pdata || !req || !rio)
>                 return -ENODEV;
>
> -       pr_info("SNP guest platform device initialized.\n");
> +       pdata = snp_dev->pdata;
> +
> +       /* Get message sequence and verify that its a non-zero */
> +       seqno = snp_get_msg_seqno(snp_dev);
> +       if (!seqno)
> +               return -EIO;
> +
> +       /* Clear shared memory's response for the host to populate. */
> +       memset(pdata->response, 0, sizeof(struct snp_guest_msg));
> +
> +       /* Encrypt the userspace provided payload in pdata->secret_request. */
> +       rc = enc_payload(snp_dev, seqno, req);
> +       if (rc)
> +               return rc;
> +
> +       /*
> +        * Write the fully encrypted request to the shared unencrypted
> +        * request page.
> +        */
> +       memcpy(pdata->request, &snp_dev->secret_request, sizeof(snp_dev->secret_request));
> +
> +       rc = __handle_guest_request(snp_dev, req, rio, pdata);
> +       if (rc) {
> +               if (rc == -EIO &&
> +                   rio->exitinfo2 == SNP_GUEST_VMM_ERR(SNP_GUEST_VMM_ERR_INVALID_LEN))
> +                       return rc;
> +
> +               pr_alert("Detected error from ASP request. rc: %d, exitinfo2: 0x%llx\n",
> +                        rc, rio->exitinfo2);
> +               snp_disable_vmpck(snp_dev);
> +               return rc;
> +       }
> +
> +       rc = verify_and_dec_payload(snp_dev, req, pdata);
> +       if (rc) {
> +               pr_alert("Detected unexpected decode failure from ASP. rc: %d\n", rc);
> +               snp_disable_vmpck(snp_dev);
> +               return rc;
> +       }
> +
>         return 0;
>  }
> -device_initcall(snp_init_platform_device);
> +EXPORT_SYMBOL_GPL(snp_send_guest_request);
> +
> +bool snp_assign_vmpck(struct snp_guest_dev *dev, unsigned int vmpck_id)
> +{
> +       if (WARN_ON(vmpck_id > 3))

This constant 3 should be #define'd, I believe.

> +               return false;
> +
> +       dev->vmpck_id = vmpck_id;
> +
> +       return true;
> +}
> +EXPORT_SYMBOL_GPL(snp_assign_vmpck);
> diff --git a/drivers/virt/coco/sev-guest/Kconfig b/drivers/virt/coco/sev-guest/Kconfig
> index 0b772bd921d8..a6405ab6c2c3 100644
> --- a/drivers/virt/coco/sev-guest/Kconfig
> +++ b/drivers/virt/coco/sev-guest/Kconfig
> @@ -2,7 +2,6 @@ config SEV_GUEST
>         tristate "AMD SEV Guest driver"
>         default m
>         depends on AMD_MEM_ENCRYPT
> -       select CRYPTO_LIB_AESGCM
>         select TSM_REPORTS
>         help
>           SEV-SNP firmware provides the guest a mechanism to communicate with
> diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c
> index 0f2134deca51..1cdf7ab04d39 100644
> --- a/drivers/virt/coco/sev-guest/sev-guest.c
> +++ b/drivers/virt/coco/sev-guest/sev-guest.c
> @@ -31,130 +31,10 @@
>
>  #define DEVICE_NAME    "sev-guest"
>
> -#define SNP_REQ_MAX_RETRY_DURATION     (60*HZ)
> -#define SNP_REQ_RETRY_DELAY            (2*HZ)
> -
> -struct snp_guest_dev {
> -       struct device *dev;
> -       struct miscdevice misc;
> -
> -       /* Mutex to serialize the shared buffer access and command handling. */
> -       struct mutex cmd_mutex;
> -
> -       void *certs_data;
> -       struct aesgcm_ctx *ctx;
> -       /* request and response are in unencrypted memory */
> -       struct snp_guest_msg *request, *response;
> -
> -       /*
> -        * Avoid information leakage by double-buffering shared messages
> -        * in fields that are in regular encrypted memory.
> -        */
> -       struct snp_guest_msg secret_request, secret_response;
> -
> -       struct snp_secrets_page_layout *layout;
> -       struct snp_req_data input;
> -       union {
> -               struct snp_report_req report;
> -               struct snp_derived_key_req derived_key;
> -               struct snp_ext_report_req ext_report;
> -       } req;
> -       unsigned int vmpck_id;
> -};
> -
>  static u32 vmpck_id;
>  module_param(vmpck_id, uint, 0444);
>  MODULE_PARM_DESC(vmpck_id, "The VMPCK ID to use when communicating with the PSP.");
>
> -static inline u8 *snp_get_vmpck(struct snp_guest_dev *snp_dev)
> -{
> -       return snp_dev->layout->vmpck0 + snp_dev->vmpck_id * VMPCK_KEY_LEN;
> -}
> -
> -static inline u32 *snp_get_os_area_msg_seqno(struct snp_guest_dev *snp_dev)
> -{
> -       return &snp_dev->layout->os_area.msg_seqno_0 + snp_dev->vmpck_id;
> -}
> -
> -static bool snp_is_vmpck_empty(struct snp_guest_dev *snp_dev)
> -{
> -       char zero_key[VMPCK_KEY_LEN] = {0};
> -       u8 *key = snp_get_vmpck(snp_dev);
> -
> -       return !memcmp(key, zero_key, VMPCK_KEY_LEN);
> -}
> -
> -/*
> - * If an error is received from the host or AMD Secure Processor (ASP) there
> - * are two options. Either retry the exact same encrypted request or discontinue
> - * using the VMPCK.
> - *
> - * This is because in the current encryption scheme GHCB v2 uses AES-GCM to
> - * encrypt the requests. The IV for this scheme is the sequence number. GCM
> - * cannot tolerate IV reuse.
> - *
> - * The ASP FW v1.51 only increments the sequence numbers on a successful
> - * guest<->ASP back and forth and only accepts messages at its exact sequence
> - * number.
> - *
> - * So if the sequence number were to be reused the encryption scheme is
> - * vulnerable. If the sequence number were incremented for a fresh IV the ASP
> - * will reject the request.
> - */
> -static void snp_disable_vmpck(struct snp_guest_dev *snp_dev)
> -{
> -       u8 *key = snp_get_vmpck(snp_dev);
> -
> -       dev_alert(snp_dev->dev, "Disabling vmpck_id %u to prevent IV reuse.\n",
> -                 snp_dev->vmpck_id);
> -       memzero_explicit(key, VMPCK_KEY_LEN);
> -}
> -
> -static inline u64 __snp_get_msg_seqno(struct snp_guest_dev *snp_dev)
> -{
> -       u32 *os_area_msg_seqno = snp_get_os_area_msg_seqno(snp_dev);
> -       u64 count;
> -
> -       lockdep_assert_held(&snp_dev->cmd_mutex);
> -
> -       /* Read the current message sequence counter from secrets pages */
> -       count = *os_area_msg_seqno;
> -
> -       return count + 1;
> -}
> -
> -/* Return a non-zero on success */
> -static u64 snp_get_msg_seqno(struct snp_guest_dev *snp_dev)
> -{
> -       u64 count = __snp_get_msg_seqno(snp_dev);
> -
> -       /*
> -        * The message sequence counter for the SNP guest request is a  64-bit
> -        * value but the version 2 of GHCB specification defines a 32-bit storage
> -        * for it. If the counter exceeds the 32-bit value then return zero.
> -        * The caller should check the return value, but if the caller happens to
> -        * not check the value and use it, then the firmware treats zero as an
> -        * invalid number and will fail the  message request.
> -        */
> -       if (count >= UINT_MAX) {
> -               dev_err(snp_dev->dev, "request message sequence counter overflow\n");
> -               return 0;
> -       }
> -
> -       return count;
> -}
> -
> -static void snp_inc_msg_seqno(struct snp_guest_dev *snp_dev)
> -{
> -       u32 *os_area_msg_seqno = snp_get_os_area_msg_seqno(snp_dev);
> -
> -       /*
> -        * The counter is also incremented by the PSP, so increment it by 2
> -        * and save in secrets page.
> -        */
> -       *os_area_msg_seqno += 2;
> -}
> -
>  static inline struct snp_guest_dev *to_snp_dev(struct file *file)
>  {
>         struct miscdevice *dev = file->private_data;
> @@ -162,241 +42,6 @@ static inline struct snp_guest_dev *to_snp_dev(struct file *file)
>         return container_of(dev, struct snp_guest_dev, misc);
>  }
>
> -static struct aesgcm_ctx *snp_init_crypto(struct snp_guest_dev *snp_dev)
> -{
> -       struct aesgcm_ctx *ctx;
> -       u8 *key;
> -
> -       if (snp_is_vmpck_empty(snp_dev)) {
> -               pr_err("VM communication key VMPCK%u is null\n", vmpck_id);
> -               return NULL;
> -       }
> -
> -       ctx = kzalloc(sizeof(*ctx), GFP_KERNEL_ACCOUNT);
> -       if (!ctx)
> -               return NULL;
> -
> -       key = snp_get_vmpck(snp_dev);
> -       if (aesgcm_expandkey(ctx, key, VMPCK_KEY_LEN, AUTHTAG_LEN)) {
> -               pr_err("Crypto context initialization failed\n");
> -               kfree(ctx);
> -               return NULL;
> -       }
> -
> -       return ctx;
> -}
> -
> -static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, struct snp_guest_req *guest_req)
> -{
> -       struct snp_guest_msg *resp = &snp_dev->secret_response;
> -       struct snp_guest_msg *req = &snp_dev->secret_request;
> -       struct snp_guest_msg_hdr *req_hdr = &req->hdr;
> -       struct snp_guest_msg_hdr *resp_hdr = &resp->hdr;
> -       struct aesgcm_ctx *ctx = snp_dev->ctx;
> -       u8 iv[GCM_AES_IV_SIZE] = {};
> -
> -       pr_debug("response [seqno %lld type %d version %d sz %d]\n",
> -                resp_hdr->msg_seqno, resp_hdr->msg_type, resp_hdr->msg_version,
> -                resp_hdr->msg_sz);
> -
> -       /* Copy response from shared memory to encrypted memory. */
> -       memcpy(resp, snp_dev->response, sizeof(*resp));
> -
> -       /* Verify that the sequence counter is incremented by 1 */
> -       if (unlikely(resp_hdr->msg_seqno != (req_hdr->msg_seqno + 1)))
> -               return -EBADMSG;
> -
> -       /* Verify response message type and version number. */
> -       if (resp_hdr->msg_type != (req_hdr->msg_type + 1) ||
> -           resp_hdr->msg_version != req_hdr->msg_version)
> -               return -EBADMSG;
> -
> -       /*
> -        * If the message size is greater than our buffer length then return
> -        * an error.
> -        */
> -       if (unlikely((resp_hdr->msg_sz + ctx->authsize) > guest_req->resp_sz))
> -               return -EBADMSG;
> -
> -       /* Decrypt the payload */
> -       memcpy(iv, &resp_hdr->msg_seqno, sizeof(resp_hdr->msg_seqno));
> -       if (!aesgcm_decrypt(ctx, guest_req->resp_buf, resp->payload, resp_hdr->msg_sz,
> -                           &resp_hdr->algo, AAD_LEN, iv, resp_hdr->authtag))
> -               return -EBADMSG;
> -
> -       return 0;
> -}
> -
> -static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, struct snp_guest_req *req)
> -{
> -       struct snp_guest_msg *msg = &snp_dev->secret_request;
> -       struct snp_guest_msg_hdr *hdr = &msg->hdr;
> -       struct aesgcm_ctx *ctx = snp_dev->ctx;
> -       u8 iv[GCM_AES_IV_SIZE] = {};
> -
> -       memset(msg, 0, sizeof(*msg));
> -
> -       hdr->algo = SNP_AEAD_AES_256_GCM;
> -       hdr->hdr_version = MSG_HDR_VER;
> -       hdr->hdr_sz = sizeof(*hdr);
> -       hdr->msg_type = req->msg_type;
> -       hdr->msg_version = req->msg_version;
> -       hdr->msg_seqno = seqno;
> -       hdr->msg_vmpck = req->vmpck_id;
> -       hdr->msg_sz = req->req_sz;
> -
> -       /* Verify the sequence number is non-zero */
> -       if (!hdr->msg_seqno)
> -               return -ENOSR;
> -
> -       pr_debug("request [seqno %lld type %d version %d sz %d]\n",
> -                hdr->msg_seqno, hdr->msg_type, hdr->msg_version, hdr->msg_sz);
> -
> -       if (WARN_ON((req->req_sz + ctx->authsize) > sizeof(msg->payload)))
> -               return -EBADMSG;
> -
> -       memcpy(iv, &hdr->msg_seqno, sizeof(hdr->msg_seqno));
> -       aesgcm_encrypt(ctx, msg->payload, req->req_buf, req->req_sz, &hdr->algo,
> -                      AAD_LEN, iv, hdr->authtag);
> -
> -       return 0;
> -}
> -
> -static int __handle_guest_request(struct snp_guest_dev *snp_dev, struct snp_guest_req *req,
> -                                 struct snp_guest_request_ioctl *rio)
> -{
> -       unsigned long req_start = jiffies;
> -       unsigned int override_npages = 0;
> -       u64 override_err = 0;
> -       int rc;
> -
> -retry_request:
> -       /*
> -        * Call firmware to process the request. In this function the encrypted
> -        * message enters shared memory with the host. So after this call the
> -        * sequence number must be incremented or the VMPCK must be deleted to
> -        * prevent reuse of the IV.
> -        */
> -       rc = snp_issue_guest_request(req, &snp_dev->input, rio);
> -       switch (rc) {
> -       case -ENOSPC:
> -               /*
> -                * If the extended guest request fails due to having too
> -                * small of a certificate data buffer, retry the same
> -                * guest request without the extended data request in
> -                * order to increment the sequence number and thus avoid
> -                * IV reuse.
> -                */
> -               override_npages = req->data_npages;
> -               req->exit_code  = SVM_VMGEXIT_GUEST_REQUEST;
> -
> -               /*
> -                * Override the error to inform callers the given extended
> -                * request buffer size was too small and give the caller the
> -                * required buffer size.
> -                */
> -               override_err = SNP_GUEST_VMM_ERR(SNP_GUEST_VMM_ERR_INVALID_LEN);
> -
> -               /*
> -                * If this call to the firmware succeeds, the sequence number can
> -                * be incremented allowing for continued use of the VMPCK. If
> -                * there is an error reflected in the return value, this value
> -                * is checked further down and the result will be the deletion
> -                * of the VMPCK and the error code being propagated back to the
> -                * user as an ioctl() return code.
> -                */
> -               goto retry_request;
> -
> -       /*
> -        * The host may return SNP_GUEST_VMM_ERR_BUSY if the request has been
> -        * throttled. Retry in the driver to avoid returning and reusing the
> -        * message sequence number on a different message.
> -        */
> -       case -EAGAIN:
> -               if (jiffies - req_start > SNP_REQ_MAX_RETRY_DURATION) {
> -                       rc = -ETIMEDOUT;
> -                       break;
> -               }
> -               schedule_timeout_killable(SNP_REQ_RETRY_DELAY);
> -               goto retry_request;
> -       }
> -
> -       /*
> -        * Increment the message sequence number. There is no harm in doing
> -        * this now because decryption uses the value stored in the response
> -        * structure and any failure will wipe the VMPCK, preventing further
> -        * use anyway.
> -        */
> -       snp_inc_msg_seqno(snp_dev);
> -
> -       if (override_err) {
> -               rio->exitinfo2 = override_err;
> -
> -               /*
> -                * If an extended guest request was issued and the supplied certificate
> -                * buffer was not large enough, a standard guest request was issued to
> -                * prevent IV reuse. If the standard request was successful, return -EIO
> -                * back to the caller as would have originally been returned.
> -                */
> -               if (!rc && override_err == SNP_GUEST_VMM_ERR(SNP_GUEST_VMM_ERR_INVALID_LEN))
> -                       rc = -EIO;
> -       }
> -
> -       if (override_npages)
> -               req->data_npages = override_npages;
> -
> -       return rc;
> -}
> -
> -static int snp_send_guest_request(struct snp_guest_dev *snp_dev, struct snp_guest_req *req,
> -                                 struct snp_guest_request_ioctl *rio)
> -{
> -       u64 seqno;
> -       int rc;
> -
> -       /* Get message sequence and verify that its a non-zero */
> -       seqno = snp_get_msg_seqno(snp_dev);
> -       if (!seqno)
> -               return -EIO;
> -
> -       /* Clear shared memory's response for the host to populate. */
> -       memset(snp_dev->response, 0, sizeof(struct snp_guest_msg));
> -
> -       /* Encrypt the userspace provided payload in snp_dev->secret_request. */
> -       rc = enc_payload(snp_dev, seqno, req);
> -       if (rc)
> -               return rc;
> -
> -       /*
> -        * Write the fully encrypted request to the shared unencrypted
> -        * request page.
> -        */
> -       memcpy(snp_dev->request, &snp_dev->secret_request,
> -              sizeof(snp_dev->secret_request));
> -
> -       rc = __handle_guest_request(snp_dev, req, rio);
> -       if (rc) {
> -               if (rc == -EIO &&
> -                   rio->exitinfo2 == SNP_GUEST_VMM_ERR(SNP_GUEST_VMM_ERR_INVALID_LEN))
> -                       return rc;
> -
> -               dev_alert(snp_dev->dev,
> -                         "Detected error from ASP request. rc: %d, exitinfo2: 0x%llx\n",
> -                         rc, rio->exitinfo2);
> -               snp_disable_vmpck(snp_dev);
> -               return rc;
> -       }
> -
> -       rc = verify_and_dec_payload(snp_dev, req);
> -       if (rc) {
> -               dev_alert(snp_dev->dev, "Detected unexpected decode failure from ASP. rc: %d\n", rc);
> -               snp_disable_vmpck(snp_dev);
> -               return rc;
> -       }
> -
> -       return 0;
> -}
> -
>  struct snp_req_resp {
>         sockptr_t req_data;
>         sockptr_t resp_data;
> @@ -607,7 +252,7 @@ static long snp_guest_ioctl(struct file *file, unsigned int ioctl, unsigned long
>         mutex_lock(&snp_dev->cmd_mutex);
>
>         /* Check if the VMPCK is not empty */
> -       if (snp_is_vmpck_empty(snp_dev)) {
> +       if (snp_is_vmpck_empty(snp_dev->vmpck_id)) {
>                 dev_err_ratelimited(snp_dev->dev, "VMPCK is disabled\n");
>                 mutex_unlock(&snp_dev->cmd_mutex);
>                 return -ENOTTY;
> @@ -642,58 +287,11 @@ static long snp_guest_ioctl(struct file *file, unsigned int ioctl, unsigned long
>         return ret;
>  }
>
> -static void free_shared_pages(void *buf, size_t sz)
> -{
> -       unsigned int npages = PAGE_ALIGN(sz) >> PAGE_SHIFT;
> -       int ret;
> -
> -       if (!buf)
> -               return;
> -
> -       ret = set_memory_encrypted((unsigned long)buf, npages);
> -       if (ret) {
> -               WARN_ONCE(ret, "failed to restore encryption mask (leak it)\n");
> -               return;
> -       }
> -
> -       __free_pages(virt_to_page(buf), get_order(sz));
> -}
> -
> -static void *alloc_shared_pages(struct device *dev, size_t sz)
> -{
> -       unsigned int npages = PAGE_ALIGN(sz) >> PAGE_SHIFT;
> -       struct page *page;
> -       int ret;
> -
> -       page = alloc_pages(GFP_KERNEL_ACCOUNT, get_order(sz));
> -       if (!page)
> -               return NULL;
> -
> -       ret = set_memory_decrypted((unsigned long)page_address(page), npages);
> -       if (ret) {
> -               dev_err(dev, "failed to mark page shared, ret=%d\n", ret);
> -               __free_pages(page, get_order(sz));
> -               return NULL;
> -       }
> -
> -       return page_address(page);
> -}
> -
>  static const struct file_operations snp_guest_fops = {
>         .owner  = THIS_MODULE,
>         .unlocked_ioctl = snp_guest_ioctl,
>  };
>
> -bool snp_assign_vmpck(struct snp_guest_dev *dev, unsigned int vmpck_id)
> -{
> -       if (WARN_ON(vmpck_id > 3))
> -               return false;
> -
> -       dev->vmpck_id = vmpck_id;
> -
> -       return true;
> -}
> -
>  struct snp_msg_report_resp_hdr {
>         u32 status;
>         u32 report_size;
> @@ -727,7 +325,7 @@ static int sev_report_new(struct tsm_report *report, void *data)
>         guard(mutex)(&snp_dev->cmd_mutex);
>
>         /* Check if the VMPCK is not empty */
> -       if (snp_is_vmpck_empty(snp_dev)) {
> +       if (snp_is_vmpck_empty(snp_dev->vmpck_id)) {
>                 dev_err_ratelimited(snp_dev->dev, "VMPCK is disabled\n");
>                 return -ENOTTY;
>         }
> @@ -820,76 +418,43 @@ static void unregister_sev_tsm(void *data)
>
>  static int __init sev_guest_probe(struct platform_device *pdev)
>  {
> -       struct snp_secrets_page_layout *layout;
> -       struct sev_guest_platform_data *data;
>         struct device *dev = &pdev->dev;
>         struct snp_guest_dev *snp_dev;
>         struct miscdevice *misc;
> -       void __iomem *mapping;
>         int ret;
>
>         if (!cc_platform_has(CC_ATTR_GUEST_SEV_SNP))
>                 return -ENODEV;
>
> -       if (!dev->platform_data)
> -               return -ENODEV;
> -
> -       data = (struct sev_guest_platform_data *)dev->platform_data;
> -       mapping = ioremap_encrypted(data->secrets_gpa, PAGE_SIZE);
> -       if (!mapping)
> -               return -ENODEV;
> -
> -       layout = (__force void *)mapping;
> -
> -       ret = -ENOMEM;
>         snp_dev = devm_kzalloc(&pdev->dev, sizeof(struct snp_guest_dev), GFP_KERNEL);
>         if (!snp_dev)
> -               goto e_unmap;
> +               return -ENOMEM;
>
> -       ret = -EINVAL;
> -       snp_dev->layout = layout;
>         if (!snp_assign_vmpck(snp_dev, vmpck_id)) {
>                 dev_err(dev, "invalid vmpck id %u\n", vmpck_id);
> -               goto e_unmap;
> +               ret = -EINVAL;
> +               goto e_free_snpdev;
>         }
>
> -       /* Verify that VMPCK is not zero. */
> -       if (snp_is_vmpck_empty(snp_dev)) {
> -               dev_err(dev, "vmpck id %u is null\n", vmpck_id);
> -               goto e_unmap;
> +       if (snp_setup_psp_messaging(snp_dev)) {
> +               dev_err(dev, "Unable to setup PSP messaging vmpck id %u\n", snp_dev->vmpck_id);
> +               ret = -ENODEV;
> +               goto e_free_snpdev;
>         }
>
>         mutex_init(&snp_dev->cmd_mutex);
>         platform_set_drvdata(pdev, snp_dev);
>         snp_dev->dev = dev;
>
> -       /* Allocate the shared page used for the request and response message. */
> -       snp_dev->request = alloc_shared_pages(dev, sizeof(struct snp_guest_msg));
> -       if (!snp_dev->request)
> -               goto e_unmap;
> -
> -       snp_dev->response = alloc_shared_pages(dev, sizeof(struct snp_guest_msg));
> -       if (!snp_dev->response)
> -               goto e_free_request;
> -
> -       snp_dev->certs_data = alloc_shared_pages(dev, SEV_FW_BLOB_MAX_SIZE);
> +       snp_dev->certs_data = alloc_shared_pages(SEV_FW_BLOB_MAX_SIZE);
>         if (!snp_dev->certs_data)
> -               goto e_free_response;
> -
> -       ret = -EIO;
> -       snp_dev->ctx = snp_init_crypto(snp_dev);
> -       if (!snp_dev->ctx)
> -               goto e_free_cert_data;
> +               goto e_free_ctx;
>
>         misc = &snp_dev->misc;
>         misc->minor = MISC_DYNAMIC_MINOR;
>         misc->name = DEVICE_NAME;
>         misc->fops = &snp_guest_fops;
>
> -       /* initial the input address for guest request */
> -       snp_dev->input.req_gpa = __pa(snp_dev->request);
> -       snp_dev->input.resp_gpa = __pa(snp_dev->response);
> -
>         ret = tsm_register(&sev_tsm_ops, snp_dev, &tsm_report_extra_type);
>         if (ret)
>                 goto e_free_cert_data;
> @@ -900,21 +465,18 @@ static int __init sev_guest_probe(struct platform_device *pdev)
>
>         ret =  misc_register(misc);
>         if (ret)
> -               goto e_free_ctx;
> +               goto e_free_cert_data;
> +
> +       dev_info(dev, "Initialized SEV guest driver (using vmpck_id %u)\n", snp_dev->vmpck_id);
>
> -       dev_info(dev, "Initialized SEV guest driver (using vmpck_id %u)\n", vmpck_id);
>         return 0;
>
> -e_free_ctx:
> -       kfree(snp_dev->ctx);
>  e_free_cert_data:
>         free_shared_pages(snp_dev->certs_data, SEV_FW_BLOB_MAX_SIZE);
> -e_free_response:
> -       free_shared_pages(snp_dev->response, sizeof(struct snp_guest_msg));
> -e_free_request:
> -       free_shared_pages(snp_dev->request, sizeof(struct snp_guest_msg));
> -e_unmap:
> -       iounmap(mapping);
> +e_free_ctx:
> +       kfree(snp_dev->ctx);
> +e_free_snpdev:
> +       kfree(snp_dev);
>         return ret;
>  }
>
> @@ -923,10 +485,9 @@ static int __exit sev_guest_remove(struct platform_device *pdev)
>         struct snp_guest_dev *snp_dev = platform_get_drvdata(pdev);
>
>         free_shared_pages(snp_dev->certs_data, SEV_FW_BLOB_MAX_SIZE);
> -       free_shared_pages(snp_dev->response, sizeof(struct snp_guest_msg));
> -       free_shared_pages(snp_dev->request, sizeof(struct snp_guest_msg));
> -       kfree(snp_dev->ctx);
>         misc_deregister(&snp_dev->misc);
> +       kfree(snp_dev->ctx);
> +       kfree(snp_dev);
>
>         return 0;
>  }
> --
> 2.34.1
>
  
Nikunj A. Dadhania Dec. 6, 2023, 4:24 a.m. UTC | #4
On 12/5/2023 10:43 PM, Dionna Amalie Glaze wrote:
> On Tue, Nov 28, 2023 at 5:01 AM Nikunj A Dadhania <nikunj@amd.com> wrote:
>>
>> +int snp_setup_psp_messaging(struct snp_guest_dev *snp_dev)
>> +{
>> +       struct sev_guest_platform_data *pdata;
>> +       int ret;
>> +
>> +       if (!cc_platform_has(CC_ATTR_GUEST_SEV_SNP)) {
> 
> Note that this may be going away in favor of an
> cpu_feature_enabled(X86_FEATURE_...) check given Kirill's "[PATCH]
> x86/coco, x86/sev: Use cpu_feature_enabled() to detect SEV guest
> flavor"

I do not see a conclusion on that yet, so we should wait.

>> +bool snp_assign_vmpck(struct snp_guest_dev *dev, unsigned int vmpck_id)
>> +{
>> +       if (WARN_ON(vmpck_id > 3))
> 
> This constant 3 should be #define'd, I believe.

Sure, I am working on few changes related to mutex per vmpck that Tom had suggested offline, that will also need a #define.

Thanks
Nikunj
  

Patch

diff --git a/arch/x86/Kconfig b/arch/x86/Kconfig
index 3762f41bb092..b8f374ec5651 100644
--- a/arch/x86/Kconfig
+++ b/arch/x86/Kconfig
@@ -1534,6 +1534,7 @@  config AMD_MEM_ENCRYPT
 	select ARCH_HAS_CC_PLATFORM
 	select X86_MEM_ENCRYPT
 	select UNACCEPTED_MEMORY
+	select CRYPTO_LIB_AESGCM
 	help
 	  Say yes to enable support for the encryption of system memory.
 	  This requires an AMD processor that supports Secure Memory
diff --git a/arch/x86/include/asm/sev-guest.h b/arch/x86/include/asm/sev-guest.h
index 27cc15ad6131..16bf25c14e6f 100644
--- a/arch/x86/include/asm/sev-guest.h
+++ b/arch/x86/include/asm/sev-guest.h
@@ -11,6 +11,11 @@ 
 #define __VIRT_SEVGUEST_H__
 
 #include <linux/types.h>
+#include <linux/miscdevice.h>
+#include <asm/sev.h>
+
+#define SNP_REQ_MAX_RETRY_DURATION    (60*HZ)
+#define SNP_REQ_RETRY_DELAY           (2*HZ)
 
 #define MAX_AUTHTAG_LEN		32
 #define AUTHTAG_LEN		16
@@ -58,11 +63,52 @@  struct snp_guest_msg_hdr {
 	u8 rsvd3[35];
 } __packed;
 
+/* SNP Guest message request */
+struct snp_req_data {
+	unsigned long req_gpa;
+	unsigned long resp_gpa;
+};
+
 struct snp_guest_msg {
 	struct snp_guest_msg_hdr hdr;
 	u8 payload[4000];
 } __packed;
 
+struct sev_guest_platform_data {
+	/* request and response are in unencrypted memory */
+	struct snp_guest_msg *request;
+	struct snp_guest_msg *response;
+
+	struct snp_secrets_page_layout *layout;
+	struct snp_req_data input;
+};
+
+struct snp_guest_dev {
+	struct device *dev;
+	struct miscdevice misc;
+
+	/* Mutex to serialize the shared buffer access and command handling. */
+	struct mutex cmd_mutex;
+
+	void *certs_data;
+	struct aesgcm_ctx *ctx;
+
+	/*
+	 * Avoid information leakage by double-buffering shared messages
+	 * in fields that are in regular encrypted memory
+	 */
+	struct snp_guest_msg secret_request;
+	struct snp_guest_msg secret_response;
+
+	struct sev_guest_platform_data *pdata;
+	union {
+		struct snp_report_req report;
+		struct snp_derived_key_req derived_key;
+		struct snp_ext_report_req ext_report;
+	} req;
+	unsigned int vmpck_id;
+};
+
 struct snp_guest_req {
 	void *req_buf;
 	size_t req_sz;
@@ -79,6 +125,47 @@  struct snp_guest_req {
 	u8 msg_type;
 };
 
-int snp_issue_guest_request(struct snp_guest_req *req, struct snp_req_data *input,
-			    struct snp_guest_request_ioctl *rio);
+int snp_setup_psp_messaging(struct snp_guest_dev *snp_dev);
+int snp_send_guest_request(struct snp_guest_dev *dev, struct snp_guest_req *req,
+			   struct snp_guest_request_ioctl *rio);
+bool snp_assign_vmpck(struct snp_guest_dev *dev, unsigned int vmpck_id);
+bool snp_is_vmpck_empty(unsigned int vmpck_id);
+
+static inline void free_shared_pages(void *buf, size_t sz)
+{
+	unsigned int npages = PAGE_ALIGN(sz) >> PAGE_SHIFT;
+	int ret;
+
+	if (!buf)
+		return;
+
+	ret = set_memory_encrypted((unsigned long)buf, npages);
+	if (ret) {
+		WARN_ONCE(ret, "failed to restore encryption mask (leak it)\n");
+		return;
+	}
+
+	__free_pages(virt_to_page(buf), get_order(sz));
+}
+
+static inline void *alloc_shared_pages(size_t sz)
+{
+	unsigned int npages = PAGE_ALIGN(sz) >> PAGE_SHIFT;
+	struct page *page;
+	int ret;
+
+	page = alloc_pages(GFP_KERNEL_ACCOUNT, get_order(sz));
+	if (!page)
+		return NULL;
+
+	ret = set_memory_decrypted((unsigned long)page_address(page), npages);
+	if (ret) {
+		pr_err("%s: failed to mark page shared, ret=%d\n", __func__, ret);
+		__free_pages(page, get_order(sz));
+		return NULL;
+	}
+
+	return page_address(page);
+}
+
 #endif /* __VIRT_SEVGUEST_H__ */
diff --git a/arch/x86/include/asm/sev.h b/arch/x86/include/asm/sev.h
index 78465a8c7dc6..783150458864 100644
--- a/arch/x86/include/asm/sev.h
+++ b/arch/x86/include/asm/sev.h
@@ -93,16 +93,6 @@  extern bool handle_vc_boot_ghcb(struct pt_regs *regs);
 
 #define RMPADJUST_VMSA_PAGE_BIT		BIT(16)
 
-/* SNP Guest message request */
-struct snp_req_data {
-	unsigned long req_gpa;
-	unsigned long resp_gpa;
-};
-
-struct sev_guest_platform_data {
-	u64 secrets_gpa;
-};
-
 /*
  * The secrets page contains 96-bytes of reserved field that can be used by
  * the guest OS. The guest OS uses the area to save the message sequence
diff --git a/arch/x86/kernel/sev.c b/arch/x86/kernel/sev.c
index 479ea61f40f3..a413add2fd2c 100644
--- a/arch/x86/kernel/sev.c
+++ b/arch/x86/kernel/sev.c
@@ -24,6 +24,7 @@ 
 #include <linux/io.h>
 #include <linux/psp-sev.h>
 #include <uapi/linux/sev-guest.h>
+#include <crypto/gcm.h>
 
 #include <asm/cpu_entry_area.h>
 #include <asm/stacktrace.h>
@@ -2150,8 +2151,8 @@  static int __init init_sev_config(char *str)
 }
 __setup("sev=", init_sev_config);
 
-int snp_issue_guest_request(struct snp_guest_req *req, struct snp_req_data *input,
-			    struct snp_guest_request_ioctl *rio)
+static int snp_issue_guest_request(struct snp_guest_req *req, struct snp_req_data *input,
+				   struct snp_guest_request_ioctl *rio)
 {
 	struct ghcb_state state;
 	struct es_em_ctxt ctxt;
@@ -2218,7 +2219,6 @@  int snp_issue_guest_request(struct snp_guest_req *req, struct snp_req_data *inpu
 
 	return ret;
 }
-EXPORT_SYMBOL_GPL(snp_issue_guest_request);
 
 static struct platform_device sev_guest_device = {
 	.name		= "sev-guest",
@@ -2227,22 +2227,451 @@  static struct platform_device sev_guest_device = {
 
 static int __init snp_init_platform_device(void)
 {
-	struct sev_guest_platform_data data;
-
 	if (!cc_platform_has(CC_ATTR_GUEST_SEV_SNP))
 		return -ENODEV;
 
-	if (!secrets_pa)
+	if (platform_device_register(&sev_guest_device))
 		return -ENODEV;
 
-	data.secrets_gpa = secrets_pa;
-	if (platform_device_add_data(&sev_guest_device, &data, sizeof(data)))
+	pr_info("SNP guest platform device initialized.\n");
+	return 0;
+}
+device_initcall(snp_init_platform_device);
+
+static struct sev_guest_platform_data *platform_data;
+
+static inline u8 *snp_get_vmpck(unsigned int vmpck_id)
+{
+	if (!platform_data)
+		return NULL;
+
+	return platform_data->layout->vmpck0 + vmpck_id * VMPCK_KEY_LEN;
+}
+
+static inline u32 *snp_get_os_area_msg_seqno(unsigned int vmpck_id)
+{
+	if (!platform_data)
+		return NULL;
+
+	return &platform_data->layout->os_area.msg_seqno_0 + vmpck_id;
+}
+
+bool snp_is_vmpck_empty(unsigned int vmpck_id)
+{
+	char zero_key[VMPCK_KEY_LEN] = {0};
+	u8 *key = snp_get_vmpck(vmpck_id);
+
+	if (key)
+		return !memcmp(key, zero_key, VMPCK_KEY_LEN);
+
+	return true;
+}
+EXPORT_SYMBOL_GPL(snp_is_vmpck_empty);
+
+/*
+ * If an error is received from the host or AMD Secure Processor (ASP) there
+ * are two options. Either retry the exact same encrypted request or discontinue
+ * using the VMPCK.
+ *
+ * This is because in the current encryption scheme GHCB v2 uses AES-GCM to
+ * encrypt the requests. The IV for this scheme is the sequence number. GCM
+ * cannot tolerate IV reuse.
+ *
+ * The ASP FW v1.51 only increments the sequence numbers on a successful
+ * guest<->ASP back and forth and only accepts messages at its exact sequence
+ * number.
+ *
+ * So if the sequence number were to be reused the encryption scheme is
+ * vulnerable. If the sequence number were incremented for a fresh IV the ASP
+ * will reject the request.
+ */
+static void snp_disable_vmpck(struct snp_guest_dev *snp_dev)
+{
+	u8 *key = snp_get_vmpck(snp_dev->vmpck_id);
+
+	pr_alert("Disabling vmpck_id %u to prevent IV reuse.\n", snp_dev->vmpck_id);
+	memzero_explicit(key, VMPCK_KEY_LEN);
+}
+
+static inline u64 __snp_get_msg_seqno(struct snp_guest_dev *snp_dev)
+{
+	u32 *os_area_msg_seqno = snp_get_os_area_msg_seqno(snp_dev->vmpck_id);
+	u64 count;
+
+	if (!os_area_msg_seqno) {
+		pr_err("SNP unable to get message sequence counter\n");
+		return 0;
+	}
+
+	lockdep_assert_held(&snp_dev->cmd_mutex);
+
+	/* Read the current message sequence counter from secrets pages */
+	count = *os_area_msg_seqno;
+
+	return count + 1;
+}
+
+/* Return a non-zero on success */
+static u64 snp_get_msg_seqno(struct snp_guest_dev *snp_dev)
+{
+	u64 count = __snp_get_msg_seqno(snp_dev);
+
+	/*
+	 * The message sequence counter for the SNP guest request is a  64-bit
+	 * value but the version 2 of GHCB specification defines a 32-bit storage
+	 * for it. If the counter exceeds the 32-bit value then return zero.
+	 * The caller should check the return value, but if the caller happens to
+	 * not check the value and use it, then the firmware treats zero as an
+	 * invalid number and will fail the  message request.
+	 */
+	if (count >= UINT_MAX) {
+		pr_err("SNP request message sequence counter overflow\n");
+		return 0;
+	}
+
+	return count;
+}
+
+static void snp_inc_msg_seqno(struct snp_guest_dev *snp_dev)
+{
+	u32 *os_area_msg_seqno = snp_get_os_area_msg_seqno(snp_dev->vmpck_id);
+
+	if (!os_area_msg_seqno) {
+		pr_err("SNP unable to get message sequence counter\n");
+		return;
+	}
+
+	lockdep_assert_held(&snp_dev->cmd_mutex);
+
+	/*
+	 * The counter is also incremented by the PSP, so increment it by 2
+	 * and save in secrets page.
+	 */
+	*os_area_msg_seqno += 2;
+}
+
+static struct aesgcm_ctx *snp_init_crypto(unsigned int vmpck_id)
+{
+	struct aesgcm_ctx *ctx;
+	u8 *key;
+
+	if (snp_is_vmpck_empty(vmpck_id)) {
+		pr_err("VM communication key VMPCK%u is null\n", vmpck_id);
+		return NULL;
+	}
+
+	ctx = kzalloc(sizeof(*ctx), GFP_KERNEL_ACCOUNT);
+	if (!ctx)
+		return NULL;
+
+	key = snp_get_vmpck(vmpck_id);
+	if (aesgcm_expandkey(ctx, key, VMPCK_KEY_LEN, AUTHTAG_LEN)) {
+		pr_err("Crypto context initialization failed\n");
+		kfree(ctx);
+		return NULL;
+	}
+
+	return ctx;
+}
+
+int snp_setup_psp_messaging(struct snp_guest_dev *snp_dev)
+{
+	struct sev_guest_platform_data *pdata;
+	int ret;
+
+	if (!cc_platform_has(CC_ATTR_GUEST_SEV_SNP)) {
+		pr_err("SNP not supported\n");
+		return 0;
+	}
+
+	if (platform_data) {
+		pr_debug("SNP platform data already initialized.\n");
+		goto create_ctx;
+	}
+
+	if (!secrets_pa) {
+		pr_err("SNP secrets page not found\n");
 		return -ENODEV;
+	}
 
-	if (platform_device_register(&sev_guest_device))
+	pdata = kzalloc(sizeof(struct sev_guest_platform_data), GFP_KERNEL);
+	if (!pdata) {
+		pr_err("Allocation of SNP guest platform data failed\n");
+		return -ENOMEM;
+	}
+
+	pdata->layout = (__force void *)ioremap_encrypted(secrets_pa, PAGE_SIZE);
+	if (!pdata->layout) {
+		pr_err("Failed to map SNP secrets page.\n");
+		goto e_free_pdata;
+	}
+
+	ret = -ENOMEM;
+	/* Allocate the shared page used for the request and response message. */
+	pdata->request = alloc_shared_pages(sizeof(struct snp_guest_msg));
+	if (!pdata->request)
+		goto e_unmap;
+
+	pdata->response = alloc_shared_pages(sizeof(struct snp_guest_msg));
+	if (!pdata->response)
+		goto e_free_request;
+
+	/* initial the input address for guest request */
+	pdata->input.req_gpa = __pa(pdata->request);
+	pdata->input.resp_gpa = __pa(pdata->response);
+	platform_data = pdata;
+
+create_ctx:
+	ret = -EIO;
+	snp_dev->ctx = snp_init_crypto(snp_dev->vmpck_id);
+	if (!snp_dev->ctx) {
+		pr_err("SNP crypto context initialization failed\n");
+		platform_data = NULL;
+		goto e_free_response;
+	}
+
+	snp_dev->pdata = platform_data;
+
+	return 0;
+
+e_free_response:
+	free_shared_pages(pdata->response, sizeof(struct snp_guest_msg));
+e_free_request:
+	free_shared_pages(pdata->request, sizeof(struct snp_guest_msg));
+e_unmap:
+	iounmap(pdata->layout);
+e_free_pdata:
+	kfree(pdata);
+
+	return ret;
+}
+EXPORT_SYMBOL_GPL(snp_setup_psp_messaging);
+
+static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, struct snp_guest_req *guest_req,
+				  struct sev_guest_platform_data *pdata)
+{
+	struct snp_guest_msg *resp = &snp_dev->secret_response;
+	struct snp_guest_msg *req = &snp_dev->secret_request;
+	struct snp_guest_msg_hdr *req_hdr = &req->hdr;
+	struct snp_guest_msg_hdr *resp_hdr = &resp->hdr;
+	struct aesgcm_ctx *ctx = snp_dev->ctx;
+	u8 iv[GCM_AES_IV_SIZE] = {};
+
+	pr_debug("response [seqno %lld type %d version %d sz %d]\n",
+		 resp_hdr->msg_seqno, resp_hdr->msg_type, resp_hdr->msg_version,
+		 resp_hdr->msg_sz);
+
+	/* Copy response from shared memory to encrypted memory. */
+	memcpy(resp, pdata->response, sizeof(*resp));
+
+	/* Verify that the sequence counter is incremented by 1 */
+	if (unlikely(resp_hdr->msg_seqno != (req_hdr->msg_seqno + 1)))
+		return -EBADMSG;
+
+	/* Verify response message type and version number. */
+	if (resp_hdr->msg_type != (req_hdr->msg_type + 1) ||
+	    resp_hdr->msg_version != req_hdr->msg_version)
+		return -EBADMSG;
+
+	/*
+	 * If the message size is greater than our buffer length then return
+	 * an error.
+	 */
+	if (unlikely((resp_hdr->msg_sz + ctx->authsize) > guest_req->resp_sz))
+		return -EBADMSG;
+
+	/* Decrypt the payload */
+	memcpy(iv, &resp_hdr->msg_seqno, sizeof(resp_hdr->msg_seqno));
+	if (!aesgcm_decrypt(ctx, guest_req->resp_buf, resp->payload, resp_hdr->msg_sz,
+			    &resp_hdr->algo, AAD_LEN, iv, resp_hdr->authtag))
+		return -EBADMSG;
+
+	return 0;
+}
+
+static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, struct snp_guest_req *req)
+{
+	struct snp_guest_msg *msg = &snp_dev->secret_request;
+	struct snp_guest_msg_hdr *hdr = &msg->hdr;
+	struct aesgcm_ctx *ctx = snp_dev->ctx;
+	u8 iv[GCM_AES_IV_SIZE] = {};
+
+	memset(msg, 0, sizeof(*msg));
+
+	hdr->algo = SNP_AEAD_AES_256_GCM;
+	hdr->hdr_version = MSG_HDR_VER;
+	hdr->hdr_sz = sizeof(*hdr);
+	hdr->msg_type = req->msg_type;
+	hdr->msg_version = req->msg_version;
+	hdr->msg_seqno = seqno;
+	hdr->msg_vmpck = req->vmpck_id;
+	hdr->msg_sz = req->req_sz;
+
+	/* Verify the sequence number is non-zero */
+	if (!hdr->msg_seqno)
+		return -ENOSR;
+
+	pr_debug("request [seqno %lld type %d version %d sz %d]\n",
+		 hdr->msg_seqno, hdr->msg_type, hdr->msg_version, hdr->msg_sz);
+
+	if (WARN_ON((req->req_sz + ctx->authsize) > sizeof(msg->payload)))
+		return -EBADMSG;
+
+	memcpy(iv, &hdr->msg_seqno, sizeof(hdr->msg_seqno));
+	aesgcm_encrypt(ctx, msg->payload, req->req_buf, req->req_sz, &hdr->algo,
+		       AAD_LEN, iv, hdr->authtag);
+
+	return 0;
+}
+
+static int __handle_guest_request(struct snp_guest_dev *snp_dev, struct snp_guest_req *req,
+				  struct snp_guest_request_ioctl *rio,
+				  struct sev_guest_platform_data *pdata)
+{
+	unsigned long req_start = jiffies;
+	unsigned int override_npages = 0;
+	u64 override_err = 0;
+	int rc;
+
+retry_request:
+	/*
+	 * Call firmware to process the request. In this function the encrypted
+	 * message enters shared memory with the host. So after this call the
+	 * sequence number must be incremented or the VMPCK must be deleted to
+	 * prevent reuse of the IV.
+	 */
+	rc = snp_issue_guest_request(req, &pdata->input, rio);
+	switch (rc) {
+	case -ENOSPC:
+		/*
+		 * If the extended guest request fails due to having too
+		 * small of a certificate data buffer, retry the same
+		 * guest request without the extended data request in
+		 * order to increment the sequence number and thus avoid
+		 * IV reuse.
+		 */
+		override_npages = req->data_npages;
+		req->exit_code	= SVM_VMGEXIT_GUEST_REQUEST;
+
+		/*
+		 * Override the error to inform callers the given extended
+		 * request buffer size was too small and give the caller the
+		 * required buffer size.
+		 */
+		override_err	= SNP_GUEST_VMM_ERR(SNP_GUEST_VMM_ERR_INVALID_LEN);
+
+		/*
+		 * If this call to the firmware succeeds, the sequence number can
+		 * be incremented allowing for continued use of the VMPCK. If
+		 * there is an error reflected in the return value, this value
+		 * is checked further down and the result will be the deletion
+		 * of the VMPCK and the error code being propagated back to the
+		 * user as an ioctl() return code.
+		 */
+		goto retry_request;
+
+	/*
+	 * The host may return SNP_GUEST_REQ_ERR_BUSY if the request has been
+	 * throttled. Retry in the driver to avoid returning and reusing the
+	 * message sequence number on a different message.
+	 */
+	case -EAGAIN:
+		if (jiffies - req_start > SNP_REQ_MAX_RETRY_DURATION) {
+			rc = -ETIMEDOUT;
+			break;
+		}
+		schedule_timeout_killable(SNP_REQ_RETRY_DELAY);
+		goto retry_request;
+	}
+
+	/*
+	 * Increment the message sequence number. There is no harm in doing
+	 * this now because decryption uses the value stored in the response
+	 * structure and any failure will wipe the VMPCK, preventing further
+	 * use anyway.
+	 */
+	snp_inc_msg_seqno(snp_dev);
+
+	if (override_err) {
+		rio->exitinfo2 = override_err;
+
+		/*
+		 * If an extended guest request was issued and the supplied certificate
+		 * buffer was not large enough, a standard guest request was issued to
+		 * prevent IV reuse. If the standard request was successful, return -EIO
+		 * back to the caller as would have originally been returned.
+		 */
+		if (!rc && override_err == SNP_GUEST_VMM_ERR(SNP_GUEST_VMM_ERR_INVALID_LEN))
+			rc = -EIO;
+	}
+
+	if (override_npages)
+		req->data_npages = override_npages;
+
+	return rc;
+}
+
+int snp_send_guest_request(struct snp_guest_dev *snp_dev, struct snp_guest_req *req,
+			   struct snp_guest_request_ioctl *rio)
+{
+	struct sev_guest_platform_data *pdata;
+	u64 seqno;
+	int rc;
+
+	if (!snp_dev || !snp_dev->pdata || !req || !rio)
 		return -ENODEV;
 
-	pr_info("SNP guest platform device initialized.\n");
+	pdata = snp_dev->pdata;
+
+	/* Get message sequence and verify that its a non-zero */
+	seqno = snp_get_msg_seqno(snp_dev);
+	if (!seqno)
+		return -EIO;
+
+	/* Clear shared memory's response for the host to populate. */
+	memset(pdata->response, 0, sizeof(struct snp_guest_msg));
+
+	/* Encrypt the userspace provided payload in pdata->secret_request. */
+	rc = enc_payload(snp_dev, seqno, req);
+	if (rc)
+		return rc;
+
+	/*
+	 * Write the fully encrypted request to the shared unencrypted
+	 * request page.
+	 */
+	memcpy(pdata->request, &snp_dev->secret_request, sizeof(snp_dev->secret_request));
+
+	rc = __handle_guest_request(snp_dev, req, rio, pdata);
+	if (rc) {
+		if (rc == -EIO &&
+		    rio->exitinfo2 == SNP_GUEST_VMM_ERR(SNP_GUEST_VMM_ERR_INVALID_LEN))
+			return rc;
+
+		pr_alert("Detected error from ASP request. rc: %d, exitinfo2: 0x%llx\n",
+			 rc, rio->exitinfo2);
+		snp_disable_vmpck(snp_dev);
+		return rc;
+	}
+
+	rc = verify_and_dec_payload(snp_dev, req, pdata);
+	if (rc) {
+		pr_alert("Detected unexpected decode failure from ASP. rc: %d\n", rc);
+		snp_disable_vmpck(snp_dev);
+		return rc;
+	}
+
 	return 0;
 }
-device_initcall(snp_init_platform_device);
+EXPORT_SYMBOL_GPL(snp_send_guest_request);
+
+bool snp_assign_vmpck(struct snp_guest_dev *dev, unsigned int vmpck_id)
+{
+	if (WARN_ON(vmpck_id > 3))
+		return false;
+
+	dev->vmpck_id = vmpck_id;
+
+	return true;
+}
+EXPORT_SYMBOL_GPL(snp_assign_vmpck);
diff --git a/drivers/virt/coco/sev-guest/Kconfig b/drivers/virt/coco/sev-guest/Kconfig
index 0b772bd921d8..a6405ab6c2c3 100644
--- a/drivers/virt/coco/sev-guest/Kconfig
+++ b/drivers/virt/coco/sev-guest/Kconfig
@@ -2,7 +2,6 @@  config SEV_GUEST
 	tristate "AMD SEV Guest driver"
 	default m
 	depends on AMD_MEM_ENCRYPT
-	select CRYPTO_LIB_AESGCM
 	select TSM_REPORTS
 	help
 	  SEV-SNP firmware provides the guest a mechanism to communicate with
diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c
index 0f2134deca51..1cdf7ab04d39 100644
--- a/drivers/virt/coco/sev-guest/sev-guest.c
+++ b/drivers/virt/coco/sev-guest/sev-guest.c
@@ -31,130 +31,10 @@ 
 
 #define DEVICE_NAME	"sev-guest"
 
-#define SNP_REQ_MAX_RETRY_DURATION	(60*HZ)
-#define SNP_REQ_RETRY_DELAY		(2*HZ)
-
-struct snp_guest_dev {
-	struct device *dev;
-	struct miscdevice misc;
-
-	/* Mutex to serialize the shared buffer access and command handling. */
-	struct mutex cmd_mutex;
-
-	void *certs_data;
-	struct aesgcm_ctx *ctx;
-	/* request and response are in unencrypted memory */
-	struct snp_guest_msg *request, *response;
-
-	/*
-	 * Avoid information leakage by double-buffering shared messages
-	 * in fields that are in regular encrypted memory.
-	 */
-	struct snp_guest_msg secret_request, secret_response;
-
-	struct snp_secrets_page_layout *layout;
-	struct snp_req_data input;
-	union {
-		struct snp_report_req report;
-		struct snp_derived_key_req derived_key;
-		struct snp_ext_report_req ext_report;
-	} req;
-	unsigned int vmpck_id;
-};
-
 static u32 vmpck_id;
 module_param(vmpck_id, uint, 0444);
 MODULE_PARM_DESC(vmpck_id, "The VMPCK ID to use when communicating with the PSP.");
 
-static inline u8 *snp_get_vmpck(struct snp_guest_dev *snp_dev)
-{
-	return snp_dev->layout->vmpck0 + snp_dev->vmpck_id * VMPCK_KEY_LEN;
-}
-
-static inline u32 *snp_get_os_area_msg_seqno(struct snp_guest_dev *snp_dev)
-{
-	return &snp_dev->layout->os_area.msg_seqno_0 + snp_dev->vmpck_id;
-}
-
-static bool snp_is_vmpck_empty(struct snp_guest_dev *snp_dev)
-{
-	char zero_key[VMPCK_KEY_LEN] = {0};
-	u8 *key = snp_get_vmpck(snp_dev);
-
-	return !memcmp(key, zero_key, VMPCK_KEY_LEN);
-}
-
-/*
- * If an error is received from the host or AMD Secure Processor (ASP) there
- * are two options. Either retry the exact same encrypted request or discontinue
- * using the VMPCK.
- *
- * This is because in the current encryption scheme GHCB v2 uses AES-GCM to
- * encrypt the requests. The IV for this scheme is the sequence number. GCM
- * cannot tolerate IV reuse.
- *
- * The ASP FW v1.51 only increments the sequence numbers on a successful
- * guest<->ASP back and forth and only accepts messages at its exact sequence
- * number.
- *
- * So if the sequence number were to be reused the encryption scheme is
- * vulnerable. If the sequence number were incremented for a fresh IV the ASP
- * will reject the request.
- */
-static void snp_disable_vmpck(struct snp_guest_dev *snp_dev)
-{
-	u8 *key = snp_get_vmpck(snp_dev);
-
-	dev_alert(snp_dev->dev, "Disabling vmpck_id %u to prevent IV reuse.\n",
-		  snp_dev->vmpck_id);
-	memzero_explicit(key, VMPCK_KEY_LEN);
-}
-
-static inline u64 __snp_get_msg_seqno(struct snp_guest_dev *snp_dev)
-{
-	u32 *os_area_msg_seqno = snp_get_os_area_msg_seqno(snp_dev);
-	u64 count;
-
-	lockdep_assert_held(&snp_dev->cmd_mutex);
-
-	/* Read the current message sequence counter from secrets pages */
-	count = *os_area_msg_seqno;
-
-	return count + 1;
-}
-
-/* Return a non-zero on success */
-static u64 snp_get_msg_seqno(struct snp_guest_dev *snp_dev)
-{
-	u64 count = __snp_get_msg_seqno(snp_dev);
-
-	/*
-	 * The message sequence counter for the SNP guest request is a  64-bit
-	 * value but the version 2 of GHCB specification defines a 32-bit storage
-	 * for it. If the counter exceeds the 32-bit value then return zero.
-	 * The caller should check the return value, but if the caller happens to
-	 * not check the value and use it, then the firmware treats zero as an
-	 * invalid number and will fail the  message request.
-	 */
-	if (count >= UINT_MAX) {
-		dev_err(snp_dev->dev, "request message sequence counter overflow\n");
-		return 0;
-	}
-
-	return count;
-}
-
-static void snp_inc_msg_seqno(struct snp_guest_dev *snp_dev)
-{
-	u32 *os_area_msg_seqno = snp_get_os_area_msg_seqno(snp_dev);
-
-	/*
-	 * The counter is also incremented by the PSP, so increment it by 2
-	 * and save in secrets page.
-	 */
-	*os_area_msg_seqno += 2;
-}
-
 static inline struct snp_guest_dev *to_snp_dev(struct file *file)
 {
 	struct miscdevice *dev = file->private_data;
@@ -162,241 +42,6 @@  static inline struct snp_guest_dev *to_snp_dev(struct file *file)
 	return container_of(dev, struct snp_guest_dev, misc);
 }
 
-static struct aesgcm_ctx *snp_init_crypto(struct snp_guest_dev *snp_dev)
-{
-	struct aesgcm_ctx *ctx;
-	u8 *key;
-
-	if (snp_is_vmpck_empty(snp_dev)) {
-		pr_err("VM communication key VMPCK%u is null\n", vmpck_id);
-		return NULL;
-	}
-
-	ctx = kzalloc(sizeof(*ctx), GFP_KERNEL_ACCOUNT);
-	if (!ctx)
-		return NULL;
-
-	key = snp_get_vmpck(snp_dev);
-	if (aesgcm_expandkey(ctx, key, VMPCK_KEY_LEN, AUTHTAG_LEN)) {
-		pr_err("Crypto context initialization failed\n");
-		kfree(ctx);
-		return NULL;
-	}
-
-	return ctx;
-}
-
-static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, struct snp_guest_req *guest_req)
-{
-	struct snp_guest_msg *resp = &snp_dev->secret_response;
-	struct snp_guest_msg *req = &snp_dev->secret_request;
-	struct snp_guest_msg_hdr *req_hdr = &req->hdr;
-	struct snp_guest_msg_hdr *resp_hdr = &resp->hdr;
-	struct aesgcm_ctx *ctx = snp_dev->ctx;
-	u8 iv[GCM_AES_IV_SIZE] = {};
-
-	pr_debug("response [seqno %lld type %d version %d sz %d]\n",
-		 resp_hdr->msg_seqno, resp_hdr->msg_type, resp_hdr->msg_version,
-		 resp_hdr->msg_sz);
-
-	/* Copy response from shared memory to encrypted memory. */
-	memcpy(resp, snp_dev->response, sizeof(*resp));
-
-	/* Verify that the sequence counter is incremented by 1 */
-	if (unlikely(resp_hdr->msg_seqno != (req_hdr->msg_seqno + 1)))
-		return -EBADMSG;
-
-	/* Verify response message type and version number. */
-	if (resp_hdr->msg_type != (req_hdr->msg_type + 1) ||
-	    resp_hdr->msg_version != req_hdr->msg_version)
-		return -EBADMSG;
-
-	/*
-	 * If the message size is greater than our buffer length then return
-	 * an error.
-	 */
-	if (unlikely((resp_hdr->msg_sz + ctx->authsize) > guest_req->resp_sz))
-		return -EBADMSG;
-
-	/* Decrypt the payload */
-	memcpy(iv, &resp_hdr->msg_seqno, sizeof(resp_hdr->msg_seqno));
-	if (!aesgcm_decrypt(ctx, guest_req->resp_buf, resp->payload, resp_hdr->msg_sz,
-			    &resp_hdr->algo, AAD_LEN, iv, resp_hdr->authtag))
-		return -EBADMSG;
-
-	return 0;
-}
-
-static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, struct snp_guest_req *req)
-{
-	struct snp_guest_msg *msg = &snp_dev->secret_request;
-	struct snp_guest_msg_hdr *hdr = &msg->hdr;
-	struct aesgcm_ctx *ctx = snp_dev->ctx;
-	u8 iv[GCM_AES_IV_SIZE] = {};
-
-	memset(msg, 0, sizeof(*msg));
-
-	hdr->algo = SNP_AEAD_AES_256_GCM;
-	hdr->hdr_version = MSG_HDR_VER;
-	hdr->hdr_sz = sizeof(*hdr);
-	hdr->msg_type = req->msg_type;
-	hdr->msg_version = req->msg_version;
-	hdr->msg_seqno = seqno;
-	hdr->msg_vmpck = req->vmpck_id;
-	hdr->msg_sz = req->req_sz;
-
-	/* Verify the sequence number is non-zero */
-	if (!hdr->msg_seqno)
-		return -ENOSR;
-
-	pr_debug("request [seqno %lld type %d version %d sz %d]\n",
-		 hdr->msg_seqno, hdr->msg_type, hdr->msg_version, hdr->msg_sz);
-
-	if (WARN_ON((req->req_sz + ctx->authsize) > sizeof(msg->payload)))
-		return -EBADMSG;
-
-	memcpy(iv, &hdr->msg_seqno, sizeof(hdr->msg_seqno));
-	aesgcm_encrypt(ctx, msg->payload, req->req_buf, req->req_sz, &hdr->algo,
-		       AAD_LEN, iv, hdr->authtag);
-
-	return 0;
-}
-
-static int __handle_guest_request(struct snp_guest_dev *snp_dev, struct snp_guest_req *req,
-				  struct snp_guest_request_ioctl *rio)
-{
-	unsigned long req_start = jiffies;
-	unsigned int override_npages = 0;
-	u64 override_err = 0;
-	int rc;
-
-retry_request:
-	/*
-	 * Call firmware to process the request. In this function the encrypted
-	 * message enters shared memory with the host. So after this call the
-	 * sequence number must be incremented or the VMPCK must be deleted to
-	 * prevent reuse of the IV.
-	 */
-	rc = snp_issue_guest_request(req, &snp_dev->input, rio);
-	switch (rc) {
-	case -ENOSPC:
-		/*
-		 * If the extended guest request fails due to having too
-		 * small of a certificate data buffer, retry the same
-		 * guest request without the extended data request in
-		 * order to increment the sequence number and thus avoid
-		 * IV reuse.
-		 */
-		override_npages = req->data_npages;
-		req->exit_code	= SVM_VMGEXIT_GUEST_REQUEST;
-
-		/*
-		 * Override the error to inform callers the given extended
-		 * request buffer size was too small and give the caller the
-		 * required buffer size.
-		 */
-		override_err = SNP_GUEST_VMM_ERR(SNP_GUEST_VMM_ERR_INVALID_LEN);
-
-		/*
-		 * If this call to the firmware succeeds, the sequence number can
-		 * be incremented allowing for continued use of the VMPCK. If
-		 * there is an error reflected in the return value, this value
-		 * is checked further down and the result will be the deletion
-		 * of the VMPCK and the error code being propagated back to the
-		 * user as an ioctl() return code.
-		 */
-		goto retry_request;
-
-	/*
-	 * The host may return SNP_GUEST_VMM_ERR_BUSY if the request has been
-	 * throttled. Retry in the driver to avoid returning and reusing the
-	 * message sequence number on a different message.
-	 */
-	case -EAGAIN:
-		if (jiffies - req_start > SNP_REQ_MAX_RETRY_DURATION) {
-			rc = -ETIMEDOUT;
-			break;
-		}
-		schedule_timeout_killable(SNP_REQ_RETRY_DELAY);
-		goto retry_request;
-	}
-
-	/*
-	 * Increment the message sequence number. There is no harm in doing
-	 * this now because decryption uses the value stored in the response
-	 * structure and any failure will wipe the VMPCK, preventing further
-	 * use anyway.
-	 */
-	snp_inc_msg_seqno(snp_dev);
-
-	if (override_err) {
-		rio->exitinfo2 = override_err;
-
-		/*
-		 * If an extended guest request was issued and the supplied certificate
-		 * buffer was not large enough, a standard guest request was issued to
-		 * prevent IV reuse. If the standard request was successful, return -EIO
-		 * back to the caller as would have originally been returned.
-		 */
-		if (!rc && override_err == SNP_GUEST_VMM_ERR(SNP_GUEST_VMM_ERR_INVALID_LEN))
-			rc = -EIO;
-	}
-
-	if (override_npages)
-		req->data_npages = override_npages;
-
-	return rc;
-}
-
-static int snp_send_guest_request(struct snp_guest_dev *snp_dev, struct snp_guest_req *req,
-				  struct snp_guest_request_ioctl *rio)
-{
-	u64 seqno;
-	int rc;
-
-	/* Get message sequence and verify that its a non-zero */
-	seqno = snp_get_msg_seqno(snp_dev);
-	if (!seqno)
-		return -EIO;
-
-	/* Clear shared memory's response for the host to populate. */
-	memset(snp_dev->response, 0, sizeof(struct snp_guest_msg));
-
-	/* Encrypt the userspace provided payload in snp_dev->secret_request. */
-	rc = enc_payload(snp_dev, seqno, req);
-	if (rc)
-		return rc;
-
-	/*
-	 * Write the fully encrypted request to the shared unencrypted
-	 * request page.
-	 */
-	memcpy(snp_dev->request, &snp_dev->secret_request,
-	       sizeof(snp_dev->secret_request));
-
-	rc = __handle_guest_request(snp_dev, req, rio);
-	if (rc) {
-		if (rc == -EIO &&
-		    rio->exitinfo2 == SNP_GUEST_VMM_ERR(SNP_GUEST_VMM_ERR_INVALID_LEN))
-			return rc;
-
-		dev_alert(snp_dev->dev,
-			  "Detected error from ASP request. rc: %d, exitinfo2: 0x%llx\n",
-			  rc, rio->exitinfo2);
-		snp_disable_vmpck(snp_dev);
-		return rc;
-	}
-
-	rc = verify_and_dec_payload(snp_dev, req);
-	if (rc) {
-		dev_alert(snp_dev->dev, "Detected unexpected decode failure from ASP. rc: %d\n", rc);
-		snp_disable_vmpck(snp_dev);
-		return rc;
-	}
-
-	return 0;
-}
-
 struct snp_req_resp {
 	sockptr_t req_data;
 	sockptr_t resp_data;
@@ -607,7 +252,7 @@  static long snp_guest_ioctl(struct file *file, unsigned int ioctl, unsigned long
 	mutex_lock(&snp_dev->cmd_mutex);
 
 	/* Check if the VMPCK is not empty */
-	if (snp_is_vmpck_empty(snp_dev)) {
+	if (snp_is_vmpck_empty(snp_dev->vmpck_id)) {
 		dev_err_ratelimited(snp_dev->dev, "VMPCK is disabled\n");
 		mutex_unlock(&snp_dev->cmd_mutex);
 		return -ENOTTY;
@@ -642,58 +287,11 @@  static long snp_guest_ioctl(struct file *file, unsigned int ioctl, unsigned long
 	return ret;
 }
 
-static void free_shared_pages(void *buf, size_t sz)
-{
-	unsigned int npages = PAGE_ALIGN(sz) >> PAGE_SHIFT;
-	int ret;
-
-	if (!buf)
-		return;
-
-	ret = set_memory_encrypted((unsigned long)buf, npages);
-	if (ret) {
-		WARN_ONCE(ret, "failed to restore encryption mask (leak it)\n");
-		return;
-	}
-
-	__free_pages(virt_to_page(buf), get_order(sz));
-}
-
-static void *alloc_shared_pages(struct device *dev, size_t sz)
-{
-	unsigned int npages = PAGE_ALIGN(sz) >> PAGE_SHIFT;
-	struct page *page;
-	int ret;
-
-	page = alloc_pages(GFP_KERNEL_ACCOUNT, get_order(sz));
-	if (!page)
-		return NULL;
-
-	ret = set_memory_decrypted((unsigned long)page_address(page), npages);
-	if (ret) {
-		dev_err(dev, "failed to mark page shared, ret=%d\n", ret);
-		__free_pages(page, get_order(sz));
-		return NULL;
-	}
-
-	return page_address(page);
-}
-
 static const struct file_operations snp_guest_fops = {
 	.owner	= THIS_MODULE,
 	.unlocked_ioctl = snp_guest_ioctl,
 };
 
-bool snp_assign_vmpck(struct snp_guest_dev *dev, unsigned int vmpck_id)
-{
-	if (WARN_ON(vmpck_id > 3))
-		return false;
-
-	dev->vmpck_id = vmpck_id;
-
-	return true;
-}
-
 struct snp_msg_report_resp_hdr {
 	u32 status;
 	u32 report_size;
@@ -727,7 +325,7 @@  static int sev_report_new(struct tsm_report *report, void *data)
 	guard(mutex)(&snp_dev->cmd_mutex);
 
 	/* Check if the VMPCK is not empty */
-	if (snp_is_vmpck_empty(snp_dev)) {
+	if (snp_is_vmpck_empty(snp_dev->vmpck_id)) {
 		dev_err_ratelimited(snp_dev->dev, "VMPCK is disabled\n");
 		return -ENOTTY;
 	}
@@ -820,76 +418,43 @@  static void unregister_sev_tsm(void *data)
 
 static int __init sev_guest_probe(struct platform_device *pdev)
 {
-	struct snp_secrets_page_layout *layout;
-	struct sev_guest_platform_data *data;
 	struct device *dev = &pdev->dev;
 	struct snp_guest_dev *snp_dev;
 	struct miscdevice *misc;
-	void __iomem *mapping;
 	int ret;
 
 	if (!cc_platform_has(CC_ATTR_GUEST_SEV_SNP))
 		return -ENODEV;
 
-	if (!dev->platform_data)
-		return -ENODEV;
-
-	data = (struct sev_guest_platform_data *)dev->platform_data;
-	mapping = ioremap_encrypted(data->secrets_gpa, PAGE_SIZE);
-	if (!mapping)
-		return -ENODEV;
-
-	layout = (__force void *)mapping;
-
-	ret = -ENOMEM;
 	snp_dev = devm_kzalloc(&pdev->dev, sizeof(struct snp_guest_dev), GFP_KERNEL);
 	if (!snp_dev)
-		goto e_unmap;
+		return -ENOMEM;
 
-	ret = -EINVAL;
-	snp_dev->layout = layout;
 	if (!snp_assign_vmpck(snp_dev, vmpck_id)) {
 		dev_err(dev, "invalid vmpck id %u\n", vmpck_id);
-		goto e_unmap;
+		ret = -EINVAL;
+		goto e_free_snpdev;
 	}
 
-	/* Verify that VMPCK is not zero. */
-	if (snp_is_vmpck_empty(snp_dev)) {
-		dev_err(dev, "vmpck id %u is null\n", vmpck_id);
-		goto e_unmap;
+	if (snp_setup_psp_messaging(snp_dev)) {
+		dev_err(dev, "Unable to setup PSP messaging vmpck id %u\n", snp_dev->vmpck_id);
+		ret = -ENODEV;
+		goto e_free_snpdev;
 	}
 
 	mutex_init(&snp_dev->cmd_mutex);
 	platform_set_drvdata(pdev, snp_dev);
 	snp_dev->dev = dev;
 
-	/* Allocate the shared page used for the request and response message. */
-	snp_dev->request = alloc_shared_pages(dev, sizeof(struct snp_guest_msg));
-	if (!snp_dev->request)
-		goto e_unmap;
-
-	snp_dev->response = alloc_shared_pages(dev, sizeof(struct snp_guest_msg));
-	if (!snp_dev->response)
-		goto e_free_request;
-
-	snp_dev->certs_data = alloc_shared_pages(dev, SEV_FW_BLOB_MAX_SIZE);
+	snp_dev->certs_data = alloc_shared_pages(SEV_FW_BLOB_MAX_SIZE);
 	if (!snp_dev->certs_data)
-		goto e_free_response;
-
-	ret = -EIO;
-	snp_dev->ctx = snp_init_crypto(snp_dev);
-	if (!snp_dev->ctx)
-		goto e_free_cert_data;
+		goto e_free_ctx;
 
 	misc = &snp_dev->misc;
 	misc->minor = MISC_DYNAMIC_MINOR;
 	misc->name = DEVICE_NAME;
 	misc->fops = &snp_guest_fops;
 
-	/* initial the input address for guest request */
-	snp_dev->input.req_gpa = __pa(snp_dev->request);
-	snp_dev->input.resp_gpa = __pa(snp_dev->response);
-
 	ret = tsm_register(&sev_tsm_ops, snp_dev, &tsm_report_extra_type);
 	if (ret)
 		goto e_free_cert_data;
@@ -900,21 +465,18 @@  static int __init sev_guest_probe(struct platform_device *pdev)
 
 	ret =  misc_register(misc);
 	if (ret)
-		goto e_free_ctx;
+		goto e_free_cert_data;
+
+	dev_info(dev, "Initialized SEV guest driver (using vmpck_id %u)\n", snp_dev->vmpck_id);
 
-	dev_info(dev, "Initialized SEV guest driver (using vmpck_id %u)\n", vmpck_id);
 	return 0;
 
-e_free_ctx:
-	kfree(snp_dev->ctx);
 e_free_cert_data:
 	free_shared_pages(snp_dev->certs_data, SEV_FW_BLOB_MAX_SIZE);
-e_free_response:
-	free_shared_pages(snp_dev->response, sizeof(struct snp_guest_msg));
-e_free_request:
-	free_shared_pages(snp_dev->request, sizeof(struct snp_guest_msg));
-e_unmap:
-	iounmap(mapping);
+e_free_ctx:
+	kfree(snp_dev->ctx);
+e_free_snpdev:
+	kfree(snp_dev);
 	return ret;
 }
 
@@ -923,10 +485,9 @@  static int __exit sev_guest_remove(struct platform_device *pdev)
 	struct snp_guest_dev *snp_dev = platform_get_drvdata(pdev);
 
 	free_shared_pages(snp_dev->certs_data, SEV_FW_BLOB_MAX_SIZE);
-	free_shared_pages(snp_dev->response, sizeof(struct snp_guest_msg));
-	free_shared_pages(snp_dev->request, sizeof(struct snp_guest_msg));
-	kfree(snp_dev->ctx);
 	misc_deregister(&snp_dev->misc);
+	kfree(snp_dev->ctx);
+	kfree(snp_dev);
 
 	return 0;
 }