[v7,03/16] virt: sev-guest: Add SNP guest request structure

Message ID 20231220151358.2147066-4-nikunj@amd.com
State New
Headers
Series Add Secure TSC support for SNP guests |

Commit Message

Nikunj A. Dadhania Dec. 20, 2023, 3:13 p.m. UTC
  Add a snp_guest_req structure to simplify the function arguments. The
structure will be used to call the SNP Guest message request API
instead of passing a long list of parameters.

Update snp_issue_guest_request() prototype to include the new guest request
structure and move the prototype to sev_guest.h.

Signed-off-by: Nikunj A Dadhania <nikunj@amd.com>
Tested-by: Peter Gonda <pgonda@google.com>
---
 .../x86/include/asm}/sev-guest.h              |  18 +++
 arch/x86/include/asm/sev.h                    |   8 --
 arch/x86/kernel/sev.c                         |  15 ++-
 drivers/virt/coco/sev-guest/sev-guest.c       | 108 +++++++++++-------
 4 files changed, 93 insertions(+), 56 deletions(-)
 rename {drivers/virt/coco/sev-guest => arch/x86/include/asm}/sev-guest.h (78%)
  

Comments

Borislav Petkov Jan. 25, 2024, 11:59 a.m. UTC | #1
On Wed, Dec 20, 2023 at 08:43:45PM +0530, Nikunj A Dadhania wrote:
> -int snp_issue_guest_request(u64 exit_code, struct snp_req_data *input, struct snp_guest_request_ioctl *rio)
> +int snp_issue_guest_request(struct snp_guest_req *req, struct snp_req_data *input,
> +			    struct snp_guest_request_ioctl *rio)
>  {
>  	struct ghcb_state state;
>  	struct es_em_ctxt ctxt;
>  	unsigned long flags;
>  	struct ghcb *ghcb;
> +	u64 exit_code;

Silly local vars. Just use req->exit_code everywhere instead.

>  	int ret;
>  
>  	rio->exitinfo2 = SEV_RET_NO_FW_CALL;
> +	if (!req)
> +		return -EINVAL;

Such tests are done under the variable which is assigned, not randomly.

Also, what's the point in testing req? Will that ever be NULL? What are
you actually protecting against here?

> diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c
> index 469e10d9bf35..5cafbd1c42cb 100644
> --- a/drivers/virt/coco/sev-guest/sev-guest.c
> +++ b/drivers/virt/coco/sev-guest/sev-guest.c
> @@ -27,8 +27,7 @@
>  
>  #include <asm/svm.h>
>  #include <asm/sev.h>
> -
> -#include "sev-guest.h"
> +#include <asm/sev-guest.h>
>  
>  #define DEVICE_NAME	"sev-guest"
>  
> @@ -169,7 +168,7 @@ static struct aesgcm_ctx *snp_init_crypto(u8 *key, size_t keylen)
>  	return ctx;
>  }
>  
> -static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, void *payload, u32 sz)
> +static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, struct snp_guest_req *guest_req)

So we call the request everywhere "req". But you've called it
"guest_req" here because...

>  {
>  	struct snp_guest_msg *resp = &snp_dev->secret_response;
>  	struct snp_guest_msg *req = &snp_dev->secret_request;

.. there already is a "req" variable which is not a guest request thing
but a guest message. So why don't you call it "req_msg" instead and the
"resp" "resp_msg" so that it is clear what is what?

And then you can call the actual request var "req" and then the code
becomes more readable...

..

>  static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg)
>  {
>  	struct snp_report_req *req = &snp_dev->req.report;
> +	struct snp_guest_req guest_req = {0};

You have the same issue here.

If we aim at calling the local vars in every function the same, the code
becomes automatically much more readable.

And so on...
  
Tom Lendacky Jan. 26, 2024, 9:16 p.m. UTC | #2
On 12/20/23 09:13, Nikunj A Dadhania wrote:
> Add a snp_guest_req structure to simplify the function arguments. The
> structure will be used to call the SNP Guest message request API
> instead of passing a long list of parameters.
> 
> Update snp_issue_guest_request() prototype to include the new guest request
> structure and move the prototype to sev_guest.h.
> 
> Signed-off-by: Nikunj A Dadhania <nikunj@amd.com>
> Tested-by: Peter Gonda <pgonda@google.com>
> ---
>   .../x86/include/asm}/sev-guest.h              |  18 +++
>   arch/x86/include/asm/sev.h                    |   8 --
>   arch/x86/kernel/sev.c                         |  15 ++-
>   drivers/virt/coco/sev-guest/sev-guest.c       | 108 +++++++++++-------
>   4 files changed, 93 insertions(+), 56 deletions(-)
>   rename {drivers/virt/coco/sev-guest => arch/x86/include/asm}/sev-guest.h (78%)
> 
> diff --git a/drivers/virt/coco/sev-guest/sev-guest.h b/arch/x86/include/asm/sev-guest.h
> similarity index 78%
> rename from drivers/virt/coco/sev-guest/sev-guest.h
> rename to arch/x86/include/asm/sev-guest.h
> index ceb798a404d6..27cc15ad6131 100644
> --- a/drivers/virt/coco/sev-guest/sev-guest.h
> +++ b/arch/x86/include/asm/sev-guest.h
> @@ -63,4 +63,22 @@ struct snp_guest_msg {
>   	u8 payload[4000];
>   } __packed;
>   
> +struct snp_guest_req {
> +	void *req_buf;
> +	size_t req_sz;
> +
> +	void *resp_buf;
> +	size_t resp_sz;
> +
> +	void *data;
> +	size_t data_npages;
> +
> +	u64 exit_code;
> +	unsigned int vmpck_id;
> +	u8 msg_version;
> +	u8 msg_type;
> +};
> +
> +int snp_issue_guest_request(struct snp_guest_req *req, struct snp_req_data *input,
> +			    struct snp_guest_request_ioctl *rio);

This seems odd to have in this file. It's arch/x86/kernel/sev.c that 
exports the call and so this should probably stay in 
arch/x86/include/asm/sev.h and put the struct there, too, no?

Thanks,
Tom

>   #endif /* __VIRT_SEVGUEST_H__ */
> diff --git a/arch/x86/include/asm/sev.h b/arch/x86/include/asm/sev.h
> index 5b4a1ce3d368..78465a8c7dc6 100644
> --- a/arch/x86/include/asm/sev.h
> +++ b/arch/x86/include/asm/sev.h
> @@ -97,8 +97,6 @@ extern bool handle_vc_boot_ghcb(struct pt_regs *regs);
>   struct snp_req_data {
>   	unsigned long req_gpa;
>   	unsigned long resp_gpa;
> -	unsigned long data_gpa;
> -	unsigned int data_npages;
>   };
>   
>   struct sev_guest_platform_data {
> @@ -209,7 +207,6 @@ void snp_set_memory_private(unsigned long vaddr, unsigned long npages);
>   void snp_set_wakeup_secondary_cpu(void);
>   bool snp_init(struct boot_params *bp);
>   void __init __noreturn snp_abort(void);
> -int snp_issue_guest_request(u64 exit_code, struct snp_req_data *input, struct snp_guest_request_ioctl *rio);
>   void snp_accept_memory(phys_addr_t start, phys_addr_t end);
>   u64 snp_get_unsupported_features(u64 status);
>   u64 sev_get_status(void);
> @@ -233,11 +230,6 @@ static inline void snp_set_memory_private(unsigned long vaddr, unsigned long npa
>   static inline void snp_set_wakeup_secondary_cpu(void) { }
>   static inline bool snp_init(struct boot_params *bp) { return false; }
>   static inline void snp_abort(void) { }
> -static inline int snp_issue_guest_request(u64 exit_code, struct snp_req_data *input, struct snp_guest_request_ioctl *rio)
> -{
> -	return -ENOTTY;
> -}
> -
>   static inline void snp_accept_memory(phys_addr_t start, phys_addr_t end) { }
>   static inline u64 snp_get_unsupported_features(u64 status) { return 0; }
>   static inline u64 sev_get_status(void) { return 0; }
  
Nikunj A. Dadhania Jan. 27, 2024, 4:01 a.m. UTC | #3
On 1/25/2024 5:29 PM, Borislav Petkov wrote:
> On Wed, Dec 20, 2023 at 08:43:45PM +0530, Nikunj A Dadhania wrote:
>> -int snp_issue_guest_request(u64 exit_code, struct snp_req_data *input, struct snp_guest_request_ioctl *rio)
>> +int snp_issue_guest_request(struct snp_guest_req *req, struct snp_req_data *input,
>> +			    struct snp_guest_request_ioctl *rio)
>>  {
>>  	struct ghcb_state state;
>>  	struct es_em_ctxt ctxt;
>>  	unsigned long flags;
>>  	struct ghcb *ghcb;
>> +	u64 exit_code;
> 
> Silly local vars. Just use req->exit_code everywhere instead.

Sure, will change.

> 
>>  	int ret;
>>  
>>  	rio->exitinfo2 = SEV_RET_NO_FW_CALL;
>> +	if (!req)
>> +		return -EINVAL;
> 
> Such tests are done under the variable which is assigned, not randomly.
> 
> Also, what's the point in testing req? Will that ever be NULL? What are
> you actually protecting against here?

Right, and in the later code, this is checked at snp_send_guest_request() API. So this is redundant.

>> diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c
>> index 469e10d9bf35..5cafbd1c42cb 100644
>> --- a/drivers/virt/coco/sev-guest/sev-guest.c
>> +++ b/drivers/virt/coco/sev-guest/sev-guest.c
>> @@ -27,8 +27,7 @@
>>  
>>  #include <asm/svm.h>
>>  #include <asm/sev.h>
>> -
>> -#include "sev-guest.h"
>> +#include <asm/sev-guest.h>
>>  
>>  #define DEVICE_NAME	"sev-guest"
>>  
>> @@ -169,7 +168,7 @@ static struct aesgcm_ctx *snp_init_crypto(u8 *key, size_t keylen)
>>  	return ctx;
>>  }
>>  
>> -static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, void *payload, u32 sz)
>> +static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, struct snp_guest_req *guest_req)
> 
> So we call the request everywhere "req". But you've called it
> "guest_req" here because...

Yes, I was thinking about it and came up with this.

> 
>>  {
>>  	struct snp_guest_msg *resp = &snp_dev->secret_response;
>>  	struct snp_guest_msg *req = &snp_dev->secret_request;
> 
> ... there already is a "req" variable which is not a guest request thing
> but a guest message. So why don't you call it "req_msg" instead and the
> "resp" "resp_msg" so that it is clear what is what?
> 

This naming is much better, thanks.

> And then you can call the actual request var "req" and then the code
> becomes more readable...
> 
> ...
> 
>>  static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg)
>>  {
>>  	struct snp_report_req *req = &snp_dev->req.report;
>> +	struct snp_guest_req guest_req = {0};
> 
> You have the same issue here.
> 
> If we aim at calling the local vars in every function the same, the code
> becomes automatically much more readable.
> 
> And so on...

Will change accordingly,

Regards
Nikunj
  
Nikunj A. Dadhania Jan. 27, 2024, 4:05 a.m. UTC | #4
On 1/27/2024 2:46 AM, Tom Lendacky wrote:
> On 12/20/23 09:13, Nikunj A Dadhania wrote:
>> Add a snp_guest_req structure to simplify the function arguments. The
>> structure will be used to call the SNP Guest message request API
>> instead of passing a long list of parameters.
>>
>> Update snp_issue_guest_request() prototype to include the new guest request
>> structure and move the prototype to sev_guest.h.
>>
>> Signed-off-by: Nikunj A Dadhania <nikunj@amd.com>
>> Tested-by: Peter Gonda <pgonda@google.com>
>> ---
>>   .../x86/include/asm}/sev-guest.h              |  18 +++
>>   arch/x86/include/asm/sev.h                    |   8 --
>>   arch/x86/kernel/sev.c                         |  15 ++-
>>   drivers/virt/coco/sev-guest/sev-guest.c       | 108 +++++++++++-------
>>   4 files changed, 93 insertions(+), 56 deletions(-)
>>   rename {drivers/virt/coco/sev-guest => arch/x86/include/asm}/sev-guest.h (78%)
>>
>> diff --git a/drivers/virt/coco/sev-guest/sev-guest.h b/arch/x86/include/asm/sev-guest.h
>> similarity index 78%
>> rename from drivers/virt/coco/sev-guest/sev-guest.h
>> rename to arch/x86/include/asm/sev-guest.h
>> index ceb798a404d6..27cc15ad6131 100644
>> --- a/drivers/virt/coco/sev-guest/sev-guest.h
>> +++ b/arch/x86/include/asm/sev-guest.h
>> @@ -63,4 +63,22 @@ struct snp_guest_msg {
>>       u8 payload[4000];
>>   } __packed;
>>   +struct snp_guest_req {
>> +    void *req_buf;
>> +    size_t req_sz;
>> +
>> +    void *resp_buf;
>> +    size_t resp_sz;
>> +
>> +    void *data;
>> +    size_t data_npages;
>> +
>> +    u64 exit_code;
>> +    unsigned int vmpck_id;
>> +    u8 msg_version;
>> +    u8 msg_type;
>> +};
>> +
>> +int snp_issue_guest_request(struct snp_guest_req *req, struct snp_req_data *input,
>> +                struct snp_guest_request_ioctl *rio);
> 
> This seems odd to have in this file. It's arch/x86/kernel/sev.c that exports the call and so this should probably stay in arch/x86/include/asm/sev.h and put the struct there, too, no?

