[v4,3/6] virt: sevguest: Prep for kernel internal {get, get_ext}_report()

Message ID 169570183602.596431.6477217304734993370.stgit@dwillia2-xfh.jf.intel.com
State New
Headers
Series configfs-tsm: Attestation Report ABI |

Commit Message

Dan Williams Sept. 26, 2023, 4:17 a.m. UTC
  In preparation for using the configs-tsm facility to convey attestation
blobs to userspace, switch to using the 'sockptr' api for copying
payloads to provided buffers where 'sockptr' handles user vs kernel
buffers.

While configfs-tsm is meant to replace existing confidential computing
ioctl() implementations for attestation report retrieval the old ioctl()
path needs to stick around for a deprecation period.

No behavior change intended.

Cc: Borislav Petkov <bp@alien8.de>
Cc: Tom Lendacky <thomas.lendacky@amd.com>
Cc: Dionna Glaze <dionnaglaze@google.com>
Cc: Brijesh Singh <brijesh.singh@amd.com>
Signed-off-by: Dan Williams <dan.j.williams@intel.com>
---
 drivers/virt/coco/sev-guest/sev-guest.c |   50 ++++++++++++++++++++-----------
 1 file changed, 33 insertions(+), 17 deletions(-)
  

Comments

Kuppuswamy Sathyanarayanan Sept. 26, 2023, 6:51 p.m. UTC | #1
On 9/25/2023 9:17 PM, Dan Williams wrote:
> In preparation for using the configs-tsm facility to convey attestation
> blobs to userspace, switch to using the 'sockptr' api for copying
> payloads to provided buffers where 'sockptr' handles user vs kernel
> buffers.
> 
> While configfs-tsm is meant to replace existing confidential computing
> ioctl() implementations for attestation report retrieval the old ioctl()
> path needs to stick around for a deprecation period.
> 
> No behavior change intended.
> 
> Cc: Borislav Petkov <bp@alien8.de>
> Cc: Tom Lendacky <thomas.lendacky@amd.com>
> Cc: Dionna Glaze <dionnaglaze@google.com>
> Cc: Brijesh Singh <brijesh.singh@amd.com>
> Signed-off-by: Dan Williams <dan.j.williams@intel.com>
> ---

Looks good to me.

Reviewed-by: Kuppuswamy Sathyanarayanan <sathyanarayanan.kuppuswamy@linux.intel.com>

