[v5,05/14] virt: sev-guest: Add vmpck_id to snp_guest_dev struct

Message ID 20231030063652.68675-6-nikunj@amd.com
State New
Headers
Series Add Secure TSC support for SNP guests |

Commit Message

Nikunj A. Dadhania Oct. 30, 2023, 6:36 a.m. UTC
  Drop vmpck and os_area_msg_seqno pointers so that secret page layout
does not need to be exposed to the sev-guest driver after the rework.
Instead, add helper APIs to access vmpck and os_area_msg_seqno when
needed.

Also, change function is_vmpck_empty() to snp_is_vmpck_empty() in
preparation for moving to sev.c.

Signed-off-by: Nikunj A Dadhania <nikunj@amd.com>
---
 drivers/virt/coco/sev-guest/sev-guest.c | 85 ++++++++++++-------------
 1 file changed, 42 insertions(+), 43 deletions(-)
  

Comments

Dionna Amalie Glaze Oct. 30, 2023, 4:16 p.m. UTC | #1
On Sun, Oct 29, 2023 at 11:38 PM Nikunj A Dadhania <nikunj@amd.com> wrote:
>
> Drop vmpck and os_area_msg_seqno pointers so that secret page layout
> does not need to be exposed to the sev-guest driver after the rework.
> Instead, add helper APIs to access vmpck and os_area_msg_seqno when
> needed.
>
> Also, change function is_vmpck_empty() to snp_is_vmpck_empty() in
> preparation for moving to sev.c.
>
> Signed-off-by: Nikunj A Dadhania <nikunj@amd.com>
> ---
>  drivers/virt/coco/sev-guest/sev-guest.c | 85 ++++++++++++-------------
>  1 file changed, 42 insertions(+), 43 deletions(-)
>
> diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c
> index 5801dd52ffdf..4dd094c73e2f 100644
> --- a/drivers/virt/coco/sev-guest/sev-guest.c
> +++ b/drivers/virt/coco/sev-guest/sev-guest.c
> @@ -50,8 +50,7 @@ struct snp_guest_dev {
>
>         struct snp_secrets_page_layout *layout;
>         struct snp_req_data input;
> -       u32 *os_area_msg_seqno;
> -       u8 *vmpck;
> +       unsigned int vmpck_id;
>  };
>
>  static u32 vmpck_id;
> @@ -61,14 +60,22 @@ MODULE_PARM_DESC(vmpck_id, "The VMPCK ID to use when communicating with the PSP.
>  /* Mutex to serialize the shared buffer access and command handling. */
>  static DEFINE_MUTEX(snp_cmd_mutex);
>
> -static bool is_vmpck_empty(struct snp_guest_dev *snp_dev)
> +static inline u8 *snp_get_vmpck(struct snp_guest_dev *snp_dev)
>  {
> -       char zero_key[VMPCK_KEY_LEN] = {0};
> +       return snp_dev->layout->vmpck0 + snp_dev->vmpck_id * VMPCK_KEY_LEN;
> +}
>
> -       if (snp_dev->vmpck)
> -               return !memcmp(snp_dev->vmpck, zero_key, VMPCK_KEY_LEN);
> +static inline u32 *snp_get_os_area_msg_seqno(struct snp_guest_dev *snp_dev)
> +{
> +       return &snp_dev->layout->os_area.msg_seqno_0 + snp_dev->vmpck_id;
> +}
>
> -       return true;
> +static bool snp_is_vmpck_empty(struct snp_guest_dev *snp_dev)
> +{
> +       char zero_key[VMPCK_KEY_LEN] = {0};
> +       u8 *key = snp_get_vmpck(snp_dev);
> +
> +       return !memcmp(key, zero_key, VMPCK_KEY_LEN);
>  }
>
>  /*
> @@ -90,20 +97,22 @@ static bool is_vmpck_empty(struct snp_guest_dev *snp_dev)
>   */
>  static void snp_disable_vmpck(struct snp_guest_dev *snp_dev)
>  {
> +       u8 *key = snp_get_vmpck(snp_dev);
> +
>         dev_alert(snp_dev->dev, "Disabling vmpck_id %d to prevent IV reuse.\n",
> -                 vmpck_id);
> -       memzero_explicit(snp_dev->vmpck, VMPCK_KEY_LEN);
> -       snp_dev->vmpck = NULL;
> +                 snp_dev->vmpck_id);
> +       memzero_explicit(key, VMPCK_KEY_LEN);
>  }

We disable the VMPCK because we believe the guest to be under attack,
but this only clears a single key. Shouldn't we clear all VMPCK keys
in the secrets page for good measure? If at VMPCK > 0, most likely the
0..VMPCK-1 keys have been zeroed by whatever was prior in the boot
stack, but still better to be safe.