The prototype is removed in 7/16, I have it here to make sure that compilation does not break with minimal churn.

Regards
Nikunj
  
Nikunj A. Dadhania Jan. 31, 2024, 1:58 p.m. UTC | #5
On 1/25/2024 5:29 PM, Borislav Petkov wrote:
> On Wed, Dec 20, 2023 at 08:43:45PM +0530, Nikunj A Dadhania wrote:
>> -int snp_issue_guest_request(u64 exit_code, struct snp_req_data *input, struct snp_guest_request_ioctl *rio)
>> +int snp_issue_guest_request(struct snp_guest_req *req, struct snp_req_data *input,
>> +			    struct snp_guest_request_ioctl *rio)
>>  {
>>  	struct ghcb_state state;
>>  	struct es_em_ctxt ctxt;
>>  	unsigned long flags;
>>  	struct ghcb *ghcb;
>> +	u64 exit_code;
> 
> Silly local vars. Just use req->exit_code everywhere instead.
> 
>>  	int ret;
>>  
>>  	rio->exitinfo2 = SEV_RET_NO_FW_CALL;
>> +	if (!req)
>> +		return -EINVAL;
> 
> Such tests are done under the variable which is assigned, not randomly.
> 
> Also, what's the point in testing req? Will that ever be NULL? What are
> you actually protecting against here?
> 
>> diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c
>> index 469e10d9bf35..5cafbd1c42cb 100644
>> --- a/drivers/virt/coco/sev-guest/sev-guest.c
>> +++ b/drivers/virt/coco/sev-guest/sev-guest.c
>> @@ -27,8 +27,7 @@
>>  
>>  #include <asm/svm.h>
>>  #include <asm/sev.h>
>> -
>> -#include "sev-guest.h"
>> +#include <asm/sev-guest.h>
>>  
>>  #define DEVICE_NAME	"sev-guest"
>>  
>> @@ -169,7 +168,7 @@ static struct aesgcm_ctx *snp_init_crypto(u8 *key, size_t keylen)
>>  	return ctx;
>>  }
>>  
>> -static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, void *payload, u32 sz)
>> +static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, struct snp_guest_req *guest_req)
> 
> So we call the request everywhere "req". But you've called it
> "guest_req" here because...
> 
>>  {
>>  	struct snp_guest_msg *resp = &snp_dev->secret_response;
>>  	struct snp_guest_msg *req = &snp_dev->secret_request;
> 
> ... there already is a "req" variable which is not a guest request thing
> but a guest message. So why don't you call it "req_msg" instead and the
> "resp" "resp_msg" so that it is clear what is what?
> 
> And then you can call the actual request var "req" and then the code
> becomes more readable...
> 
> ...
> 
>>  static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg)
>>  {
>>  	struct snp_report_req *req = &snp_dev->req.report;
>> +	struct snp_guest_req guest_req = {0};
> 
> You have the same issue here.
> 
> If we aim at calling the local vars in every function the same, the code
> becomes automatically much more readable.
> 
> And so on...
> 

Changed to "req" for all the guest request throughout the file. Other "req" 
usage are renamed appropriately.

Subject: [PATCH] virt: sev-guest: Add SNP guest request structure

Add a snp_guest_req structure to simplify the function arguments. The
structure will be used to call the SNP Guest message request API
instead of passing a long list of parameters. Use "req" as variable name
for guest req throughout the file and rename other variables appropriately.

Update snp_issue_guest_request() prototype to include the new guest request
structure and move the prototype to sev_guest.h.

Signed-off-by: Nikunj A Dadhania <nikunj@amd.com>
Tested-by: Peter Gonda <pgonda@google.com>
---
 .../x86/include/asm}/sev-guest.h              |  18 ++
 arch/x86/include/asm/sev.h                    |   8 -
 arch/x86/kernel/sev.c                         |  16 +-
 drivers/virt/coco/sev-guest/sev-guest.c       | 195 ++++++++++--------
 4 files changed, 135 insertions(+), 102 deletions(-)
 rename {drivers/virt/coco/sev-guest => arch/x86/include/asm}/sev-guest.h (78%)

diff --git a/drivers/virt/coco/sev-guest/sev-guest.h b/arch/x86/include/asm/sev-guest.h
similarity index 78%
rename from drivers/virt/coco/sev-guest/sev-guest.h
rename to arch/x86/include/asm/sev-guest.h
index ceb798a404d6..27cc15ad6131 100644
--- a/drivers/virt/coco/sev-guest/sev-guest.h
+++ b/arch/x86/include/asm/sev-guest.h
@@ -63,4 +63,22 @@ struct snp_guest_msg {
        u8 payload[4000];
 } __packed;

+struct snp_guest_req {
+       void *req_buf;
+       size_t req_sz;
+
+       void *resp_buf;
+       size_t resp_sz;
+
+       void *data;
+       size_t data_npages;
+
+       u64 exit_code;
+       unsigned int vmpck_id;
+       u8 msg_version;
+       u8 msg_type;
+};
+
+int snp_issue_guest_request(struct snp_guest_req *req, struct snp_req_data *input,
+                           struct snp_guest_request_ioctl *rio);
 #endif /* __VIRT_SEVGUEST_H__ */
diff --git a/arch/x86/include/asm/sev.h b/arch/x86/include/asm/sev.h
index 5b4a1ce3d368..78465a8c7dc6 100644
--- a/arch/x86/include/asm/sev.h
+++ b/arch/x86/include/asm/sev.h
@@ -97,8 +97,6 @@ extern bool handle_vc_boot_ghcb(struct pt_regs *regs);
 struct snp_req_data {
        unsigned long req_gpa;
        unsigned long resp_gpa;
-       unsigned long data_gpa;
-       unsigned int data_npages;
 };

 struct sev_guest_platform_data {
@@ -209,7 +207,6 @@ void snp_set_memory_private(unsigned long vaddr, unsigned long npages);
 void snp_set_wakeup_secondary_cpu(void);
 bool snp_init(struct boot_params *bp);
 void __init __noreturn snp_abort(void);
-int snp_issue_guest_request(u64 exit_code, struct snp_req_data *input, struct snp_guest_request_ioctl *rio);
 void snp_accept_memory(phys_addr_t start, phys_addr_t end);
 u64 snp_get_unsupported_features(u64 status);
 u64 sev_get_status(void);
@@ -233,11 +230,6 @@ static inline void snp_set_memory_private(unsigned long vaddr, unsigned long npa
 static inline void snp_set_wakeup_secondary_cpu(void) { }
 static inline bool snp_init(struct boot_params *bp) { return false; }
 static inline void snp_abort(void) { }
-static inline int snp_issue_guest_request(u64 exit_code, struct snp_req_data *input, struct snp_guest_request_ioctl *rio)
-{
-       return -ENOTTY;
-}
-
 static inline void snp_accept_memory(phys_addr_t start, phys_addr_t end) { }
 static inline u64 snp_get_unsupported_features(u64 status) { return 0; }
 static inline u64 sev_get_status(void) { return 0; }
diff --git a/arch/x86/kernel/sev.c b/arch/x86/kernel/sev.c
index c67285824e82..43ffd307731f 100644
--- a/arch/x86/kernel/sev.c
+++ b/arch/x86/kernel/sev.c
@@ -28,6 +28,7 @@
 #include <asm/cpu_entry_area.h>
 #include <asm/stacktrace.h>
 #include <asm/sev.h>
+#include <asm/sev-guest.h>
 #include <asm/insn-eval.h>
 #include <asm/fpu/xcr.h>
 #include <asm/processor.h>
@@ -2170,7 +2171,8 @@ static int __init init_sev_config(char *str)
 }
 __setup("sev=", init_sev_config);

-int snp_issue_guest_request(u64 exit_code, struct snp_req_data *input, struct snp_guest_request_ioctl *rio)
+int snp_issue_guest_request(struct snp_guest_req *req, struct snp_req_data *input,
+                           struct snp_guest_request_ioctl *rio)
 {
        struct ghcb_state state;
        struct es_em_ctxt ctxt;
@@ -2194,12 +2196,12 @@ int snp_issue_guest_request(u64 exit_code, struct snp_req_data *input, struct sn

        vc_ghcb_invalidate(ghcb);

-       if (exit_code == SVM_VMGEXIT_EXT_GUEST_REQUEST) {
-               ghcb_set_rax(ghcb, input->data_gpa);
-               ghcb_set_rbx(ghcb, input->data_npages);
+       if (req->exit_code == SVM_VMGEXIT_EXT_GUEST_REQUEST) {
+               ghcb_set_rax(ghcb, __pa(req->data));
+               ghcb_set_rbx(ghcb, req->data_npages);
        }

-       ret = sev_es_ghcb_hv_call(ghcb, &ctxt, exit_code, input->req_gpa, input->resp_gpa);
+       ret = sev_es_ghcb_hv_call(ghcb, &ctxt, req->exit_code, input->req_gpa, input->resp_gpa);
        if (ret)
                goto e_put;

@@ -2214,8 +2216,8 @@ int snp_issue_guest_request(u64 exit_code, struct snp_req_data *input, struct sn

        case SNP_GUEST_VMM_ERR(SNP_GUEST_VMM_ERR_INVALID_LEN):
                /* Number of expected pages are returned in RBX */
-               if (exit_code == SVM_VMGEXIT_EXT_GUEST_REQUEST) {
-                       input->data_npages = ghcb_get_rbx(ghcb);
+               if (req->exit_code == SVM_VMGEXIT_EXT_GUEST_REQUEST) {
+                       req->data_npages = ghcb_get_rbx(ghcb);
                        ret = -ENOSPC;
                        break;
                }
diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c
index 0450c5383476..b6c8f70e936c 100644
--- a/drivers/virt/coco/sev-guest/sev-guest.c
+++ b/drivers/virt/coco/sev-guest/sev-guest.c
@@ -27,8 +27,7 @@

 #include <asm/svm.h>
 #include <asm/sev.h>
-
-#include "sev-guest.h"
+#include <asm/sev-guest.h>

 #define DEVICE_NAME    "sev-guest"

@@ -169,65 +168,64 @@ static struct aesgcm_ctx *snp_init_crypto(u8 *key, size_t keylen)
        return ctx;
 }

-static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, void *payload, u32 sz)
+static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, struct snp_guest_req *req)
 {
-       struct snp_guest_msg *resp = &snp_dev->secret_response;
-       struct snp_guest_msg *req = &snp_dev->secret_request;
-       struct snp_guest_msg_hdr *req_hdr = &req->hdr;
-       struct snp_guest_msg_hdr *resp_hdr = &resp->hdr;
+       struct snp_guest_msg *resp_msg = &snp_dev->secret_response;
+       struct snp_guest_msg *req_msg = &snp_dev->secret_request;
+       struct snp_guest_msg_hdr *req_msg_hdr = &req_msg->hdr;
+       struct snp_guest_msg_hdr *resp_msg_hdr = &resp_msg->hdr;
        struct aesgcm_ctx *ctx = snp_dev->ctx;
        u8 iv[GCM_AES_IV_SIZE] = {};

        pr_debug("response [seqno %lld type %d version %d sz %d]\n",
-                resp_hdr->msg_seqno, resp_hdr->msg_type, resp_hdr->msg_version,
-                resp_hdr->msg_sz);
+                resp_msg_hdr->msg_seqno, resp_msg_hdr->msg_type, resp_msg_hdr->msg_version,
+                resp_msg_hdr->msg_sz);

        /* Copy response from shared memory to encrypted memory. */
-       memcpy(resp, snp_dev->response, sizeof(*resp));
+       memcpy(resp_msg, snp_dev->response, sizeof(*resp_msg));

        /* Verify that the sequence counter is incremented by 1 */
-       if (unlikely(resp_hdr->msg_seqno != (req_hdr->msg_seqno + 1)))
+       if (unlikely(resp_msg_hdr->msg_seqno != (req_msg_hdr->msg_seqno + 1)))
                return -EBADMSG;

        /* Verify response message type and version number. */
-       if (resp_hdr->msg_type != (req_hdr->msg_type + 1) ||
-           resp_hdr->msg_version != req_hdr->msg_version)
+       if (resp_msg_hdr->msg_type != (req_msg_hdr->msg_type + 1) ||
+           resp_msg_hdr->msg_version != req_msg_hdr->msg_version)
                return -EBADMSG;

        /*
         * If the message size is greater than our buffer length then return
         * an error.
         */
-       if (unlikely((resp_hdr->msg_sz + ctx->authsize) > sz))
+       if (unlikely((resp_msg_hdr->msg_sz + ctx->authsize) > req->resp_sz))
                return -EBADMSG;

        /* Decrypt the payload */
-       memcpy(iv, &resp_hdr->msg_seqno, min(sizeof(iv), sizeof(resp_hdr->msg_seqno)));
-       if (!aesgcm_decrypt(ctx, payload, resp->payload, resp_hdr->msg_sz,
-                           &resp_hdr->algo, AAD_LEN, iv, resp_hdr->authtag))
+       memcpy(iv, &resp_msg_hdr->msg_seqno, min(sizeof(iv), sizeof(resp_msg_hdr->msg_seqno)));
+       if (!aesgcm_decrypt(ctx, req->resp_buf, resp_msg->payload, resp_msg_hdr->msg_sz,
+                           &resp_msg_hdr->algo, AAD_LEN, iv, resp_msg_hdr->authtag))
                return -EBADMSG;

        return 0;
 }

-static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, int version, u8 type,
-                       void *payload, size_t sz)
+static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, struct snp_guest_req *req)
 {
-       struct snp_guest_msg *req = &snp_dev->secret_request;
-       struct snp_guest_msg_hdr *hdr = &req->hdr;
+       struct snp_guest_msg *msg = &snp_dev->secret_request;
+       struct snp_guest_msg_hdr *hdr = &msg->hdr;
        struct aesgcm_ctx *ctx = snp_dev->ctx;
        u8 iv[GCM_AES_IV_SIZE] = {};

-       memset(req, 0, sizeof(*req));
+       memset(msg, 0, sizeof(*msg));

        hdr->algo = SNP_AEAD_AES_256_GCM;
        hdr->hdr_version = MSG_HDR_VER;
        hdr->hdr_sz = sizeof(*hdr);
-       hdr->msg_type = type;
-       hdr->msg_version = version;
+       hdr->msg_type = req->msg_type;
+       hdr->msg_version = req->msg_version;
        hdr->msg_seqno = seqno;
-       hdr->msg_vmpck = vmpck_id;
-       hdr->msg_sz = sz;
+       hdr->msg_vmpck = req->vmpck_id;
+       hdr->msg_sz = req->req_sz;

        /* Verify the sequence number is non-zero */
        if (!hdr->msg_seqno)
@@ -236,17 +234,17 @@ static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, int version, u8
        pr_debug("request [seqno %lld type %d version %d sz %d]\n",
                 hdr->msg_seqno, hdr->msg_type, hdr->msg_version, hdr->msg_sz);

-       if (WARN_ON((sz + ctx->authsize) > sizeof(req->payload)))
+       if (WARN_ON((req->req_sz + ctx->authsize) > sizeof(msg->payload)))
                return -EBADMSG;

        memcpy(iv, &hdr->msg_seqno, min(sizeof(iv), sizeof(hdr->msg_seqno)));
-       aesgcm_encrypt(ctx, req->payload, payload, sz, &hdr->algo, AAD_LEN,
-                      iv, hdr->authtag);
+       aesgcm_encrypt(ctx, msg->payload, req->req_buf, req->req_sz, &hdr->algo,
+                      AAD_LEN, iv, hdr->authtag);

        return 0;
 }