>  drivers/virt/coco/sev-guest/sev-guest.c |   50 ++++++++++++++++++++-----------
>  1 file changed, 33 insertions(+), 17 deletions(-)
> 
> diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c
> index 97dbe715e96a..c3c9e9ea691f 100644
> --- a/drivers/virt/coco/sev-guest/sev-guest.c
> +++ b/drivers/virt/coco/sev-guest/sev-guest.c
> @@ -19,6 +19,7 @@
>  #include <crypto/aead.h>
>  #include <linux/scatterlist.h>
>  #include <linux/psp-sev.h>
> +#include <linux/sockptr.h>
>  #include <uapi/linux/sev-guest.h>
>  #include <uapi/linux/psp-sev.h>
>  
> @@ -470,7 +471,13 @@ static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
>  	return 0;
>  }
>  
> -static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg)
> +struct snp_req_resp {
> +	sockptr_t req_data;
> +	sockptr_t resp_data;
> +};
> +
> +static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg,
> +		      struct snp_req_resp *io)
>  {
>  	struct snp_guest_crypto *crypto = snp_dev->crypto;
>  	struct snp_report_resp *resp;
> @@ -479,10 +486,10 @@ static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_io
>  
>  	lockdep_assert_held(&snp_cmd_mutex);
>  
> -	if (!arg->req_data || !arg->resp_data)
> +	if (sockptr_is_null(io->req_data) || sockptr_is_null(io->resp_data))
>  		return -EINVAL;
>  
> -	if (copy_from_user(&req, (void __user *)arg->req_data, sizeof(req)))
> +	if (copy_from_sockptr(&req, io->req_data, sizeof(req)))
>  		return -EFAULT;
>  
>  	/*
> @@ -501,7 +508,7 @@ static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_io
>  	if (rc)
>  		goto e_free;
>  
> -	if (copy_to_user((void __user *)arg->resp_data, resp, sizeof(*resp)))
> +	if (copy_to_sockptr(io->resp_data, resp, sizeof(*resp)))
>  		rc = -EFAULT;
>  
>  e_free:
> @@ -550,22 +557,25 @@ static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_reque
>  	return rc;
>  }
>  
> -static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg)
> +static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg,
> +			  struct snp_req_resp *io)
> +
>  {
>  	struct snp_guest_crypto *crypto = snp_dev->crypto;
>  	struct snp_ext_report_req req;
>  	struct snp_report_resp *resp;
>  	int ret, npages = 0, resp_len;
> +	sockptr_t certs_address;
>  
>  	lockdep_assert_held(&snp_cmd_mutex);
>  
> -	if (!arg->req_data || !arg->resp_data)
> +	if (sockptr_is_null(io->req_data) || sockptr_is_null(io->resp_data))
>  		return -EINVAL;
>  
> -	if (copy_from_user(&req, (void __user *)arg->req_data, sizeof(req)))
> +	if (copy_from_sockptr(&req, io->req_data, sizeof(req)))
>  		return -EFAULT;
>  
> -	/* userspace does not want certificate data */
> +	/* caller does not want certificate data */
>  	if (!req.certs_len || !req.certs_address)
>  		goto cmd;
>  
> @@ -573,8 +583,13 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
>  	    !IS_ALIGNED(req.certs_len, PAGE_SIZE))
>  		return -EINVAL;
>  
> -	if (!access_ok((const void __user *)req.certs_address, req.certs_len))
> -		return -EFAULT;
> +	if (sockptr_is_kernel(io->resp_data)) {
> +		certs_address = KERNEL_SOCKPTR((void *)req.certs_address);
> +	} else {
> +		certs_address = USER_SOCKPTR((void __user *)req.certs_address);
> +		if (!access_ok(certs_address.user, req.certs_len))
> +			return -EFAULT;
> +	}
>  
>  	/*
>  	 * Initialize the intermediate buffer with all zeros. This buffer
> @@ -604,21 +619,19 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
>  	if (arg->vmm_error == SNP_GUEST_VMM_ERR_INVALID_LEN) {
>  		req.certs_len = snp_dev->input.data_npages << PAGE_SHIFT;
>  
> -		if (copy_to_user((void __user *)arg->req_data, &req, sizeof(req)))
> +		if (copy_to_sockptr(io->req_data, &req, sizeof(req)))
>  			ret = -EFAULT;
>  	}
>  
>  	if (ret)
>  		goto e_free;
>  
> -	if (npages &&
> -	    copy_to_user((void __user *)req.certs_address, snp_dev->certs_data,
> -			 req.certs_len)) {
> +	if (npages && copy_to_sockptr(certs_address, snp_dev->certs_data, req.certs_len)) {
>  		ret = -EFAULT;
>  		goto e_free;
>  	}
>  
> -	if (copy_to_user((void __user *)arg->resp_data, resp, sizeof(*resp)))
> +	if (copy_to_sockptr(io->resp_data, resp, sizeof(*resp)))
>  		ret = -EFAULT;
>  
>  e_free:
> @@ -631,6 +644,7 @@ static long snp_guest_ioctl(struct file *file, unsigned int ioctl, unsigned long
>  	struct snp_guest_dev *snp_dev = to_snp_dev(file);
>  	void __user *argp = (void __user *)arg;
>  	struct snp_guest_request_ioctl input;
> +	struct snp_req_resp io;
>  	int ret = -ENOTTY;
>  
>  	if (copy_from_user(&input, argp, sizeof(input)))
> @@ -651,15 +665,17 @@ static long snp_guest_ioctl(struct file *file, unsigned int ioctl, unsigned long
>  		return -ENOTTY;
>  	}
>  
> +	io.req_data = USER_SOCKPTR((void __user *)input.req_data);
> +	io.resp_data = USER_SOCKPTR((void __user *)input.resp_data);
>  	switch (ioctl) {
>  	case SNP_GET_REPORT:
> -		ret = get_report(snp_dev, &input);
> +		ret = get_report(snp_dev, &input, &io);
>  		break;
>  	case SNP_GET_DERIVED_KEY:
>  		ret = get_derived_key(snp_dev, &input);
>  		break;
>  	case SNP_GET_EXT_REPORT:
> -		ret = get_ext_report(snp_dev, &input);
> +		ret = get_ext_report(snp_dev, &input, &io);
>  		break;
>  	default:
>  		break;
> 
>
  

Patch

diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c
index 97dbe715e96a..c3c9e9ea691f 100644
--- a/drivers/virt/coco/sev-guest/sev-guest.c
+++ b/drivers/virt/coco/sev-guest/sev-guest.c
@@ -19,6 +19,7 @@ 
 #include <crypto/aead.h>
 #include <linux/scatterlist.h>
 #include <linux/psp-sev.h>
+#include <linux/sockptr.h>
 #include <uapi/linux/sev-guest.h>
 #include <uapi/linux/psp-sev.h>
 
@@ -470,7 +471,13 @@  static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
 	return 0;
 }
 
-static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg)
+struct snp_req_resp {
+	sockptr_t req_data;
+	sockptr_t resp_data;
+};
+
+static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg,
+		      struct snp_req_resp *io)
 {
 	struct snp_guest_crypto *crypto = snp_dev->crypto;
 	struct snp_report_resp *resp;
@@ -479,10 +486,10 @@  static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_io
 
 	lockdep_assert_held(&snp_cmd_mutex);
 
-	if (!arg->req_data || !arg->resp_data)
+	if (sockptr_is_null(io->req_data) || sockptr_is_null(io->resp_data))
 		return -EINVAL;
 
-	if (copy_from_user(&req, (void __user *)arg->req_data, sizeof(req)))
+	if (copy_from_sockptr(&req, io->req_data, sizeof(req)))
 		return -EFAULT;
 
 	/*
@@ -501,7 +508,7 @@  static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_io
 	if (rc)
 		goto e_free;
 
-	if (copy_to_user((void __user *)arg->resp_data, resp, sizeof(*resp)))
+	if (copy_to_sockptr(io->resp_data, resp, sizeof(*resp)))
 		rc = -EFAULT;
 
 e_free:
@@ -550,22 +557,25 @@  static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_reque
 	return rc;
 }
 
-static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg)
+static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg,
+			  struct snp_req_resp *io)
+
 {
 	struct snp_guest_crypto *crypto = snp_dev->crypto;
 	struct snp_ext_report_req req;
 	struct snp_report_resp *resp;
 	int ret, npages = 0, resp_len;
+	sockptr_t certs_address;
 
 	lockdep_assert_held(&snp_cmd_mutex);
 
-	if (!arg->req_data || !arg->resp_data)
+	if (sockptr_is_null(io->req_data) || sockptr_is_null(io->resp_data))
 		return -EINVAL;
 
-	if (copy_from_user(&req, (void __user *)arg->req_data, sizeof(req)))
+	if (copy_from_sockptr(&req, io->req_data, sizeof(req)))
 		return -EFAULT;
 
-	/* userspace does not want certificate data */
+	/* caller does not want certificate data */
 	if (!req.certs_len || !req.certs_address)
 		goto cmd;
 