>
>  static inline u64 __snp_get_msg_seqno(struct snp_guest_dev *snp_dev)
>  {
> +       u32 *os_area_msg_seqno = snp_get_os_area_msg_seqno(snp_dev);
>         u64 count;
>
>         lockdep_assert_held(&snp_dev->cmd_mutex);
>
>         /* Read the current message sequence counter from secrets pages */
> -       count = *snp_dev->os_area_msg_seqno;
> +       count = *os_area_msg_seqno;
>
>         return count + 1;
>  }
> @@ -131,11 +140,13 @@ static u64 snp_get_msg_seqno(struct snp_guest_dev *snp_dev)
>
>  static void snp_inc_msg_seqno(struct snp_guest_dev *snp_dev)
>  {
> +       u32 *os_area_msg_seqno = snp_get_os_area_msg_seqno(snp_dev);
> +
>         /*
>          * The counter is also incremented by the PSP, so increment it by 2
>          * and save in secrets page.
>          */
> -       *snp_dev->os_area_msg_seqno += 2;
> +       *os_area_msg_seqno += 2;
>  }
>
>  static inline struct snp_guest_dev *to_snp_dev(struct file *file)
> @@ -145,15 +156,22 @@ static inline struct snp_guest_dev *to_snp_dev(struct file *file)
>         return container_of(dev, struct snp_guest_dev, misc);
>  }
>
> -static struct aesgcm_ctx *snp_init_crypto(u8 *key, size_t keylen)
> +static struct aesgcm_ctx *snp_init_crypto(struct snp_guest_dev *snp_dev)
>  {
>         struct aesgcm_ctx *ctx;
> +       u8 *key;
> +
> +       if (snp_is_vmpck_empty(snp_dev)) {
> +               pr_err("SNP: vmpck id %d is null\n", snp_dev->vmpck_id);
> +               return NULL;
> +       }
>
>         ctx = kzalloc(sizeof(*ctx), GFP_KERNEL_ACCOUNT);
>         if (!ctx)
>                 return NULL;
>
> -       if (aesgcm_expandkey(ctx, key, keylen, AUTHTAG_LEN)) {
> +       key = snp_get_vmpck(snp_dev);
> +       if (aesgcm_expandkey(ctx, key, VMPCK_KEY_LEN, AUTHTAG_LEN)) {
>                 pr_err("SNP: crypto init failed\n");
>                 kfree(ctx);
>                 return NULL;
> @@ -586,7 +604,7 @@ static long snp_guest_ioctl(struct file *file, unsigned int ioctl, unsigned long
>         mutex_lock(&snp_dev->cmd_mutex);
>
>         /* Check if the VMPCK is not empty */
> -       if (is_vmpck_empty(snp_dev)) {
> +       if (snp_is_vmpck_empty(snp_dev)) {
>                 dev_err_ratelimited(snp_dev->dev, "VMPCK is disabled\n");
>                 mutex_unlock(&snp_dev->cmd_mutex);
>                 return -ENOTTY;
> @@ -656,32 +674,14 @@ static const struct file_operations snp_guest_fops = {
>         .unlocked_ioctl = snp_guest_ioctl,
>  };
>
> -static u8 *get_vmpck(int id, struct snp_secrets_page_layout *layout, u32 **seqno)
> +bool snp_assign_vmpck(struct snp_guest_dev *dev, int vmpck_id)
>  {
> -       u8 *key = NULL;
> +       if (WARN_ON(vmpck_id > 3))
> +               return false;

The vmpck_id is an int for some reason, so < 0 is also a problem. Can
we not use unsigned int?

>
> -       switch (id) {
> -       case 0:
> -               *seqno = &layout->os_area.msg_seqno_0;
> -               key = layout->vmpck0;
> -               break;
> -       case 1:
> -               *seqno = &layout->os_area.msg_seqno_1;
> -               key = layout->vmpck1;
> -               break;
> -       case 2:
> -               *seqno = &layout->os_area.msg_seqno_2;
> -               key = layout->vmpck2;
> -               break;
> -       case 3:
> -               *seqno = &layout->os_area.msg_seqno_3;
> -               key = layout->vmpck3;
> -               break;
> -       default:
> -               break;
> -       }
> +       dev->vmpck_id = vmpck_id;
>
> -       return key;
> +       return true;
>  }
>
>  static int __init sev_guest_probe(struct platform_device *pdev)
> @@ -713,14 +713,14 @@ static int __init sev_guest_probe(struct platform_device *pdev)
>                 goto e_unmap;
>
>         ret = -EINVAL;
> -       snp_dev->vmpck = get_vmpck(vmpck_id, layout, &snp_dev->os_area_msg_seqno);
> -       if (!snp_dev->vmpck) {
> +       snp_dev->layout = layout;
> +       if (!snp_assign_vmpck(snp_dev, vmpck_id)) {
>                 dev_err(dev, "invalid vmpck id %d\n", vmpck_id);
>                 goto e_unmap;
>         }
>
>         /* Verify that VMPCK is not zero. */
> -       if (is_vmpck_empty(snp_dev)) {
> +       if (snp_is_vmpck_empty(snp_dev)) {
>                 dev_err(dev, "vmpck id %d is null\n", vmpck_id);
>                 goto e_unmap;
>         }
> @@ -728,7 +728,6 @@ static int __init sev_guest_probe(struct platform_device *pdev)
>         mutex_init(&snp_dev->cmd_mutex);
>         platform_set_drvdata(pdev, snp_dev);
>         snp_dev->dev = dev;
> -       snp_dev->layout = layout;
>
>         /* Allocate the shared page used for the request and response message. */
>         snp_dev->request = alloc_shared_pages(dev, sizeof(struct snp_guest_msg));
> @@ -744,7 +743,7 @@ static int __init sev_guest_probe(struct platform_device *pdev)
>                 goto e_free_response;
>
>         ret = -EIO;
> -       snp_dev->ctx = snp_init_crypto(snp_dev->vmpck, VMPCK_KEY_LEN);
> +       snp_dev->ctx = snp_init_crypto(snp_dev);
>         if (!snp_dev->ctx)
>                 goto e_free_cert_data;
>
> --
> 2.34.1
>
  
Tom Lendacky Oct. 30, 2023, 5:12 p.m. UTC | #2
On 10/30/23 11:16, Dionna Amalie Glaze wrote:
> On Sun, Oct 29, 2023 at 11:38 PM Nikunj A Dadhania <nikunj@amd.com> wrote:
>>
>> Drop vmpck and os_area_msg_seqno pointers so that secret page layout
>> does not need to be exposed to the sev-guest driver after the rework.
>> Instead, add helper APIs to access vmpck and os_area_msg_seqno when
>> needed.
>>
>> Also, change function is_vmpck_empty() to snp_is_vmpck_empty() in
>> preparation for moving to sev.c.
>>
>> Signed-off-by: Nikunj A Dadhania <nikunj@amd.com>
>> ---
>>   drivers/virt/coco/sev-guest/sev-guest.c | 85 ++++++++++++-------------
>>   1 file changed, 42 insertions(+), 43 deletions(-)
>>
>> diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c
>> index 5801dd52ffdf..4dd094c73e2f 100644
>> --- a/drivers/virt/coco/sev-guest/sev-guest.c
>> +++ b/drivers/virt/coco/sev-guest/sev-guest.c
>> @@ -50,8 +50,7 @@ struct snp_guest_dev {
>>
>>          struct snp_secrets_page_layout *layout;
>>          struct snp_req_data input;
>> -       u32 *os_area_msg_seqno;
>> -       u8 *vmpck;
>> +       unsigned int vmpck_id;
>>   };
>>
>>   static u32 vmpck_id;
>> @@ -61,14 +60,22 @@ MODULE_PARM_DESC(vmpck_id, "The VMPCK ID to use when communicating with the PSP.
>>   /* Mutex to serialize the shared buffer access and command handling. */
>>   static DEFINE_MUTEX(snp_cmd_mutex);
>>
>> -static bool is_vmpck_empty(struct snp_guest_dev *snp_dev)
>> +static inline u8 *snp_get_vmpck(struct snp_guest_dev *snp_dev)
>>   {
>> -       char zero_key[VMPCK_KEY_LEN] = {0};
>> +       return snp_dev->layout->vmpck0 + snp_dev->vmpck_id * VMPCK_KEY_LEN;
>> +}
>>
>> -       if (snp_dev->vmpck)
>> -               return !memcmp(snp_dev->vmpck, zero_key, VMPCK_KEY_LEN);
>> +static inline u32 *snp_get_os_area_msg_seqno(struct snp_guest_dev *snp_dev)
>> +{
>> +       return &snp_dev->layout->os_area.msg_seqno_0 + snp_dev->vmpck_id;
>> +}
>>
>> -       return true;
>> +static bool snp_is_vmpck_empty(struct snp_guest_dev *snp_dev)
>> +{
>> +       char zero_key[VMPCK_KEY_LEN] = {0};
>> +       u8 *key = snp_get_vmpck(snp_dev);
>> +
>> +       return !memcmp(key, zero_key, VMPCK_KEY_LEN);
>>   }
>>
>>   /*
>> @@ -90,20 +97,22 @@ static bool is_vmpck_empty(struct snp_guest_dev *snp_dev)
>>    */
>>   static void snp_disable_vmpck(struct snp_guest_dev *snp_dev)
>>   {
>> +       u8 *key = snp_get_vmpck(snp_dev);
>> +
>>          dev_alert(snp_dev->dev, "Disabling vmpck_id %d to prevent IV reuse.\n",
>> -                 vmpck_id);
>> -       memzero_explicit(snp_dev->vmpck, VMPCK_KEY_LEN);
>> -       snp_dev->vmpck = NULL;
>> +                 snp_dev->vmpck_id);
>> +       memzero_explicit(key, VMPCK_KEY_LEN);
>>   }
> 
> We disable the VMPCK because we believe the guest to be under attack,
> but this only clears a single key. Shouldn't we clear all VMPCK keys
> in the secrets page for good measure? If at VMPCK > 0, most likely the
> 0..VMPCK-1 keys have been zeroed by whatever was prior in the boot
> stack, but still better to be safe.

Doing that would be a separate patch series and isn't appropriate here.

Thanks,
Tom

> 
>>
>>   static inline u64 __snp_get_msg_seqno(struct snp_guest_dev *snp_dev)
>>   {
>> +       u32 *os_area_msg_seqno = snp_get_os_area_msg_seqno(snp_dev);
>>          u64 count;
>>
>>          lockdep_assert_held(&snp_dev->cmd_mutex);
>>
>>          /* Read the current message sequence counter from secrets pages */
>> -       count = *snp_dev->os_area_msg_seqno;
>> +       count = *os_area_msg_seqno;
>>
>>          return count + 1;
>>   }
>> @@ -131,11 +140,13 @@ static u64 snp_get_msg_seqno(struct snp_guest_dev *snp_dev)
>>
>>   static void snp_inc_msg_seqno(struct snp_guest_dev *snp_dev)
>>   {
>> +       u32 *os_area_msg_seqno = snp_get_os_area_msg_seqno(snp_dev);
>> +
>>          /*
>>           * The counter is also incremented by the PSP, so increment it by 2
>>           * and save in secrets page.
>>           */
>> -       *snp_dev->os_area_msg_seqno += 2;
>> +       *os_area_msg_seqno += 2;
>>   }
>>
>>   static inline struct snp_guest_dev *to_snp_dev(struct file *file)
>> @@ -145,15 +156,22 @@ static inline struct snp_guest_dev *to_snp_dev(struct file *file)
>>          return container_of(dev, struct snp_guest_dev, misc);
>>   }
>>
>> -static struct aesgcm_ctx *snp_init_crypto(u8 *key, size_t keylen)
>> +static struct aesgcm_ctx *snp_init_crypto(struct snp_guest_dev *snp_dev)
>>   {
>>          struct aesgcm_ctx *ctx;
>> +       u8 *key;
>> +
>> +       if (snp_is_vmpck_empty(snp_dev)) {
>> +               pr_err("SNP: vmpck id %d is null\n", snp_dev->vmpck_id);
>> +               return NULL;
>> +       }
>>
>>          ctx = kzalloc(sizeof(*ctx), GFP_KERNEL_ACCOUNT);
>>          if (!ctx)
>>                  return NULL;
>>
>> -       if (aesgcm_expandkey(ctx, key, keylen, AUTHTAG_LEN)) {
>> +       key = snp_get_vmpck(snp_dev);
>> +       if (aesgcm_expandkey(ctx, key, VMPCK_KEY_LEN, AUTHTAG_LEN)) {
>>                  pr_err("SNP: crypto init failed\n");
>>                  kfree(ctx);
>>                  return NULL;
>> @@ -586,7 +604,7 @@ static long snp_guest_ioctl(struct file *file, unsigned int ioctl, unsigned long
>>          mutex_lock(&snp_dev->cmd_mutex);
>>
>>          /* Check if the VMPCK is not empty */
>> -       if (is_vmpck_empty(snp_dev)) {
>> +       if (snp_is_vmpck_empty(snp_dev)) {
>>                  dev_err_ratelimited(snp_dev->dev, "VMPCK is disabled\n");
>>                  mutex_unlock(&snp_dev->cmd_mutex);
>>                  return -ENOTTY;
>> @@ -656,32 +674,14 @@ static const struct file_operations snp_guest_fops = {
>>          .unlocked_ioctl = snp_guest_ioctl,
>>   };
>>
>> -static u8 *get_vmpck(int id, struct snp_secrets_page_layout *layout, u32 **seqno)
>> +bool snp_assign_vmpck(struct snp_guest_dev *dev, int vmpck_id)
>>   {
>> -       u8 *key = NULL;
>> +       if (WARN_ON(vmpck_id > 3))
>> +               return false;
> 
> The vmpck_id is an int for some reason, so < 0 is also a problem. Can
> we not use unsigned int?
> 
>>
>> -       switch (id) {
>> -       case 0:
>> -               *seqno = &layout->os_area.msg_seqno_0;
>> -               key = layout->vmpck0;
>> -               break;
>> -       case 1:
>> -               *seqno = &layout->os_area.msg_seqno_1;
>> -               key = layout->vmpck1;
>> -               break;
>> -       case 2:
>> -               *seqno = &layout->os_area.msg_seqno_2;
>> -               key = layout->vmpck2;
>> -               break;
>> -       case 3:
>> -               *seqno = &layout->os_area.msg_seqno_3;
>> -               key = layout->vmpck3;
>> -               break;
>> -       default:
>> -               break;
>> -       }
>> +       dev->vmpck_id = vmpck_id;
>>
>> -       return key;
>> +       return true;
>>   }
>>
>>   static int __init sev_guest_probe(struct platform_device *pdev)
>> @@ -713,14 +713,14 @@ static int __init sev_guest_probe(struct platform_device *pdev)
>>                  goto e_unmap;
>>
>>          ret = -EINVAL;
>> -       snp_dev->vmpck = get_vmpck(vmpck_id, layout, &snp_dev->os_area_msg_seqno);
>> -       if (!snp_dev->vmpck) {
>> +       snp_dev->layout = layout;
>> +       if (!snp_assign_vmpck(snp_dev, vmpck_id)) {
>>                  dev_err(dev, "invalid vmpck id %d\n", vmpck_id);
>>                  goto e_unmap;
>>          }
>>
>>          /* Verify that VMPCK is not zero. */
>> -       if (is_vmpck_empty(snp_dev)) {
>> +       if (snp_is_vmpck_empty(snp_dev)) {
>>                  dev_err(dev, "vmpck id %d is null\n", vmpck_id);
>>                  goto e_unmap;
>>          }
>> @@ -728,7 +728,6 @@ static int __init sev_guest_probe(struct platform_device *pdev)
>>          mutex_init(&snp_dev->cmd_mutex);
>>          platform_set_drvdata(pdev, snp_dev);
>>          snp_dev->dev = dev;
>> -       snp_dev->layout = layout;
>>
>>          /* Allocate the shared page used for the request and response message. */
>>          snp_dev->request = alloc_shared_pages(dev, sizeof(struct snp_guest_msg));
>> @@ -744,7 +743,7 @@ static int __init sev_guest_probe(struct platform_device *pdev)
>>                  goto e_free_response;
>>
>>          ret = -EIO;
>> -       snp_dev->ctx = snp_init_crypto(snp_dev->vmpck, VMPCK_KEY_LEN);
>> +       snp_dev->ctx = snp_init_crypto(snp_dev);
>>          if (!snp_dev->ctx)
>>                  goto e_free_cert_data;
>>
>> --
>> 2.34.1
>>
> 
>
  
Tom Lendacky Oct. 30, 2023, 6:26 p.m. UTC | #3
On 10/30/23 01:36, Nikunj A Dadhania wrote:
> Drop vmpck and os_area_msg_seqno pointers so that secret page layout
> does not need to be exposed to the sev-guest driver after the rework.
> Instead, add helper APIs to access vmpck and os_area_msg_seqno when
> needed.
> 
> Also, change function is_vmpck_empty() to snp_is_vmpck_empty() in
> preparation for moving to sev.c.
> 
> Signed-off-by: Nikunj A Dadhania <nikunj@amd.com>

With the fix to the snp_assign_vmpck() to change the int to an unsigned 
int as requested by Dionna...

Reviewed-by: Tom Lendacky <thomas.lendacky@amd.com>

> ---
>   drivers/virt/coco/sev-guest/sev-guest.c | 85 ++++++++++++-------------
>   1 file changed, 42 insertions(+), 43 deletions(-)
> 
> diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c
> index 5801dd52ffdf..4dd094c73e2f 100644
> --- a/drivers/virt/coco/sev-guest/sev-guest.c
> +++ b/drivers/virt/coco/sev-guest/sev-guest.c
> @@ -50,8 +50,7 @@ struct snp_guest_dev {
>   
>   	struct snp_secrets_page_layout *layout;
>   	struct snp_req_data input;
> -	u32 *os_area_msg_seqno;
> -	u8 *vmpck;
> +	unsigned int vmpck_id;
>   };
>   
>   static u32 vmpck_id;
> @@ -61,14 +60,22 @@ MODULE_PARM_DESC(vmpck_id, "The VMPCK ID to use when communicating with the PSP.
>   /* Mutex to serialize the shared buffer access and command handling. */
>   static DEFINE_MUTEX(snp_cmd_mutex);
>   
> -static bool is_vmpck_empty(struct snp_guest_dev *snp_dev)
> +static inline u8 *snp_get_vmpck(struct snp_guest_dev *snp_dev)
>   {
> -	char zero_key[VMPCK_KEY_LEN] = {0};
> +	return snp_dev->layout->vmpck0 + snp_dev->vmpck_id * VMPCK_KEY_LEN;
> +}
>   
> -	if (snp_dev->vmpck)
> -		return !memcmp(snp_dev->vmpck, zero_key, VMPCK_KEY_LEN);
> +static inline u32 *snp_get_os_area_msg_seqno(struct snp_guest_dev *snp_dev)
> +{
> +	return &snp_dev->layout->os_area.msg_seqno_0 + snp_dev->vmpck_id;
> +}
>   
> -	return true;
> +static bool snp_is_vmpck_empty(struct snp_guest_dev *snp_dev)
> +{
> +	char zero_key[VMPCK_KEY_LEN] = {0};
> +	u8 *key = snp_get_vmpck(snp_dev);
> +
> +	return !memcmp(key, zero_key, VMPCK_KEY_LEN);
>   }
>   
>   /*
> @@ -90,20 +97,22 @@ static bool is_vmpck_empty(struct snp_guest_dev *snp_dev)
>    */
>   static void snp_disable_vmpck(struct snp_guest_dev *snp_dev)
>   {
> +	u8 *key = snp_get_vmpck(snp_dev);
> +
>   	dev_alert(snp_dev->dev, "Disabling vmpck_id %d to prevent IV reuse.\n",
> -		  vmpck_id);
> -	memzero_explicit(snp_dev->vmpck, VMPCK_KEY_LEN);
> -	snp_dev->vmpck = NULL;
> +		  snp_dev->vmpck_id);
> +	memzero_explicit(key, VMPCK_KEY_LEN);
>   }
>   
>   static inline u64 __snp_get_msg_seqno(struct snp_guest_dev *snp_dev)
>   {
> +	u32 *os_area_msg_seqno = snp_get_os_area_msg_seqno(snp_dev);
>   	u64 count;
>   
>   	lockdep_assert_held(&snp_dev->cmd_mutex);
>   
>   	/* Read the current message sequence counter from secrets pages */
> -	count = *snp_dev->os_area_msg_seqno;
> +	count = *os_area_msg_seqno;
>   
>   	return count + 1;
>   }
> @@ -131,11 +140,13 @@ static u64 snp_get_msg_seqno(struct snp_guest_dev *snp_dev)
>   
>   static void snp_inc_msg_seqno(struct snp_guest_dev *snp_dev)
>   {
> +	u32 *os_area_msg_seqno = snp_get_os_area_msg_seqno(snp_dev);
> +
>   	/*
>   	 * The counter is also incremented by the PSP, so increment it by 2
>   	 * and save in secrets page.
>   	 */
> -	*snp_dev->os_area_msg_seqno += 2;
> +	*os_area_msg_seqno += 2;
>   }
>   
>   static inline struct snp_guest_dev *to_snp_dev(struct file *file)
> @@ -145,15 +156,22 @@ static inline struct snp_guest_dev *to_snp_dev(struct file *file)
>   	return container_of(dev, struct snp_guest_dev, misc);
>   }
>   
> -static struct aesgcm_ctx *snp_init_crypto(u8 *key, size_t keylen)
> +static struct aesgcm_ctx *snp_init_crypto(struct snp_guest_dev *snp_dev)
>   {
>   	struct aesgcm_ctx *ctx;
> +	u8 *key;
> +
> +	if (snp_is_vmpck_empty(snp_dev)) {
> +		pr_err("SNP: vmpck id %d is null\n", snp_dev->vmpck_id);
> +		return NULL;
> +	}
>   
>   	ctx = kzalloc(sizeof(*ctx), GFP_KERNEL_ACCOUNT);
>   	if (!ctx)
>   		return NULL;
>   
> -	if (aesgcm_expandkey(ctx, key, keylen, AUTHTAG_LEN)) {
> +	key = snp_get_vmpck(snp_dev);
> +	if (aesgcm_expandkey(ctx, key, VMPCK_KEY_LEN, AUTHTAG_LEN)) {
>   		pr_err("SNP: crypto init failed\n");
>   		kfree(ctx);
>   		return NULL;
> @@ -586,7 +604,7 @@ static long snp_guest_ioctl(struct file *file, unsigned int ioctl, unsigned long
>   	mutex_lock(&snp_dev->cmd_mutex);
>   
>   	/* Check if the VMPCK is not empty */
> -	if (is_vmpck_empty(snp_dev)) {
> +	if (snp_is_vmpck_empty(snp_dev)) {
>   		dev_err_ratelimited(snp_dev->dev, "VMPCK is disabled\n");
>   		mutex_unlock(&snp_dev->cmd_mutex);
>   		return -ENOTTY;
> @@ -656,32 +674,14 @@ static const struct file_operations snp_guest_fops = {
>   	.unlocked_ioctl = snp_guest_ioctl,
>   };
>   
> -static u8 *get_vmpck(int id, struct snp_secrets_page_layout *layout, u32 **seqno)
> +bool snp_assign_vmpck(struct snp_guest_dev *dev, int vmpck_id)
>   {
> -	u8 *key = NULL;
> +	if (WARN_ON(vmpck_id > 3))
> +		return false;
>   
> -	switch (id) {
> -	case 0:
> -		*seqno = &layout->os_area.msg_seqno_0;
> -		key = layout->vmpck0;
> -		break;
> -	case 1:
> -		*seqno = &layout->os_area.msg_seqno_1;
> -		key = layout->vmpck1;
> -		break;
> -	case 2:
> -		*seqno = &layout->os_area.msg_seqno_2;
> -		key = layout->vmpck2;
> -		break;
> -	case 3:
> -		*seqno = &layout->os_area.msg_seqno_3;
> -		key = layout->vmpck3;
> -		break;
> -	default:
> -		break;
> -	}
> +	dev->vmpck_id = vmpck_id;
>   
> -	return key;
> +	return true;
>   }
>   
>   static int __init sev_guest_probe(struct platform_device *pdev)
> @@ -713,14 +713,14 @@ static int __init sev_guest_probe(struct platform_device *pdev)
>   		goto e_unmap;
>   
>   	ret = -EINVAL;
> -	snp_dev->vmpck = get_vmpck(vmpck_id, layout, &snp_dev->os_area_msg_seqno);
> -	if (!snp_dev->vmpck) {
> +	snp_dev->layout = layout;
> +	if (!snp_assign_vmpck(snp_dev, vmpck_id)) {
>   		dev_err(dev, "invalid vmpck id %d\n", vmpck_id);
>   		goto e_unmap;
>   	}
>   
>   	/* Verify that VMPCK is not zero. */
> -	if (is_vmpck_empty(snp_dev)) {
> +	if (snp_is_vmpck_empty(snp_dev)) {
>   		dev_err(dev, "vmpck id %d is null\n", vmpck_id);
>   		goto e_unmap;
>   	}
> @@ -728,7 +728,6 @@ static int __init sev_guest_probe(struct platform_device *pdev)
>   	mutex_init(&snp_dev->cmd_mutex);
>   	platform_set_drvdata(pdev, snp_dev);
>   	snp_dev->dev = dev;
> -	snp_dev->layout = layout;
>   
>   	/* Allocate the shared page used for the request and response message. */
>   	snp_dev->request = alloc_shared_pages(dev, sizeof(struct snp_guest_msg));
> @@ -744,7 +743,7 @@ static int __init sev_guest_probe(struct platform_device *pdev)
>   		goto e_free_response;
>   
>   	ret = -EIO;
> -	snp_dev->ctx = snp_init_crypto(snp_dev->vmpck, VMPCK_KEY_LEN);
> +	snp_dev->ctx = snp_init_crypto(snp_dev);
>   	if (!snp_dev->ctx)
>   		goto e_free_cert_data;
>
  
Nikunj A. Dadhania Nov. 2, 2023, 4:03 a.m. UTC | #4
On 10/30/2023 10:42 PM, Tom Lendacky wrote:
> On 10/30/23 11:16, Dionna Amalie Glaze wrote:
>> On Sun, Oct 29, 2023 at 11:38 PM Nikunj A Dadhania <nikunj@amd.com> wrote:

>>> @@ -656,32 +674,14 @@ static const struct file_operations snp_guest_fops = {
>>>          .unlocked_ioctl = snp_guest_ioctl,
>>>   };
>>>
>>> -static u8 *get_vmpck(int id, struct snp_secrets_page_layout *layout, u32 **seqno)
>>> +bool snp_assign_vmpck(struct snp_guest_dev *dev, int vmpck_id)
>>>   {
>>> -       u8 *key = NULL;
>>> +       if (WARN_ON(vmpck_id > 3))
>>> +               return false;
>>
>> The vmpck_id is an int for some reason, so < 0 is also a problem. Can
>> we not use unsigned int?

Yes, I will update that in my next revision,

Thanks
Nikunj
  

Patch

diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c
index 5801dd52ffdf..4dd094c73e2f 100644
--- a/drivers/virt/coco/sev-guest/sev-guest.c
+++ b/drivers/virt/coco/sev-guest/sev-guest.c
@@ -50,8 +50,7 @@  struct snp_guest_dev {
 
 	struct snp_secrets_page_layout *layout;
 	struct snp_req_data input;
-	u32 *os_area_msg_seqno;
-	u8 *vmpck;
+	unsigned int vmpck_id;
 };
 
 static u32 vmpck_id;
@@ -61,14 +60,22 @@  MODULE_PARM_DESC(vmpck_id, "The VMPCK ID to use when communicating with the PSP.
 /* Mutex to serialize the shared buffer access and command handling. */
 static DEFINE_MUTEX(snp_cmd_mutex);
 
-static bool is_vmpck_empty(struct snp_guest_dev *snp_dev)
+static inline u8 *snp_get_vmpck(struct snp_guest_dev *snp_dev)
 {
-	char zero_key[VMPCK_KEY_LEN] = {0};
+	return snp_dev->layout->vmpck0 + snp_dev->vmpck_id * VMPCK_KEY_LEN;
+}
 
-	if (snp_dev->vmpck)
-		return !memcmp(snp_dev->vmpck, zero_key, VMPCK_KEY_LEN);
+static inline u32 *snp_get_os_area_msg_seqno(struct snp_guest_dev *snp_dev)
+{
+	return &snp_dev->layout->os_area.msg_seqno_0 + snp_dev->vmpck_id;
+}
 
-	return true;
+static bool snp_is_vmpck_empty(struct snp_guest_dev *snp_dev)
+{
+	char zero_key[VMPCK_KEY_LEN] = {0};
+	u8 *key = snp_get_vmpck(snp_dev);
+
+	return !memcmp(key, zero_key, VMPCK_KEY_LEN);
 }
 
 /*
@@ -90,20 +97,22 @@  static bool is_vmpck_empty(struct snp_guest_dev *snp_dev)
  */
 static void snp_disable_vmpck(struct snp_guest_dev *snp_dev)
 {
+	u8 *key = snp_get_vmpck(snp_dev);
+
 	dev_alert(snp_dev->dev, "Disabling vmpck_id %d to prevent IV reuse.\n",
-		  vmpck_id);
-	memzero_explicit(snp_dev->vmpck, VMPCK_KEY_LEN);
-	snp_dev->vmpck = NULL;
+		  snp_dev->vmpck_id);
+	memzero_explicit(key, VMPCK_KEY_LEN);
 }
 
 static inline u64 __snp_get_msg_seqno(struct snp_guest_dev *snp_dev)
 {
+	u32 *os_area_msg_seqno = snp_get_os_area_msg_seqno(snp_dev);
 	u64 count;
 
 	lockdep_assert_held(&snp_dev->cmd_mutex);
 
 	/* Read the current message sequence counter from secrets pages */
-	count = *snp_dev->os_area_msg_seqno;
+	count = *os_area_msg_seqno;
 
 	return count + 1;
 }
@@ -131,11 +140,13 @@  static u64 snp_get_msg_seqno(struct snp_guest_dev *snp_dev)
 
 static void snp_inc_msg_seqno(struct snp_guest_dev *snp_dev)
 {
+	u32 *os_area_msg_seqno = snp_get_os_area_msg_seqno(snp_dev);
+
 	/*
 	 * The counter is also incremented by the PSP, so increment it by 2
 	 * and save in secrets page.
 	 */
-	*snp_dev->os_area_msg_seqno += 2;
+	*os_area_msg_seqno += 2;
 }
 
 static inline struct snp_guest_dev *to_snp_dev(struct file *file)
@@ -145,15 +156,22 @@  static inline struct snp_guest_dev *to_snp_dev(struct file *file)
 	return container_of(dev, struct snp_guest_dev, misc);
 }
 
-static struct aesgcm_ctx *snp_init_crypto(u8 *key, size_t keylen)
+static struct aesgcm_ctx *snp_init_crypto(struct snp_guest_dev *snp_dev)
 {
 	struct aesgcm_ctx *ctx;
+	u8 *key;
+
+	if (snp_is_vmpck_empty(snp_dev)) {
+		pr_err("SNP: vmpck id %d is null\n", snp_dev->vmpck_id);
+		return NULL;
+	}
 
 	ctx = kzalloc(sizeof(*ctx), GFP_KERNEL_ACCOUNT);
 	if (!ctx)
 		return NULL;
 
-	if (aesgcm_expandkey(ctx, key, keylen, AUTHTAG_LEN)) {
+	key = snp_get_vmpck(snp_dev);
+	if (aesgcm_expandkey(ctx, key, VMPCK_KEY_LEN, AUTHTAG_LEN)) {
 		pr_err("SNP: crypto init failed\n");
 		kfree(ctx);
 		return NULL;
@@ -586,7 +604,7 @@  static long snp_guest_ioctl(struct file *file, unsigned int ioctl, unsigned long
 	mutex_lock(&snp_dev->cmd_mutex);
 
 	/* Check if the VMPCK is not empty */
-	if (is_vmpck_empty(snp_dev)) {
+	if (snp_is_vmpck_empty(snp_dev)) {
 		dev_err_ratelimited(snp_dev->dev, "VMPCK is disabled\n");
 		mutex_unlock(&snp_dev->cmd_mutex);
 		return -ENOTTY;
@@ -656,32 +674,14 @@  static const struct file_operations snp_guest_fops = {
 	.unlocked_ioctl = snp_guest_ioctl,
 };
 
-static u8 *get_vmpck(int id, struct snp_secrets_page_layout *layout, u32 **seqno)
+bool snp_assign_vmpck(struct snp_guest_dev *dev, int vmpck_id)
 {
-	u8 *key = NULL;
+	if (WARN_ON(vmpck_id > 3))
+		return false;
 
-	switch (id) {
-	case 0:
-		*seqno = &layout->os_area.msg_seqno_0;
-		key = layout->vmpck0;
-		break;
-	case 1:
-		*seqno = &layout->os_area.msg_seqno_1;
-		key = layout->vmpck1;
-		break;
-	case 2:
-		*seqno = &layout->os_area.msg_seqno_2;
-		key = layout->vmpck2;
-		break;
-	case 3:
-		*seqno = &layout->os_area.msg_seqno_3;
-		key = layout->vmpck3;
-		break;
-	default:
-		break;
-	}
+	dev->vmpck_id = vmpck_id;
 
-	return key;
+	return true;
 }
 
 static int __init sev_guest_probe(struct platform_device *pdev)
@@ -713,14 +713,14 @@  static int __init sev_guest_probe(struct platform_device *pdev)
 		goto e_unmap;
 
 	ret = -EINVAL;
-	snp_dev->vmpck = get_vmpck(vmpck_id, layout, &snp_dev->os_area_msg_seqno);
-	if (!snp_dev->vmpck) {
+	snp_dev->layout = layout;
+	if (!snp_assign_vmpck(snp_dev, vmpck_id)) {
 		dev_err(dev, "invalid vmpck id %d\n", vmpck_id);
 		goto e_unmap;
 	}
 
 	/* Verify that VMPCK is not zero. */
-	if (is_vmpck_empty(snp_dev)) {
+	if (snp_is_vmpck_empty(snp_dev)) {
 		dev_err(dev, "vmpck id %d is null\n", vmpck_id);
 		goto e_unmap;
 	}
@@ -728,7 +728,6 @@  static int __init sev_guest_probe(struct platform_device *pdev)
 	mutex_init(&snp_dev->cmd_mutex);
 	platform_set_drvdata(pdev, snp_dev);
 	snp_dev->dev = dev;
-	snp_dev->layout = layout;
 
 	/* Allocate the shared page used for the request and response message. */
 	snp_dev->request = alloc_shared_pages(dev, sizeof(struct snp_guest_msg));
@@ -744,7 +743,7 @@  static int __init sev_guest_probe(struct platform_device *pdev)
 		goto e_free_response;
 
 	ret = -EIO;
-	snp_dev->ctx = snp_init_crypto(snp_dev->vmpck, VMPCK_KEY_LEN);
+	snp_dev->ctx = snp_init_crypto(snp_dev);
 	if (!snp_dev->ctx)
 		goto e_free_cert_data;