-static int __handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
+static int __handle_guest_request(struct snp_guest_dev *snp_dev, struct snp_guest_req *req,
                                  struct snp_guest_request_ioctl *rio)
 {
        unsigned long req_start = jiffies;
@@ -261,7 +259,7 @@ static int __handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
         * sequence number must be incremented or the VMPCK must be deleted to
         * prevent reuse of the IV.
         */
-       rc = snp_issue_guest_request(exit_code, &snp_dev->input, rio);
+       rc = snp_issue_guest_request(req, &snp_dev->input, rio);
        switch (rc) {
        case -ENOSPC:
                /*
@@ -271,8 +269,8 @@ static int __handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
                 * order to increment the sequence number and thus avoid
                 * IV reuse.
                 */
-               override_npages = snp_dev->input.data_npages;
-               exit_code       = SVM_VMGEXIT_GUEST_REQUEST;
+               override_npages = req->data_npages;
+               req->exit_code  = SVM_VMGEXIT_GUEST_REQUEST;

                /*
                 * Override the error to inform callers the given extended
@@ -327,15 +325,13 @@ static int __handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
        }

        if (override_npages)
-               snp_dev->input.data_npages = override_npages;
+               req->data_npages = override_npages;

        return rc;
 }

-static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
-                               struct snp_guest_request_ioctl *rio, u8 type,
-                               void *req_buf, size_t req_sz, void *resp_buf,
-                               u32 resp_sz)
+static int snp_send_guest_request(struct snp_guest_dev *snp_dev, struct snp_guest_req *req,
+                                 struct snp_guest_request_ioctl *rio)
 {
        u64 seqno;
        int rc;
@@ -349,7 +345,7 @@ static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
        memset(snp_dev->response, 0, sizeof(struct snp_guest_msg));

        /* Encrypt the userspace provided payload in snp_dev->secret_request. */
-       rc = enc_payload(snp_dev, seqno, rio->msg_version, type, req_buf, req_sz);
+       rc = enc_payload(snp_dev, seqno, req);
        if (rc)
                return rc;

@@ -360,7 +356,7 @@ static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
        memcpy(snp_dev->request, &snp_dev->secret_request,
               sizeof(snp_dev->secret_request));

-       rc = __handle_guest_request(snp_dev, exit_code, rio);
+       rc = __handle_guest_request(snp_dev, req, rio);
        if (rc) {
                if (rc == -EIO &&
                    rio->exitinfo2 == SNP_GUEST_VMM_ERR(SNP_GUEST_VMM_ERR_INVALID_LEN))
@@ -369,12 +365,11 @@ static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
                dev_alert(snp_dev->dev,
                          "Detected error from ASP request. rc: %d, exitinfo2: 0x%llx\n",
                          rc, rio->exitinfo2);
-
                snp_disable_vmpck(snp_dev);
                return rc;
        }

-       rc = verify_and_dec_payload(snp_dev, resp_buf, resp_sz);
+       rc = verify_and_dec_payload(snp_dev, req);
        if (rc) {
                dev_alert(snp_dev->dev, "Detected unexpected decode failure from ASP. rc: %d\n", rc);
                snp_disable_vmpck(snp_dev);
@@ -391,8 +386,9 @@ struct snp_req_resp {

 static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg)
 {
-       struct snp_report_req *req = &snp_dev->req.report;
-       struct snp_report_resp *resp;
+       struct snp_report_req *report_req = &snp_dev->req.report;
+       struct snp_guest_req req = {0};
+       struct snp_report_resp *report_resp;
        int rc, resp_len;

        lockdep_assert_held(&snp_cmd_mutex);
@@ -400,7 +396,7 @@ static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_io
        if (!arg->req_data || !arg->resp_data)
                return -EINVAL;

-       if (copy_from_user(req, (void __user *)arg->req_data, sizeof(*req)))
+       if (copy_from_user(report_req, (void __user *)arg->req_data, sizeof(*report_req)))
                return -EFAULT;

        /*
@@ -408,29 +404,37 @@ static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_io
         * response payload. Make sure that it has enough space to cover the
         * authtag.
         */
-       resp_len = sizeof(resp->data) + snp_dev->ctx->authsize;
-       resp = kzalloc(resp_len, GFP_KERNEL_ACCOUNT);
-       if (!resp)
+       resp_len = sizeof(report_resp->data) + snp_dev->ctx->authsize;
+       report_resp = kzalloc(resp_len, GFP_KERNEL_ACCOUNT);
+       if (!report_resp)
                return -ENOMEM;

-       rc = handle_guest_request(snp_dev, SVM_VMGEXIT_GUEST_REQUEST, arg,
-                                 SNP_MSG_REPORT_REQ, req, sizeof(*req), resp->data,
-                                 resp_len);
+       req.msg_version = arg->msg_version;
+       req.msg_type = SNP_MSG_REPORT_REQ;
+       req.vmpck_id = vmpck_id;
+       req.req_buf = report_req;
+       req.req_sz = sizeof(*report_req);
+       req.resp_buf = report_resp->data;
+       req.resp_sz = resp_len;
+       req.exit_code = SVM_VMGEXIT_GUEST_REQUEST;
+
+       rc = snp_send_guest_request(snp_dev, &req, arg);
        if (rc)
                goto e_free;

-       if (copy_to_user((void __user *)arg->resp_data, resp, sizeof(*resp)))
+       if (copy_to_user((void __user *)arg->resp_data, report_resp, sizeof(*report_resp)))
                rc = -EFAULT;

 e_free:
-       kfree(resp);
+       kfree(report_resp);
        return rc;
 }

 static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg)
 {
-       struct snp_derived_key_req *req = &snp_dev->req.derived_key;
-       struct snp_derived_key_resp resp = {0};
+       struct snp_derived_key_req *derived_key_req = &snp_dev->req.derived_key;
+       struct snp_derived_key_resp derived_key_resp = {0};
+       struct snp_guest_req req = {0};
        int rc, resp_len;
        /* Response data is 64 bytes and max authsize for GCM is 16 bytes. */
        u8 buf[64 + 16];
@@ -445,25 +449,34 @@ static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_reque
         * response payload. Make sure that it has enough space to cover the
         * authtag.
         */
-       resp_len = sizeof(resp.data) + snp_dev->ctx->authsize;
+       resp_len = sizeof(derived_key_resp.data) + snp_dev->ctx->authsize;
        if (sizeof(buf) < resp_len)
                return -ENOMEM;

-       if (copy_from_user(req, (void __user *)arg->req_data, sizeof(*req)))
+       if (copy_from_user(derived_key_req, (void __user *)arg->req_data,
+                          sizeof(*derived_key_req)))
                return -EFAULT;

-       rc = handle_guest_request(snp_dev, SVM_VMGEXIT_GUEST_REQUEST, arg,
-                                 SNP_MSG_KEY_REQ, req, sizeof(*req), buf, resp_len);
+       req.msg_version = arg->msg_version;
+       req.msg_type = SNP_MSG_KEY_REQ;
+       req.vmpck_id = vmpck_id;
+       req.req_buf = derived_key_req;
+       req.req_sz = sizeof(*derived_key_req);
+       req.resp_buf = buf;
+       req.resp_sz = resp_len;
+       req.exit_code = SVM_VMGEXIT_GUEST_REQUEST;
+
+       rc = snp_send_guest_request(snp_dev, &req, arg);
        if (rc)
                return rc;

-       memcpy(resp.data, buf, sizeof(resp.data));
-       if (copy_to_user((void __user *)arg->resp_data, &resp, sizeof(resp)))
+       memcpy(derived_key_resp.data, buf, sizeof(derived_key_resp.data));
+       if (copy_to_user((void __user *)arg->resp_data, &derived_key_resp, sizeof(derived_key_resp)))
                rc = -EFAULT;

        /* The response buffer contains the sensitive data, explicitly clear it. */
        memzero_explicit(buf, sizeof(buf));
-       memzero_explicit(&resp, sizeof(resp));
+       memzero_explicit(&derived_key_resp, sizeof(derived_key_resp));
        return rc;
 }

@@ -471,32 +484,33 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
                          struct snp_req_resp *io)

 {
-       struct snp_ext_report_req *req = &snp_dev->req.ext_report;
-       struct snp_report_resp *resp;
-       int ret, npages = 0, resp_len;
+       struct snp_ext_report_req *report_req = &snp_dev->req.ext_report;
+       struct snp_guest_req req = {0};
+       struct snp_report_resp *report_resp;
        sockptr_t certs_address;
+       int ret, resp_len;

        lockdep_assert_held(&snp_cmd_mutex);

        if (sockptr_is_null(io->req_data) || sockptr_is_null(io->resp_data))
                return -EINVAL;

-       if (copy_from_sockptr(req, io->req_data, sizeof(*req)))
+       if (copy_from_sockptr(report_req, io->req_data, sizeof(*report_req)))
                return -EFAULT;

        /* caller does not want certificate data */
-       if (!req->certs_len || !req->certs_address)
+       if (!report_req->certs_len || !report_req->certs_address)
                goto cmd;

-       if (req->certs_len > SEV_FW_BLOB_MAX_SIZE ||
-           !IS_ALIGNED(req->certs_len, PAGE_SIZE))
+       if (report_req->certs_len > SEV_FW_BLOB_MAX_SIZE ||
+           !IS_ALIGNED(report_req->certs_len, PAGE_SIZE))
                return -EINVAL;

        if (sockptr_is_kernel(io->resp_data)) {
-               certs_address = KERNEL_SOCKPTR((void *)req->certs_address);
+               certs_address = KERNEL_SOCKPTR((void *)report_req->certs_address);
        } else {
-               certs_address = USER_SOCKPTR((void __user *)req->certs_address);
-               if (!access_ok(certs_address.user, req->certs_len))
+               certs_address = USER_SOCKPTR((void __user *)report_req->certs_address);
+               if (!access_ok(certs_address.user, report_req->certs_len))
                        return -EFAULT;
        }

@@ -506,45 +520,53 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
         * the host. If host does not supply any certs in it, then copy
         * zeros to indicate that certificate data was not provided.
         */
-       memset(snp_dev->certs_data, 0, req->certs_len);
-       npages = req->certs_len >> PAGE_SHIFT;
+       memset(snp_dev->certs_data, 0, report_req->certs_len);
+       req.data_npages = report_req->certs_len >> PAGE_SHIFT;
 cmd:
        /*
         * The intermediate response buffer is used while decrypting the
         * response payload. Make sure that it has enough space to cover the
         * authtag.
         */
-       resp_len = sizeof(resp->data) + snp_dev->ctx->authsize;
-       resp = kzalloc(resp_len, GFP_KERNEL_ACCOUNT);
-       if (!resp)
+       resp_len = sizeof(report_resp->data) + snp_dev->ctx->authsize;
+       report_resp = kzalloc(resp_len, GFP_KERNEL_ACCOUNT);
+       if (!report_resp)
                return -ENOMEM;

-       snp_dev->input.data_npages = npages;
-       ret = handle_guest_request(snp_dev, SVM_VMGEXIT_EXT_GUEST_REQUEST, arg,
-                                  SNP_MSG_REPORT_REQ, &req->data,
-                                  sizeof(req->data), resp->data, resp_len);
+       req.msg_version = arg->msg_version;
+       req.msg_type = SNP_MSG_REPORT_REQ;
+       req.vmpck_id = vmpck_id;
+       req.req_buf = &report_req->data;
+       req.req_sz = sizeof(report_req->data);
+       req.resp_buf = report_resp->data;
+       req.resp_sz = resp_len;
+       req.exit_code = SVM_VMGEXIT_EXT_GUEST_REQUEST;
+       req.data = snp_dev->certs_data;
+
+       ret = snp_send_guest_request(snp_dev, &req, arg);

        /* If certs length is invalid then copy the returned length */
        if (arg->vmm_error == SNP_GUEST_VMM_ERR_INVALID_LEN) {
-               req->certs_len = snp_dev->input.data_npages << PAGE_SHIFT;
+               report_req->certs_len = req.data_npages << PAGE_SHIFT;

-               if (copy_to_sockptr(io->req_data, req, sizeof(*req)))
+               if (copy_to_sockptr(io->req_data, report_req, sizeof(*report_req)))
                        ret = -EFAULT;
        }

        if (ret)
                goto e_free;

-       if (npages && copy_to_sockptr(certs_address, snp_dev->certs_data, req->certs_len)) {
+       if (req.data_npages && report_req->certs_len &&
+           copy_to_sockptr(certs_address, snp_dev->certs_data, report_req->certs_len)) {
                ret = -EFAULT;
                goto e_free;
        }

-       if (copy_to_sockptr(io->resp_data, resp, sizeof(*resp)))
+       if (copy_to_sockptr(io->resp_data, report_resp, sizeof(*report_resp)))
                ret = -EFAULT;

 e_free:
-       kfree(resp);
+       kfree(report_resp);
        return ret;
 }

@@ -868,7 +890,6 @@ static int __init sev_guest_probe(struct platform_device *pdev)
        /* initial the input address for guest request */
        snp_dev->input.req_gpa = __pa(snp_dev->request);
        snp_dev->input.resp_gpa = __pa(snp_dev->response);
-       snp_dev->input.data_gpa = __pa(snp_dev->certs_data);

        ret = tsm_register(&sev_tsm_ops, snp_dev, &tsm_report_extra_type);
        if (ret)
--
2.34.1
  
Borislav Petkov Feb. 1, 2024, 10:29 a.m. UTC | #6
On Wed, Jan 31, 2024 at 07:28:05PM +0530, Nikunj A. Dadhania wrote:
> Changed to "req" for all the guest request throughout the file. Other "req" 
> usage are renamed appropriately.

Yes, better from what I can tell.

However, I can't apply this patch in order to have a better look, it is
mangled. Next time, before you send a patch this way, send it yourself
first and try applying it.

If it doesn't work, throw away your mailer and use a proper one:

Documentation/process/email-clients.rst

> Subject: [PATCH] virt: sev-guest: Add SNP guest request structure
> 
> Add a snp_guest_req structure to simplify the function arguments. The
> structure will be used to call the SNP Guest message request API
> instead of passing a long list of parameters. Use "req" as variable name
> for guest req throughout the file and rename other variables appropriately.
> 
> Update snp_issue_guest_request() prototype to include the new guest request
> structure and move the prototype to sev_guest.h.
> 
> Signed-off-by: Nikunj A Dadhania <nikunj@amd.com>
> Tested-by: Peter Gonda <pgonda@google.com>

Tested-by: tags must be dropped if you change a patch in a non-trivial
way. And this change is not that trivial I'd say.

> ---
>  .../x86/include/asm}/sev-guest.h              |  18 ++
>  arch/x86/include/asm/sev.h                    |   8 -
>  arch/x86/kernel/sev.c                         |  16 +-
>  drivers/virt/coco/sev-guest/sev-guest.c       | 195 ++++++++++--------
>  4 files changed, 135 insertions(+), 102 deletions(-)
>  rename {drivers/virt/coco/sev-guest => arch/x86/include/asm}/sev-guest.h (78%)

I didn't notice this before: why am I getting a sev-guest.h header in
arch/x86/?

Lemme quote again the file paths we agreed upon:

https://lore.kernel.org/all/Yg5nh1RknPRwIrb8@zn.tnic/

Thx.
  
Nikunj A. Dadhania Feb. 1, 2024, 11:10 a.m. UTC | #7
On 2/1/2024 3:59 PM, Borislav Petkov wrote:
> On Wed, Jan 31, 2024 at 07:28:05PM +0530, Nikunj A. Dadhania wrote:
>> Changed to "req" for all the guest request throughout the file. Other "req" 
>> usage are renamed appropriately.
> 
> Yes, better from what I can tell.
> 
> However, I can't apply this patch in order to have a better look, it is
> mangled. Next time, before you send a patch this way, send it yourself
> first and try applying it.
>
> If it doesn't work, throw away your mailer and use a proper one:
> 
> Documentation/process/email-clients.rst

Sorry for that, will fix it. 

> 
>> Subject: [PATCH] virt: sev-guest: Add SNP guest request structure
>>
>> Add a snp_guest_req structure to simplify the function arguments. The
>> structure will be used to call the SNP Guest message request API
>> instead of passing a long list of parameters. Use "req" as variable name
>> for guest req throughout the file and rename other variables appropriately.
>>
>> Update snp_issue_guest_request() prototype to include the new guest request
>> structure and move the prototype to sev_guest.h.
>>
>> Signed-off-by: Nikunj A Dadhania <nikunj@amd.com>
>> Tested-by: Peter Gonda <pgonda@google.com>
> 
> Tested-by: tags must be dropped if you change a patch in a non-trivial
> way. And this change is not that trivial I'd say.
> 
>> ---
>>  .../x86/include/asm}/sev-guest.h              |  18 ++
>>  arch/x86/include/asm/sev.h                    |   8 -
>>  arch/x86/kernel/sev.c                         |  16 +-
>>  drivers/virt/coco/sev-guest/sev-guest.c       | 195 ++++++++++--------
>>  4 files changed, 135 insertions(+), 102 deletions(-)
>>  rename {drivers/virt/coco/sev-guest => arch/x86/include/asm}/sev-guest.h (78%)
> 
> I didn't notice this before: why am I getting a sev-guest.h header in
> arch/x86/?
> 
> Lemme quote again the file paths we agreed upon:
> 
> https://lore.kernel.org/all/Yg5nh1RknPRwIrb8@zn.tnic/