@@ -573,8 +583,13 @@  static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
 	    !IS_ALIGNED(req.certs_len, PAGE_SIZE))
 		return -EINVAL;
 
-	if (!access_ok((const void __user *)req.certs_address, req.certs_len))
-		return -EFAULT;
+	if (sockptr_is_kernel(io->resp_data)) {
+		certs_address = KERNEL_SOCKPTR((void *)req.certs_address);
+	} else {
+		certs_address = USER_SOCKPTR((void __user *)req.certs_address);
+		if (!access_ok(certs_address.user, req.certs_len))
+			return -EFAULT;
+	}
 
 	/*
 	 * Initialize the intermediate buffer with all zeros. This buffer
@@ -604,21 +619,19 @@  static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
 	if (arg->vmm_error == SNP_GUEST_VMM_ERR_INVALID_LEN) {
 		req.certs_len = snp_dev->input.data_npages << PAGE_SHIFT;
 
-		if (copy_to_user((void __user *)arg->req_data, &req, sizeof(req)))
+		if (copy_to_sockptr(io->req_data, &req, sizeof(req)))
 			ret = -EFAULT;
 	}
 
 	if (ret)
 		goto e_free;
 
-	if (npages &&
-	    copy_to_user((void __user *)req.certs_address, snp_dev->certs_data,
-			 req.certs_len)) {
+	if (npages && copy_to_sockptr(certs_address, snp_dev->certs_data, req.certs_len)) {
 		ret = -EFAULT;
 		goto e_free;
 	}
 
-	if (copy_to_user((void __user *)arg->resp_data, resp, sizeof(*resp)))
+	if (copy_to_sockptr(io->resp_data, resp, sizeof(*resp)))
 		ret = -EFAULT;
 
 e_free:
@@ -631,6 +644,7 @@  static long snp_guest_ioctl(struct file *file, unsigned int ioctl, unsigned long
 	struct snp_guest_dev *snp_dev = to_snp_dev(file);
 	void __user *argp = (void __user *)arg;
 	struct snp_guest_request_ioctl input;
+	struct snp_req_resp io;
 	int ret = -ENOTTY;
 
 	if (copy_from_user(&input, argp, sizeof(input)))
@@ -651,15 +665,17 @@  static long snp_guest_ioctl(struct file *file, unsigned int ioctl, unsigned long
 		return -ENOTTY;
 	}
 
+	io.req_data = USER_SOCKPTR((void __user *)input.req_data);
+	io.resp_data = USER_SOCKPTR((void __user *)input.resp_data);
 	switch (ioctl) {
 	case SNP_GET_REPORT:
-		ret = get_report(snp_dev, &input);
+		ret = get_report(snp_dev, &input, &io);
 		break;
 	case SNP_GET_DERIVED_KEY:
 		ret = get_derived_key(snp_dev, &input);
 		break;
 	case SNP_GET_EXT_REPORT:
-		ret = get_ext_report(snp_dev, &input);
+		ret = get_ext_report(snp_dev, &input, &io);
 		break;
 	default:
 		break;