I will move it to arch/x86/coco/sev, do we need a separate "include" directory ?

As we are doing this movement, should we move guest messaging related code to arch/x86/coco/sev/guest-msg.c ?

Regards
Nikunj
  
Borislav Petkov Feb. 1, 2024, 2:07 p.m. UTC | #8
On Thu, Feb 01, 2024 at 04:40:10PM +0530, Nikunj A. Dadhania wrote:
> I will move it to arch/x86/coco/sev, do we need a separate "include" directory ?

I still don't understand why you need to move it at all?
  
Nikunj A. Dadhania Feb. 2, 2024, 3:50 a.m. UTC | #9
On 2/1/2024 7:37 PM, Borislav Petkov wrote:
> On Thu, Feb 01, 2024 at 04:40:10PM +0530, Nikunj A. Dadhania wrote:
>> I will move it to arch/x86/coco/sev, do we need a separate "include" directory ?
> 
> I still don't understand why you need to move it at all?
> 

To support Secure TSC, SNP guest messages need to be used during the early boot.
Most of the guest messaging code is currently part of sev-guest driver
and header. I have opportunistically moved the header in this patch as I was adding 
guest request structure. Movement of rest of the functions implementation 
from sev-guest.c => kernel/sev.c is done in patch 7/16.

As per https://lore.kernel.org/all/Yg5nh1RknPRwIrb8@zn.tnic/, I can move the snp 
guest messaging code implementation to arch/x86/coco/sev/guest-msg.[ch]

Regards
Nikunj
  
Borislav Petkov Feb. 2, 2024, 4:14 p.m. UTC | #10
On Fri, Feb 02, 2024 at 09:20:22AM +0530, Nikunj A. Dadhania wrote:
> I have opportunistically moved the header in this patch as I was
> adding guest request structure. Movement of rest of the functions
> implementation from sev-guest.c => kernel/sev.c is done in patch 7/16.

And kernel/sev.c has a corresponding header arch/x86/include/asm/sev.h
which is kinda *begging* to collect all the stuff that sev.c is
using instead of introducing a sev-guest.h thing which doesn't make
a lot of sense, TU-wise.
  
Nikunj A. Dadhania Feb. 5, 2024, 9:23 a.m. UTC | #11
On 2/2/2024 9:44 PM, Borislav Petkov wrote:
> On Fri, Feb 02, 2024 at 09:20:22AM +0530, Nikunj A. Dadhania wrote:
>> I have opportunistically moved the header in this patch as I was
>> adding guest request structure. Movement of rest of the functions
>> implementation from sev-guest.c => kernel/sev.c is done in patch 7/16.
> 
> And kernel/sev.c has a corresponding header arch/x86/include/asm/sev.h
> which is kinda *begging* to collect all the stuff that sev.c is
> using instead of introducing a sev-guest.h thing which doesn't make
> a lot of sense, TU-wise.
> 

Sure, below is the updated patch. Complete series is pushed here 

https://github.com/AMDESE/linux-kvm/commits/sectsc-guest-latest/

Subject: virt: sev-guest: Add SNP guest request structure

Add a snp_guest_req structure to simplify the function arguments. The
structure will be used to call the SNP Guest message request API
instead of passing a long list of parameters.

Update snp_issue_guest_request() prototype to include the new guest request
structure and move the prototype to sev.h.

Signed-off-by: Nikunj A Dadhania <nikunj@amd.com>
---
 arch/x86/include/asm/sev.h              |  75 ++++++++-
 arch/x86/kernel/sev.c                   |  15 +-
 drivers/virt/coco/sev-guest/sev-guest.c | 194 +++++++++++++-----------
 drivers/virt/coco/sev-guest/sev-guest.h |  66 --------
 4 files changed, 186 insertions(+), 164 deletions(-)
 delete mode 100644 drivers/virt/coco/sev-guest/sev-guest.h

diff --git a/arch/x86/include/asm/sev.h b/arch/x86/include/asm/sev.h
index 5b4a1ce3d368..56b07c79945a 100644
--- a/arch/x86/include/asm/sev.h
+++ b/arch/x86/include/asm/sev.h
@@ -97,8 +97,6 @@ extern bool handle_vc_boot_ghcb(struct pt_regs *regs);
 struct snp_req_data {
 	unsigned long req_gpa;
 	unsigned long resp_gpa;
-	unsigned long data_gpa;
-	unsigned int data_npages;
 };
 
 struct sev_guest_platform_data {
@@ -140,6 +138,73 @@ struct snp_secrets_page_layout {
 	u8 rsvd3[3840];
 } __packed;
 
+#define MAX_AUTHTAG_LEN		32
+#define AUTHTAG_LEN		16
+#define AAD_LEN			48
+#define MSG_HDR_VER		1
+
+/* See SNP spec SNP_GUEST_REQUEST section for the structure */
+enum msg_type {
+	SNP_MSG_TYPE_INVALID = 0,
+	SNP_MSG_CPUID_REQ,
+	SNP_MSG_CPUID_RSP,
+	SNP_MSG_KEY_REQ,
+	SNP_MSG_KEY_RSP,
+	SNP_MSG_REPORT_REQ,
+	SNP_MSG_REPORT_RSP,
+	SNP_MSG_EXPORT_REQ,
+	SNP_MSG_EXPORT_RSP,
+	SNP_MSG_IMPORT_REQ,
+	SNP_MSG_IMPORT_RSP,
+	SNP_MSG_ABSORB_REQ,
+	SNP_MSG_ABSORB_RSP,
+	SNP_MSG_VMRK_REQ,
+	SNP_MSG_VMRK_RSP,
+
+	SNP_MSG_TYPE_MAX
+};
+
+enum aead_algo {
+	SNP_AEAD_INVALID,
+	SNP_AEAD_AES_256_GCM,
+};
+
+struct snp_guest_msg_hdr {
+	u8 authtag[MAX_AUTHTAG_LEN];
+	u64 msg_seqno;
+	u8 rsvd1[8];
+	u8 algo;
+	u8 hdr_version;
+	u16 hdr_sz;
+	u8 msg_type;
+	u8 msg_version;
+	u16 msg_sz;
+	u32 rsvd2;
+	u8 msg_vmpck;
+	u8 rsvd3[35];
+} __packed;
+
+struct snp_guest_msg {
+	struct snp_guest_msg_hdr hdr;
+	u8 payload[4000];
+} __packed;
+
+struct snp_guest_req {
+	void *req_buf;
+	size_t req_sz;
+
+	void *resp_buf;
+	size_t resp_sz;
+
+	void *data;
+	size_t data_npages;
+
+	u64 exit_code;
+	unsigned int vmpck_id;
+	u8 msg_version;
+	u8 msg_type;
+};
+
 #ifdef CONFIG_AMD_MEM_ENCRYPT
 extern void __sev_es_ist_enter(struct pt_regs *regs);
 extern void __sev_es_ist_exit(void);
@@ -209,7 +274,8 @@ void snp_set_memory_private(unsigned long vaddr, unsigned long npages);
 void snp_set_wakeup_secondary_cpu(void);
 bool snp_init(struct boot_params *bp);
 void __init __noreturn snp_abort(void);
-int snp_issue_guest_request(u64 exit_code, struct snp_req_data *input, struct snp_guest_request_ioctl *rio);
+int snp_issue_guest_request(struct snp_guest_req *req, struct snp_req_data *input,
+			    struct snp_guest_request_ioctl *rio);
 void snp_accept_memory(phys_addr_t start, phys_addr_t end);
 u64 snp_get_unsupported_features(u64 status);
 u64 sev_get_status(void);
@@ -233,7 +299,8 @@ static inline void snp_set_memory_private(unsigned long vaddr, unsigned long npa
 static inline void snp_set_wakeup_secondary_cpu(void) { }
 static inline bool snp_init(struct boot_params *bp) { return false; }
 static inline void snp_abort(void) { }
-static inline int snp_issue_guest_request(u64 exit_code, struct snp_req_data *input, struct snp_guest_request_ioctl *rio)
+static inline int snp_issue_guest_request(struct snp_guest_req *req, struct snp_req_data *input,
+					  struct snp_guest_request_ioctl *rio)
 {
 	return -ENOTTY;
 }
diff --git a/arch/x86/kernel/sev.c b/arch/x86/kernel/sev.c
index c67285824e82..3d6429321536 100644
--- a/arch/x86/kernel/sev.c
+++ b/arch/x86/kernel/sev.c
@@ -2170,7 +2170,8 @@ static int __init init_sev_config(char *str)
 }
 __setup("sev=", init_sev_config);
 
-int snp_issue_guest_request(u64 exit_code, struct snp_req_data *input, struct snp_guest_request_ioctl *rio)
+int snp_issue_guest_request(struct snp_guest_req *req, struct snp_req_data *input,
+			    struct snp_guest_request_ioctl *rio)
 {
 	struct ghcb_state state;
 	struct es_em_ctxt ctxt;
@@ -2194,12 +2195,12 @@ int snp_issue_guest_request(u64 exit_code, struct snp_req_data *input, struct sn
 
 	vc_ghcb_invalidate(ghcb);
 
-	if (exit_code == SVM_VMGEXIT_EXT_GUEST_REQUEST) {
-		ghcb_set_rax(ghcb, input->data_gpa);
-		ghcb_set_rbx(ghcb, input->data_npages);
+	if (req->exit_code == SVM_VMGEXIT_EXT_GUEST_REQUEST) {
+		ghcb_set_rax(ghcb, __pa(req->data));
+		ghcb_set_rbx(ghcb, req->data_npages);
 	}
 
-	ret = sev_es_ghcb_hv_call(ghcb, &ctxt, exit_code, input->req_gpa, input->resp_gpa);
+	ret = sev_es_ghcb_hv_call(ghcb, &ctxt, req->exit_code, input->req_gpa, input->resp_gpa);
 	if (ret)
 		goto e_put;
 
@@ -2214,8 +2215,8 @@ int snp_issue_guest_request(u64 exit_code, struct snp_req_data *input, struct sn
 
 	case SNP_GUEST_VMM_ERR(SNP_GUEST_VMM_ERR_INVALID_LEN):
 		/* Number of expected pages are returned in RBX */
-		if (exit_code == SVM_VMGEXIT_EXT_GUEST_REQUEST) {
-			input->data_npages = ghcb_get_rbx(ghcb);
+		if (req->exit_code == SVM_VMGEXIT_EXT_GUEST_REQUEST) {
+			req->data_npages = ghcb_get_rbx(ghcb);
 			ret = -ENOSPC;
 			break;
 		}
diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c
index 0450c5383476..894f6974e192 100644
--- a/drivers/virt/coco/sev-guest/sev-guest.c
+++ b/drivers/virt/coco/sev-guest/sev-guest.c
@@ -28,8 +28,6 @@
 #include <asm/svm.h>
 #include <asm/sev.h>
 
-#include "sev-guest.h"
-
 #define DEVICE_NAME	"sev-guest"
 
 #define SNP_REQ_MAX_RETRY_DURATION	(60*HZ)
@@ -169,65 +167,64 @@ static struct aesgcm_ctx *snp_init_crypto(u8 *key, size_t keylen)
 	return ctx;
 }
 
-static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, void *payload, u32 sz)
+static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, struct snp_guest_req *req)
 {
-	struct snp_guest_msg *resp = &snp_dev->secret_response;
-	struct snp_guest_msg *req = &snp_dev->secret_request;
-	struct snp_guest_msg_hdr *req_hdr = &req->hdr;
-	struct snp_guest_msg_hdr *resp_hdr = &resp->hdr;
+	struct snp_guest_msg *resp_msg = &snp_dev->secret_response;
+	struct snp_guest_msg *req_msg = &snp_dev->secret_request;
+	struct snp_guest_msg_hdr *req_msg_hdr = &req_msg->hdr;
+	struct snp_guest_msg_hdr *resp_msg_hdr = &resp_msg->hdr;
 	struct aesgcm_ctx *ctx = snp_dev->ctx;
 	u8 iv[GCM_AES_IV_SIZE] = {};
 
 	pr_debug("response [seqno %lld type %d version %d sz %d]\n",
-		 resp_hdr->msg_seqno, resp_hdr->msg_type, resp_hdr->msg_version,
-		 resp_hdr->msg_sz);
+		 resp_msg_hdr->msg_seqno, resp_msg_hdr->msg_type, resp_msg_hdr->msg_version,
+		 resp_msg_hdr->msg_sz);
 
 	/* Copy response from shared memory to encrypted memory. */
-	memcpy(resp, snp_dev->response, sizeof(*resp));
+	memcpy(resp_msg, snp_dev->response, sizeof(*resp_msg));
 
 	/* Verify that the sequence counter is incremented by 1 */
-	if (unlikely(resp_hdr->msg_seqno != (req_hdr->msg_seqno + 1)))
+	if (unlikely(resp_msg_hdr->msg_seqno != (req_msg_hdr->msg_seqno + 1)))
 		return -EBADMSG;
 
 	/* Verify response message type and version number. */
-	if (resp_hdr->msg_type != (req_hdr->msg_type + 1) ||
-	    resp_hdr->msg_version != req_hdr->msg_version)
+	if (resp_msg_hdr->msg_type != (req_msg_hdr->msg_type + 1) ||
+	    resp_msg_hdr->msg_version != req_msg_hdr->msg_version)
 		return -EBADMSG;
 
 	/*
 	 * If the message size is greater than our buffer length then return
 	 * an error.
 	 */
-	if (unlikely((resp_hdr->msg_sz + ctx->authsize) > sz))
+	if (unlikely((resp_msg_hdr->msg_sz + ctx->authsize) > req->resp_sz))
 		return -EBADMSG;
 
 	/* Decrypt the payload */
-	memcpy(iv, &resp_hdr->msg_seqno, min(sizeof(iv), sizeof(resp_hdr->msg_seqno)));
-	if (!aesgcm_decrypt(ctx, payload, resp->payload, resp_hdr->msg_sz,
-			    &resp_hdr->algo, AAD_LEN, iv, resp_hdr->authtag))
+	memcpy(iv, &resp_msg_hdr->msg_seqno, min(sizeof(iv), sizeof(resp_msg_hdr->msg_seqno)));
+	if (!aesgcm_decrypt(ctx, req->resp_buf, resp_msg->payload, resp_msg_hdr->msg_sz,
+			    &resp_msg_hdr->algo, AAD_LEN, iv, resp_msg_hdr->authtag))
 		return -EBADMSG;
 
 	return 0;
 }
 
-static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, int version, u8 type,
-			void *payload, size_t sz)
+static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, struct snp_guest_req *req)
 {
-	struct snp_guest_msg *req = &snp_dev->secret_request;
-	struct snp_guest_msg_hdr *hdr = &req->hdr;
+	struct snp_guest_msg *msg = &snp_dev->secret_request;
+	struct snp_guest_msg_hdr *hdr = &msg->hdr;
 	struct aesgcm_ctx *ctx = snp_dev->ctx;
 	u8 iv[GCM_AES_IV_SIZE] = {};
 
-	memset(req, 0, sizeof(*req));
+	memset(msg, 0, sizeof(*msg));
 
 	hdr->algo = SNP_AEAD_AES_256_GCM;
 	hdr->hdr_version = MSG_HDR_VER;
 	hdr->hdr_sz = sizeof(*hdr);
-	hdr->msg_type = type;
-	hdr->msg_version = version;
+	hdr->msg_type = req->msg_type;
+	hdr->msg_version = req->msg_version;
 	hdr->msg_seqno = seqno;
-	hdr->msg_vmpck = vmpck_id;
-	hdr->msg_sz = sz;
+	hdr->msg_vmpck = req->vmpck_id;
+	hdr->msg_sz = req->req_sz;
 
 	/* Verify the sequence number is non-zero */
 	if (!hdr->msg_seqno)
@@ -236,17 +233,17 @@ static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, int version, u8
 	pr_debug("request [seqno %lld type %d version %d sz %d]\n",
 		 hdr->msg_seqno, hdr->msg_type, hdr->msg_version, hdr->msg_sz);
 
-	if (WARN_ON((sz + ctx->authsize) > sizeof(req->payload)))
+	if (WARN_ON((req->req_sz + ctx->authsize) > sizeof(msg->payload)))
 		return -EBADMSG;
 
 	memcpy(iv, &hdr->msg_seqno, min(sizeof(iv), sizeof(hdr->msg_seqno)));
-	aesgcm_encrypt(ctx, req->payload, payload, sz, &hdr->algo, AAD_LEN,
-		       iv, hdr->authtag);
+	aesgcm_encrypt(ctx, msg->payload, req->req_buf, req->req_sz, &hdr->algo,
+		       AAD_LEN, iv, hdr->authtag);
 
 	return 0;
 }
 
-static int __handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
+static int __handle_guest_request(struct snp_guest_dev *snp_dev, struct snp_guest_req *req,
 				  struct snp_guest_request_ioctl *rio)
 {
 	unsigned long req_start = jiffies;
@@ -261,7 +258,7 @@ static int __handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
 	 * sequence number must be incremented or the VMPCK must be deleted to
 	 * prevent reuse of the IV.
 	 */
-	rc = snp_issue_guest_request(exit_code, &snp_dev->input, rio);
+	rc = snp_issue_guest_request(req, &snp_dev->input, rio);
 	switch (rc) {
 	case -ENOSPC:
 		/*
@@ -271,8 +268,8 @@ static int __handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
 		 * order to increment the sequence number and thus avoid
 		 * IV reuse.
 		 */
-		override_npages = snp_dev->input.data_npages;
-		exit_code	= SVM_VMGEXIT_GUEST_REQUEST;
+		override_npages = req->data_npages;
+		req->exit_code	= SVM_VMGEXIT_GUEST_REQUEST;
 
 		/*
 		 * Override the error to inform callers the given extended
@@ -327,15 +324,13 @@ static int __handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
 	}
 
 	if (override_npages)
-		snp_dev->input.data_npages = override_npages;
+		req->data_npages = override_npages;
 
 	return rc;
 }
 
-static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
-				struct snp_guest_request_ioctl *rio, u8 type,
-				void *req_buf, size_t req_sz, void *resp_buf,
-				u32 resp_sz)
+static int snp_send_guest_request(struct snp_guest_dev *snp_dev, struct snp_guest_req *req,
+				  struct snp_guest_request_ioctl *rio)
 {
 	u64 seqno;
 	int rc;
@@ -349,7 +344,7 @@ static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
 	memset(snp_dev->response, 0, sizeof(struct snp_guest_msg));
 
 	/* Encrypt the userspace provided payload in snp_dev->secret_request. */
-	rc = enc_payload(snp_dev, seqno, rio->msg_version, type, req_buf, req_sz);
+	rc = enc_payload(snp_dev, seqno, req);
 	if (rc)
 		return rc;
 
@@ -360,7 +355,7 @@ static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
 	memcpy(snp_dev->request, &snp_dev->secret_request,
 	       sizeof(snp_dev->secret_request));
 
-	rc = __handle_guest_request(snp_dev, exit_code, rio);
+	rc = __handle_guest_request(snp_dev, req, rio);
 	if (rc) {
 		if (rc == -EIO &&
 		    rio->exitinfo2 == SNP_GUEST_VMM_ERR(SNP_GUEST_VMM_ERR_INVALID_LEN))
@@ -369,12 +364,11 @@ static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
 		dev_alert(snp_dev->dev,
 			  "Detected error from ASP request. rc: %d, exitinfo2: 0x%llx\n",
 			  rc, rio->exitinfo2);
-
 		snp_disable_vmpck(snp_dev);
 		return rc;
 	}
 
-	rc = verify_and_dec_payload(snp_dev, resp_buf, resp_sz);
+	rc = verify_and_dec_payload(snp_dev, req);
 	if (rc) {
 		dev_alert(snp_dev->dev, "Detected unexpected decode failure from ASP. rc: %d\n", rc);
 		snp_disable_vmpck(snp_dev);
@@ -391,8 +385,9 @@ struct snp_req_resp {
 
 static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg)
 {
-	struct snp_report_req *req = &snp_dev->req.report;
-	struct snp_report_resp *resp;
+	struct snp_report_req *report_req = &snp_dev->req.report;
+	struct snp_guest_req req = {0};
+	struct snp_report_resp *report_resp;
 	int rc, resp_len;
 
 	lockdep_assert_held(&snp_cmd_mutex);
@@ -400,7 +395,7 @@ static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_io
 	if (!arg->req_data || !arg->resp_data)
 		return -EINVAL;
 
-	if (copy_from_user(req, (void __user *)arg->req_data, sizeof(*req)))
+	if (copy_from_user(report_req, (void __user *)arg->req_data, sizeof(*report_req)))
 		return -EFAULT;
 
 	/*
@@ -408,29 +403,37 @@ static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_io
 	 * response payload. Make sure that it has enough space to cover the
 	 * authtag.
 	 */
-	resp_len = sizeof(resp->data) + snp_dev->ctx->authsize;
-	resp = kzalloc(resp_len, GFP_KERNEL_ACCOUNT);
-	if (!resp)
+	resp_len = sizeof(report_resp->data) + snp_dev->ctx->authsize;
+	report_resp = kzalloc(resp_len, GFP_KERNEL_ACCOUNT);
+	if (!report_resp)
 		return -ENOMEM;
 
-	rc = handle_guest_request(snp_dev, SVM_VMGEXIT_GUEST_REQUEST, arg,
-				  SNP_MSG_REPORT_REQ, req, sizeof(*req), resp->data,
-				  resp_len);
+	req.msg_version = arg->msg_version;
+	req.msg_type = SNP_MSG_REPORT_REQ;
+	req.vmpck_id = vmpck_id;
+	req.req_buf = report_req;
+	req.req_sz = sizeof(*report_req);
+	req.resp_buf = report_resp->data;
+	req.resp_sz = resp_len;
+	req.exit_code = SVM_VMGEXIT_GUEST_REQUEST;
+
+	rc = snp_send_guest_request(snp_dev, &req, arg);
 	if (rc)
 		goto e_free;
 
-	if (copy_to_user((void __user *)arg->resp_data, resp, sizeof(*resp)))
+	if (copy_to_user((void __user *)arg->resp_data, report_resp, sizeof(*report_resp)))
 		rc = -EFAULT;
 
 e_free:
-	kfree(resp);
+	kfree(report_resp);
 	return rc;
 }
 
 static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg)
 {
-	struct snp_derived_key_req *req = &snp_dev->req.derived_key;
-	struct snp_derived_key_resp resp = {0};
+	struct snp_derived_key_req *derived_key_req = &snp_dev->req.derived_key;
+	struct snp_derived_key_resp derived_key_resp = {0};
+	struct snp_guest_req req = {0};
 	int rc, resp_len;
 	/* Response data is 64 bytes and max authsize for GCM is 16 bytes. */
 	u8 buf[64 + 16];
@@ -445,25 +448,34 @@ static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_reque
 	 * response payload. Make sure that it has enough space to cover the
 	 * authtag.
 	 */
-	resp_len = sizeof(resp.data) + snp_dev->ctx->authsize;
+	resp_len = sizeof(derived_key_resp.data) + snp_dev->ctx->authsize;
 	if (sizeof(buf) < resp_len)
 		return -ENOMEM;
 
-	if (copy_from_user(req, (void __user *)arg->req_data, sizeof(*req)))
+	if (copy_from_user(derived_key_req, (void __user *)arg->req_data,
+			   sizeof(*derived_key_req)))
 		return -EFAULT;
 
-	rc = handle_guest_request(snp_dev, SVM_VMGEXIT_GUEST_REQUEST, arg,
-				  SNP_MSG_KEY_REQ, req, sizeof(*req), buf, resp_len);
+	req.msg_version = arg->msg_version;
+	req.msg_type = SNP_MSG_KEY_REQ;
+	req.vmpck_id = vmpck_id;
+	req.req_buf = derived_key_req;
+	req.req_sz = sizeof(*derived_key_req);
+	req.resp_buf = buf;
+	req.resp_sz = resp_len;
+	req.exit_code = SVM_VMGEXIT_GUEST_REQUEST;
+
+	rc = snp_send_guest_request(snp_dev, &req, arg);
 	if (rc)
 		return rc;
 
-	memcpy(resp.data, buf, sizeof(resp.data));
-	if (copy_to_user((void __user *)arg->resp_data, &resp, sizeof(resp)))
+	memcpy(derived_key_resp.data, buf, sizeof(derived_key_resp.data));
+	if (copy_to_user((void __user *)arg->resp_data, &derived_key_resp, sizeof(derived_key_resp)))
 		rc = -EFAULT;
 
 	/* The response buffer contains the sensitive data, explicitly clear it. */
 	memzero_explicit(buf, sizeof(buf));
-	memzero_explicit(&resp, sizeof(resp));
+	memzero_explicit(&derived_key_resp, sizeof(derived_key_resp));
 	return rc;
 }
 
@@ -471,32 +483,33 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
 			  struct snp_req_resp *io)
 
 {
-	struct snp_ext_report_req *req = &snp_dev->req.ext_report;
-	struct snp_report_resp *resp;
-	int ret, npages = 0, resp_len;
+	struct snp_ext_report_req *report_req = &snp_dev->req.ext_report;
+	struct snp_guest_req req = {0};
+	struct snp_report_resp *report_resp;
 	sockptr_t certs_address;
+	int ret, resp_len;
 
 	lockdep_assert_held(&snp_cmd_mutex);
 
 	if (sockptr_is_null(io->req_data) || sockptr_is_null(io->resp_data))
 		return -EINVAL;
 
-	if (copy_from_sockptr(req, io->req_data, sizeof(*req)))
+	if (copy_from_sockptr(report_req, io->req_data, sizeof(*report_req)))
 		return -EFAULT;
 
 	/* caller does not want certificate data */
-	if (!req->certs_len || !req->certs_address)
+	if (!report_req->certs_len || !report_req->certs_address)
 		goto cmd;
 
-	if (req->certs_len > SEV_FW_BLOB_MAX_SIZE ||
-	    !IS_ALIGNED(req->certs_len, PAGE_SIZE))
+	if (report_req->certs_len > SEV_FW_BLOB_MAX_SIZE ||
+	    !IS_ALIGNED(report_req->certs_len, PAGE_SIZE))
 		return -EINVAL;
 
 	if (sockptr_is_kernel(io->resp_data)) {
-		certs_address = KERNEL_SOCKPTR((void *)req->certs_address);
+		certs_address = KERNEL_SOCKPTR((void *)report_req->certs_address);
 	} else {
-		certs_address = USER_SOCKPTR((void __user *)req->certs_address);
-		if (!access_ok(certs_address.user, req->certs_len))
+		certs_address = USER_SOCKPTR((void __user *)report_req->certs_address);
+		if (!access_ok(certs_address.user, report_req->certs_len))
 			return -EFAULT;
 	}
 
@@ -506,45 +519,53 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
 	 * the host. If host does not supply any certs in it, then copy
 	 * zeros to indicate that certificate data was not provided.
 	 */
-	memset(snp_dev->certs_data, 0, req->certs_len);
-	npages = req->certs_len >> PAGE_SHIFT;
+	memset(snp_dev->certs_data, 0, report_req->certs_len);
+	req.data_npages = report_req->certs_len >> PAGE_SHIFT;
 cmd:
 	/*
 	 * The intermediate response buffer is used while decrypting the
 	 * response payload. Make sure that it has enough space to cover the
 	 * authtag.
 	 */
-	resp_len = sizeof(resp->data) + snp_dev->ctx->authsize;
-	resp = kzalloc(resp_len, GFP_KERNEL_ACCOUNT);
-	if (!resp)
+	resp_len = sizeof(report_resp->data) + snp_dev->ctx->authsize;
+	report_resp = kzalloc(resp_len, GFP_KERNEL_ACCOUNT);
+	if (!report_resp)
 		return -ENOMEM;
 
-	snp_dev->input.data_npages = npages;
-	ret = handle_guest_request(snp_dev, SVM_VMGEXIT_EXT_GUEST_REQUEST, arg,
-				   SNP_MSG_REPORT_REQ, &req->data,
-				   sizeof(req->data), resp->data, resp_len);
+	req.msg_version = arg->msg_version;
+	req.msg_type = SNP_MSG_REPORT_REQ;
+	req.vmpck_id = vmpck_id;
+	req.req_buf = &report_req->data;
+	req.req_sz = sizeof(report_req->data);
+	req.resp_buf = report_resp->data;
+	req.resp_sz = resp_len;
+	req.exit_code = SVM_VMGEXIT_EXT_GUEST_REQUEST;
+	req.data = snp_dev->certs_data;
+
+	ret = snp_send_guest_request(snp_dev, &req, arg);
 
 	/* If certs length is invalid then copy the returned length */
 	if (arg->vmm_error == SNP_GUEST_VMM_ERR_INVALID_LEN) {
-		req->certs_len = snp_dev->input.data_npages << PAGE_SHIFT;
+		report_req->certs_len = req.data_npages << PAGE_SHIFT;
 
-		if (copy_to_sockptr(io->req_data, req, sizeof(*req)))
+		if (copy_to_sockptr(io->req_data, report_req, sizeof(*report_req)))
 			ret = -EFAULT;
 	}
 
 	if (ret)
 		goto e_free;
 
-	if (npages && copy_to_sockptr(certs_address, snp_dev->certs_data, req->certs_len)) {
+	if (req.data_npages && report_req->certs_len &&
+	    copy_to_sockptr(certs_address, snp_dev->certs_data, report_req->certs_len)) {
 		ret = -EFAULT;
 		goto e_free;
 	}
 
-	if (copy_to_sockptr(io->resp_data, resp, sizeof(*resp)))
+	if (copy_to_sockptr(io->resp_data, report_resp, sizeof(*report_resp)))
 		ret = -EFAULT;
 
 e_free:
-	kfree(resp);
+	kfree(report_resp);
 	return ret;
 }
 
@@ -868,7 +889,6 @@ static int __init sev_guest_probe(struct platform_device *pdev)
 	/* initial the input address for guest request */
 	snp_dev->input.req_gpa = __pa(snp_dev->request);
 	snp_dev->input.resp_gpa = __pa(snp_dev->response);
-	snp_dev->input.data_gpa = __pa(snp_dev->certs_data);
 
 	ret = tsm_register(&sev_tsm_ops, snp_dev, &tsm_report_extra_type);
 	if (ret)
diff --git a/drivers/virt/coco/sev-guest/sev-guest.h b/drivers/virt/coco/sev-guest/sev-guest.h
deleted file mode 100644
index ceb798a404d6..000000000000
--- a/drivers/virt/coco/sev-guest/sev-guest.h
+++ /dev/null
@@ -1,66 +0,0 @@
-/* SPDX-License-Identifier: GPL-2.0-only */
-/*
- * Copyright (C) 2021 Advanced Micro Devices, Inc.
- *
- * Author: Brijesh Singh <brijesh.singh@amd.com>
- *
- * SEV-SNP API spec is available at https://developer.amd.com/sev
- */
-
-#ifndef __VIRT_SEVGUEST_H__
-#define __VIRT_SEVGUEST_H__
-
-#include <linux/types.h>
-
-#define MAX_AUTHTAG_LEN		32
-#define AUTHTAG_LEN		16
-#define AAD_LEN			48
-#define MSG_HDR_VER		1
-
-/* See SNP spec SNP_GUEST_REQUEST section for the structure */
-enum msg_type {
-	SNP_MSG_TYPE_INVALID = 0,
-	SNP_MSG_CPUID_REQ,
-	SNP_MSG_CPUID_RSP,
-	SNP_MSG_KEY_REQ,
-	SNP_MSG_KEY_RSP,
-	SNP_MSG_REPORT_REQ,
-	SNP_MSG_REPORT_RSP,
-	SNP_MSG_EXPORT_REQ,
-	SNP_MSG_EXPORT_RSP,
-	SNP_MSG_IMPORT_REQ,
-	SNP_MSG_IMPORT_RSP,
-	SNP_MSG_ABSORB_REQ,
-	SNP_MSG_ABSORB_RSP,
-	SNP_MSG_VMRK_REQ,
-	SNP_MSG_VMRK_RSP,
-
-	SNP_MSG_TYPE_MAX
-};
-
-enum aead_algo {
-	SNP_AEAD_INVALID,
-	SNP_AEAD_AES_256_GCM,
-};
-
-struct snp_guest_msg_hdr {
-	u8 authtag[MAX_AUTHTAG_LEN];
-	u64 msg_seqno;
-	u8 rsvd1[8];
-	u8 algo;
-	u8 hdr_version;
-	u16 hdr_sz;
-	u8 msg_type;
-	u8 msg_version;
-	u16 msg_sz;
-	u32 rsvd2;
-	u8 msg_vmpck;
-	u8 rsvd3[35];
-} __packed;
-
-struct snp_guest_msg {
-	struct snp_guest_msg_hdr hdr;
-	u8 payload[4000];
-} __packed;
-
-#endif /* __VIRT_SEVGUEST_H__ */
  
Borislav Petkov Feb. 6, 2024, 10:04 a.m. UTC | #12
On Mon, Feb 05, 2024 at 02:53:30PM +0530, Nikunj A. Dadhania wrote:
> Sure, below is the updated patch. Complete series is pushed here 
> 
> https://github.com/AMDESE/linux-kvm/commits/sectsc-guest-latest/

Yap, that looks more like it.

Thx.
  

Patch

diff --git a/drivers/virt/coco/sev-guest/sev-guest.h b/arch/x86/include/asm/sev-guest.h
similarity index 78%
rename from drivers/virt/coco/sev-guest/sev-guest.h
rename to arch/x86/include/asm/sev-guest.h
index ceb798a404d6..27cc15ad6131 100644
--- a/drivers/virt/coco/sev-guest/sev-guest.h
+++ b/arch/x86/include/asm/sev-guest.h
@@ -63,4 +63,22 @@  struct snp_guest_msg {
 	u8 payload[4000];
 } __packed;
 
+struct snp_guest_req {
+	void *req_buf;
+	size_t req_sz;
+
+	void *resp_buf;
+	size_t resp_sz;
+
+	void *data;
+	size_t data_npages;
+
+	u64 exit_code;
+	unsigned int vmpck_id;
+	u8 msg_version;
+	u8 msg_type;
+};
+
+int snp_issue_guest_request(struct snp_guest_req *req, struct snp_req_data *input,
+			    struct snp_guest_request_ioctl *rio);
 #endif /* __VIRT_SEVGUEST_H__ */
diff --git a/arch/x86/include/asm/sev.h b/arch/x86/include/asm/sev.h
index 5b4a1ce3d368..78465a8c7dc6 100644
--- a/arch/x86/include/asm/sev.h
+++ b/arch/x86/include/asm/sev.h
@@ -97,8 +97,6 @@  extern bool handle_vc_boot_ghcb(struct pt_regs *regs);
 struct snp_req_data {
 	unsigned long req_gpa;
 	unsigned long resp_gpa;
-	unsigned long data_gpa;
-	unsigned int data_npages;
 };
 
 struct sev_guest_platform_data {
@@ -209,7 +207,6 @@  void snp_set_memory_private(unsigned long vaddr, unsigned long npages);
 void snp_set_wakeup_secondary_cpu(void);
 bool snp_init(struct boot_params *bp);
 void __init __noreturn snp_abort(void);
-int snp_issue_guest_request(u64 exit_code, struct snp_req_data *input, struct snp_guest_request_ioctl *rio);
 void snp_accept_memory(phys_addr_t start, phys_addr_t end);
 u64 snp_get_unsupported_features(u64 status);
 u64 sev_get_status(void);
@@ -233,11 +230,6 @@  static inline void snp_set_memory_private(unsigned long vaddr, unsigned long npa
 static inline void snp_set_wakeup_secondary_cpu(void) { }
 static inline bool snp_init(struct boot_params *bp) { return false; }
 static inline void snp_abort(void) { }
-static inline int snp_issue_guest_request(u64 exit_code, struct snp_req_data *input, struct snp_guest_request_ioctl *rio)
-{
-	return -ENOTTY;
-}
-
 static inline void snp_accept_memory(phys_addr_t start, phys_addr_t end) { }
 static inline u64 snp_get_unsupported_features(u64 status) { return 0; }
 static inline u64 sev_get_status(void) { return 0; }
diff --git a/arch/x86/kernel/sev.c b/arch/x86/kernel/sev.c
index c67285824e82..fd89aca22f6a 100644
--- a/arch/x86/kernel/sev.c
+++ b/arch/x86/kernel/sev.c
@@ -28,6 +28,7 @@ 
 #include <asm/cpu_entry_area.h>
 #include <asm/stacktrace.h>
 #include <asm/sev.h>
+#include <asm/sev-guest.h>
 #include <asm/insn-eval.h>
 #include <asm/fpu/xcr.h>
 #include <asm/processor.h>
@@ -2170,15 +2171,21 @@  static int __init init_sev_config(char *str)
 }
 __setup("sev=", init_sev_config);
 
-int snp_issue_guest_request(u64 exit_code, struct snp_req_data *input, struct snp_guest_request_ioctl *rio)
+int snp_issue_guest_request(struct snp_guest_req *req, struct snp_req_data *input,
+			    struct snp_guest_request_ioctl *rio)
 {
 	struct ghcb_state state;
 	struct es_em_ctxt ctxt;
 	unsigned long flags;
 	struct ghcb *ghcb;
+	u64 exit_code;
 	int ret;
 
 	rio->exitinfo2 = SEV_RET_NO_FW_CALL;
+	if (!req)
+		return -EINVAL;
+
+	exit_code = req->exit_code;
 
 	/*
 	 * __sev_get_ghcb() needs to run with IRQs disabled because it is using
@@ -2195,8 +2202,8 @@  int snp_issue_guest_request(u64 exit_code, struct snp_req_data *input, struct sn
 	vc_ghcb_invalidate(ghcb);
 
 	if (exit_code == SVM_VMGEXIT_EXT_GUEST_REQUEST) {
-		ghcb_set_rax(ghcb, input->data_gpa);
-		ghcb_set_rbx(ghcb, input->data_npages);
+		ghcb_set_rax(ghcb, __pa(req->data));
+		ghcb_set_rbx(ghcb, req->data_npages);
 	}
 
 	ret = sev_es_ghcb_hv_call(ghcb, &ctxt, exit_code, input->req_gpa, input->resp_gpa);
@@ -2215,7 +2222,7 @@  int snp_issue_guest_request(u64 exit_code, struct snp_req_data *input, struct sn
 	case SNP_GUEST_VMM_ERR(SNP_GUEST_VMM_ERR_INVALID_LEN):
 		/* Number of expected pages are returned in RBX */
 		if (exit_code == SVM_VMGEXIT_EXT_GUEST_REQUEST) {
-			input->data_npages = ghcb_get_rbx(ghcb);
+			req->data_npages = ghcb_get_rbx(ghcb);
 			ret = -ENOSPC;
 			break;
 		}
diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c
index 469e10d9bf35..5cafbd1c42cb 100644
--- a/drivers/virt/coco/sev-guest/sev-guest.c
+++ b/drivers/virt/coco/sev-guest/sev-guest.c
@@ -27,8 +27,7 @@ 
 
 #include <asm/svm.h>
 #include <asm/sev.h>
-
-#include "sev-guest.h"
+#include <asm/sev-guest.h>
 
 #define DEVICE_NAME	"sev-guest"
 
@@ -169,7 +168,7 @@  static struct aesgcm_ctx *snp_init_crypto(u8 *key, size_t keylen)
 	return ctx;
 }
 
-static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, void *payload, u32 sz)
+static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, struct snp_guest_req *guest_req)
 {
 	struct snp_guest_msg *resp = &snp_dev->secret_response;
 	struct snp_guest_msg *req = &snp_dev->secret_request;
@@ -198,36 +197,35 @@  static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, void *payload,
 	 * If the message size is greater than our buffer length then return
 	 * an error.
 	 */
-	if (unlikely((resp_hdr->msg_sz + ctx->authsize) > sz))
+	if (unlikely((resp_hdr->msg_sz + ctx->authsize) > guest_req->resp_sz))
 		return -EBADMSG;
 
 	/* Decrypt the payload */
 	memcpy(iv, &resp_hdr->msg_seqno, sizeof(resp_hdr->msg_seqno));
-	if (!aesgcm_decrypt(ctx, payload, resp->payload, resp_hdr->msg_sz,
+	if (!aesgcm_decrypt(ctx, guest_req->resp_buf, resp->payload, resp_hdr->msg_sz,
 			    &resp_hdr->algo, AAD_LEN, iv, resp_hdr->authtag))
 		return -EBADMSG;
 
 	return 0;
 }
 
-static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, int version, u8 type,
-			void *payload, size_t sz)
+static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, struct snp_guest_req *req)
 {
-	struct snp_guest_msg *req = &snp_dev->secret_request;
-	struct snp_guest_msg_hdr *hdr = &req->hdr;
+	struct snp_guest_msg *msg = &snp_dev->secret_request;
+	struct snp_guest_msg_hdr *hdr = &msg->hdr;
 	struct aesgcm_ctx *ctx = snp_dev->ctx;
 	u8 iv[GCM_AES_IV_SIZE] = {};
 
-	memset(req, 0, sizeof(*req));
+	memset(msg, 0, sizeof(*msg));
 
 	hdr->algo = SNP_AEAD_AES_256_GCM;
 	hdr->hdr_version = MSG_HDR_VER;
 	hdr->hdr_sz = sizeof(*hdr);
-	hdr->msg_type = type;
-	hdr->msg_version = version;
+	hdr->msg_type = req->msg_type;
+	hdr->msg_version = req->msg_version;
 	hdr->msg_seqno = seqno;
-	hdr->msg_vmpck = vmpck_id;
-	hdr->msg_sz = sz;
+	hdr->msg_vmpck = req->vmpck_id;
+	hdr->msg_sz = req->req_sz;
 
 	/* Verify the sequence number is non-zero */
 	if (!hdr->msg_seqno)
@@ -236,17 +234,17 @@  static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, int version, u8
 	pr_debug("request [seqno %lld type %d version %d sz %d]\n",
 		 hdr->msg_seqno, hdr->msg_type, hdr->msg_version, hdr->msg_sz);
 
-	if (WARN_ON((sz + ctx->authsize) > sizeof(req->payload)))
+	if (WARN_ON((req->req_sz + ctx->authsize) > sizeof(msg->payload)))
 		return -EBADMSG;
 
 	memcpy(iv, &hdr->msg_seqno, sizeof(hdr->msg_seqno));
-	aesgcm_encrypt(ctx, req->payload, payload, sz, &hdr->algo, AAD_LEN,
-		       iv, hdr->authtag);
+	aesgcm_encrypt(ctx, msg->payload, req->req_buf, req->req_sz, &hdr->algo,
+		       AAD_LEN, iv, hdr->authtag);
 
 	return 0;
 }
 
-static int __handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
+static int __handle_guest_request(struct snp_guest_dev *snp_dev, struct snp_guest_req *req,
 				  struct snp_guest_request_ioctl *rio)
 {
 	unsigned long req_start = jiffies;
@@ -261,7 +259,7 @@  static int __handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
 	 * sequence number must be incremented or the VMPCK must be deleted to
 	 * prevent reuse of the IV.
 	 */
-	rc = snp_issue_guest_request(exit_code, &snp_dev->input, rio);
+	rc = snp_issue_guest_request(req, &snp_dev->input, rio);
 	switch (rc) {
 	case -ENOSPC:
 		/*
@@ -271,8 +269,8 @@  static int __handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
 		 * order to increment the sequence number and thus avoid
 		 * IV reuse.
 		 */
-		override_npages = snp_dev->input.data_npages;
-		exit_code	= SVM_VMGEXIT_GUEST_REQUEST;
+		override_npages = req->data_npages;
+		req->exit_code	= SVM_VMGEXIT_GUEST_REQUEST;
 
 		/*
 		 * Override the error to inform callers the given extended
@@ -327,15 +325,13 @@  static int __handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
 	}
 
 	if (override_npages)
-		snp_dev->input.data_npages = override_npages;
+		req->data_npages = override_npages;
 
 	return rc;
 }
 
-static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
-				struct snp_guest_request_ioctl *rio, u8 type,
-				void *req_buf, size_t req_sz, void *resp_buf,
-				u32 resp_sz)
+static int snp_send_guest_request(struct snp_guest_dev *snp_dev, struct snp_guest_req *req,
+				  struct snp_guest_request_ioctl *rio)
 {
 	u64 seqno;
 	int rc;
@@ -349,7 +345,7 @@  static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
 	memset(snp_dev->response, 0, sizeof(struct snp_guest_msg));
 
 	/* Encrypt the userspace provided payload in snp_dev->secret_request. */
-	rc = enc_payload(snp_dev, seqno, rio->msg_version, type, req_buf, req_sz);
+	rc = enc_payload(snp_dev, seqno, req);
 	if (rc)
 		return rc;
 
@@ -360,7 +356,7 @@  static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
 	memcpy(snp_dev->request, &snp_dev->secret_request,
 	       sizeof(snp_dev->secret_request));
 
-	rc = __handle_guest_request(snp_dev, exit_code, rio);
+	rc = __handle_guest_request(snp_dev, req, rio);
 	if (rc) {
 		if (rc == -EIO &&
 		    rio->exitinfo2 == SNP_GUEST_VMM_ERR(SNP_GUEST_VMM_ERR_INVALID_LEN))
@@ -369,12 +365,11 @@  static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
 		dev_alert(snp_dev->dev,
 			  "Detected error from ASP request. rc: %d, exitinfo2: 0x%llx\n",
 			  rc, rio->exitinfo2);
-
 		snp_disable_vmpck(snp_dev);
 		return rc;
 	}
 
-	rc = verify_and_dec_payload(snp_dev, resp_buf, resp_sz);
+	rc = verify_and_dec_payload(snp_dev, req);
 	if (rc) {
 		dev_alert(snp_dev->dev, "Detected unexpected decode failure from ASP. rc: %d\n", rc);
 		snp_disable_vmpck(snp_dev);
@@ -392,6 +387,7 @@  struct snp_req_resp {
 static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg)
 {
 	struct snp_report_req *req = &snp_dev->req.report;
+	struct snp_guest_req guest_req = {0};
 	struct snp_report_resp *resp;
 	int rc, resp_len;
 
@@ -413,9 +409,16 @@  static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_io
 	if (!resp)
 		return -ENOMEM;
 
-	rc = handle_guest_request(snp_dev, SVM_VMGEXIT_GUEST_REQUEST, arg,
-				  SNP_MSG_REPORT_REQ, req, sizeof(*req), resp->data,
-				  resp_len);
+	guest_req.msg_version = arg->msg_version;
+	guest_req.msg_type = SNP_MSG_REPORT_REQ;
+	guest_req.vmpck_id = vmpck_id;
+	guest_req.req_buf = req;
+	guest_req.req_sz = sizeof(*req);
+	guest_req.resp_buf = resp->data;
+	guest_req.resp_sz = resp_len;
+	guest_req.exit_code = SVM_VMGEXIT_GUEST_REQUEST;
+
+	rc = snp_send_guest_request(snp_dev, &guest_req, arg);
 	if (rc)
 		goto e_free;
 
@@ -431,6 +434,7 @@  static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_reque
 {
 	struct snp_derived_key_req *req = &snp_dev->req.derived_key;
 	struct snp_derived_key_resp resp = {0};
+	struct snp_guest_req guest_req = {0};
 	int rc, resp_len;
 	/* Response data is 64 bytes and max authsize for GCM is 16 bytes. */
 	u8 buf[64 + 16];
@@ -452,8 +456,16 @@  static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_reque
 	if (copy_from_user(req, (void __user *)arg->req_data, sizeof(*req)))
 		return -EFAULT;
 
-	rc = handle_guest_request(snp_dev, SVM_VMGEXIT_GUEST_REQUEST, arg,
-				  SNP_MSG_KEY_REQ, req, sizeof(*req), buf, resp_len);
+	guest_req.msg_version = arg->msg_version;
+	guest_req.msg_type = SNP_MSG_KEY_REQ;
+	guest_req.vmpck_id = vmpck_id;
+	guest_req.req_buf = req;
+	guest_req.req_sz = sizeof(*req);
+	guest_req.resp_buf = buf;
+	guest_req.resp_sz = resp_len;
+	guest_req.exit_code = SVM_VMGEXIT_GUEST_REQUEST;
+
+	rc = snp_send_guest_request(snp_dev, &guest_req, arg);
 	if (rc)
 		return rc;
 
@@ -472,9 +484,10 @@  static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
 
 {
 	struct snp_ext_report_req *req = &snp_dev->req.ext_report;
+	struct snp_guest_req guest_req = {0};
 	struct snp_report_resp *resp;
-	int ret, npages = 0, resp_len;
 	sockptr_t certs_address;
+	int ret, resp_len;
 
 	lockdep_assert_held(&snp_cmd_mutex);
 
@@ -507,7 +520,7 @@  static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
 	 * zeros to indicate that certificate data was not provided.
 	 */
 	memset(snp_dev->certs_data, 0, req->certs_len);
-	npages = req->certs_len >> PAGE_SHIFT;
+	guest_req.data_npages = req->certs_len >> PAGE_SHIFT;
 cmd:
 	/*
 	 * The intermediate response buffer is used while decrypting the
@@ -519,14 +532,21 @@  static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
 	if (!resp)
 		return -ENOMEM;
 
-	snp_dev->input.data_npages = npages;
-	ret = handle_guest_request(snp_dev, SVM_VMGEXIT_EXT_GUEST_REQUEST, arg,
-				   SNP_MSG_REPORT_REQ, &req->data,
-				   sizeof(req->data), resp->data, resp_len);
+	guest_req.msg_version = arg->msg_version;
+	guest_req.msg_type = SNP_MSG_REPORT_REQ;
+	guest_req.vmpck_id = vmpck_id;
+	guest_req.req_buf = &req->data;
+	guest_req.req_sz = sizeof(req->data);
+	guest_req.resp_buf = resp->data;
+	guest_req.resp_sz = resp_len;
+	guest_req.exit_code = SVM_VMGEXIT_EXT_GUEST_REQUEST;
+	guest_req.data = snp_dev->certs_data;
+
+	ret = snp_send_guest_request(snp_dev, &guest_req, arg);
 
 	/* If certs length is invalid then copy the returned length */
 	if (arg->vmm_error == SNP_GUEST_VMM_ERR_INVALID_LEN) {
-		req->certs_len = snp_dev->input.data_npages << PAGE_SHIFT;
+		req->certs_len = guest_req.data_npages << PAGE_SHIFT;
 
 		if (copy_to_sockptr(io->req_data, req, sizeof(*req)))
 			ret = -EFAULT;
@@ -535,7 +555,8 @@  static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
 	if (ret)
 		goto e_free;
 
-	if (npages && copy_to_sockptr(certs_address, snp_dev->certs_data, req->certs_len)) {
+	if (guest_req.data_npages && req->certs_len &&
+	    copy_to_sockptr(certs_address, snp_dev->certs_data, req->certs_len)) {
 		ret = -EFAULT;
 		goto e_free;
 	}
@@ -868,7 +889,6 @@  static int __init sev_guest_probe(struct platform_device *pdev)
 	/* initial the input address for guest request */
 	snp_dev->input.req_gpa = __pa(snp_dev->request);
 	snp_dev->input.resp_gpa = __pa(snp_dev->response);
-	snp_dev->input.data_gpa = __pa(snp_dev->certs_data);
 
 	ret = tsm_register(&sev_tsm_ops, snp_dev, &tsm_report_extra_type);
 	if (ret)