[v4,09/17] iommufd/device: Add helpers to enforce/remove device reserved regions

Message ID 20230921075138.124099-10-yi.l.liu@intel.com
State New
Headers
Series iommufd: Add nesting infrastructure |

Commit Message

Yi Liu Sept. 21, 2023, 7:51 a.m. UTC
  From: Nicolin Chen <nicolinc@nvidia.com>

The iopt_table_enforce_dev_resv_regions() and iopt_remove_reserved_iova()
require callers to pass in an ioas->iopt pointer. It simply works with a
kernel-managed hw_pagetable by passing in its hwpt->ioas->iopt pointer.
However, now there could be a user-managed hw_pagetable that doesn't have
an ioas pointer. And typically most of device reserved regions should be
enforced to a kernel-managed domain only, although the IOMMU_RESV_SW_MSI
used by SMMU will introduce some complication.

Add a pair of iommufd_device_enforce_rr/iommufd_device_remove_rr helpers
that calls iopt_table_enforce_dev_resv_regions/iopt_remove_reserved_iova
functions after some additional checks. This would also ease any further
extension to support the IOMMU_RESV_SW_MSI complication mentioned above.

For the replace() routine, add another helper to compare ioas pointers,
with the support of user-managed hw_pagetables.

Signed-off-by: Nicolin Chen <nicolinc@nvidia.com>
Signed-off-by: Yi Liu <yi.l.liu@intel.com>
---
 drivers/iommu/iommufd/device.c          | 42 ++++++++++++++++++-------
 drivers/iommu/iommufd/iommufd_private.h | 18 +++++++++++
 2 files changed, 48 insertions(+), 12 deletions(-)
  

Comments

Yan Zhao Oct. 7, 2023, 7:20 a.m. UTC | #1
> @@ -444,10 +465,9 @@ iommufd_device_do_replace(struct iommufd_device *idev,
>  	}
>  
>  	old_hwpt = igroup->hwpt;
> -	if (hwpt->ioas != old_hwpt->ioas) {
> +	if (iommufd_hw_pagetable_compare_ioas(old_hwpt, hwpt)) {
>  		list_for_each_entry(cur, &igroup->device_list, group_item) {
> -			rc = iopt_table_enforce_dev_resv_regions(
> -				&hwpt->ioas->iopt, cur->dev, NULL);
> +			rc = iommufd_device_enforce_rr(cur, hwpt, NULL);
>  			if (rc)
>  				goto err_unresv;
>  		}
> @@ -461,12 +481,10 @@ iommufd_device_do_replace(struct iommufd_device *idev,
>  	if (rc)
>  		goto err_unresv;
>  
> -	if (hwpt->ioas != old_hwpt->ioas) {
> +	if (iommufd_hw_pagetable_compare_ioas(old_hwpt, hwpt)) {
>  		list_for_each_entry(cur, &igroup->device_list, group_item)
> -			iopt_remove_reserved_iova(&old_hwpt->ioas->iopt,
> -						  cur->dev);
> +			iommufd_device_remove_rr(cur, hwpt);
Should be "iommufd_device_remove_rr(cur, old_hwpt);"

>  	}
> -
>  	igroup->hwpt = hwpt;
>  
>  	/*
  
Nicolin Chen Oct. 7, 2023, 9:27 a.m. UTC | #2
On Sat, Oct 07, 2023 at 03:20:41PM +0800, Yan Zhao wrote:
> > @@ -444,10 +465,9 @@ iommufd_device_do_replace(struct iommufd_device *idev,
> >       }
> >
> >       old_hwpt = igroup->hwpt;
> > -     if (hwpt->ioas != old_hwpt->ioas) {
> > +     if (iommufd_hw_pagetable_compare_ioas(old_hwpt, hwpt)) {
> >               list_for_each_entry(cur, &igroup->device_list, group_item) {
> > -                     rc = iopt_table_enforce_dev_resv_regions(
> > -                             &hwpt->ioas->iopt, cur->dev, NULL);
> > +                     rc = iommufd_device_enforce_rr(cur, hwpt, NULL);
> >                       if (rc)
> >                               goto err_unresv;
> >               }
> > @@ -461,12 +481,10 @@ iommufd_device_do_replace(struct iommufd_device *idev,
> >       if (rc)
> >               goto err_unresv;
> >
> > -     if (hwpt->ioas != old_hwpt->ioas) {
> > +     if (iommufd_hw_pagetable_compare_ioas(old_hwpt, hwpt)) {
> >               list_for_each_entry(cur, &igroup->device_list, group_item)
> > -                     iopt_remove_reserved_iova(&old_hwpt->ioas->iopt,
> > -                                               cur->dev);
> > +                     iommufd_device_remove_rr(cur, hwpt);
> Should be "iommufd_device_remove_rr(cur, old_hwpt);"

Ah, right. Should fix this.

Thanks!
Nicolin
  

Patch

diff --git a/drivers/iommu/iommufd/device.c b/drivers/iommu/iommufd/device.c
index 104dd061a2a3..10e6ec590ede 100644
--- a/drivers/iommu/iommufd/device.c
+++ b/drivers/iommu/iommufd/device.c
@@ -329,6 +329,28 @@  static int iommufd_group_setup_msi(struct iommufd_group *igroup,
 	return 0;
 }
 
+static void iommufd_device_remove_rr(struct iommufd_device *idev,
+				     struct iommufd_hw_pagetable *hwpt)
+{
+	if (WARN_ON(!hwpt))
+		return;
+	if (hwpt->user_managed)
+		return;
+	iopt_remove_reserved_iova(&hwpt->ioas->iopt, idev->dev);
+}
+
+static int iommufd_device_enforce_rr(struct iommufd_device *idev,
+				     struct iommufd_hw_pagetable *hwpt,
+				     phys_addr_t *sw_msi_start)
+{
+	if (WARN_ON(!hwpt))
+		return -EINVAL;
+	if (hwpt->user_managed)
+		return 0;
+	return iopt_table_enforce_dev_resv_regions(&hwpt->ioas->iopt, idev->dev,
+						   sw_msi_start);
+}
+
 int iommufd_hw_pagetable_attach(struct iommufd_hw_pagetable *hwpt,
 				struct iommufd_device *idev)
 {
@@ -348,8 +370,7 @@  int iommufd_hw_pagetable_attach(struct iommufd_hw_pagetable *hwpt,
 			goto err_unlock;
 	}
 
-	rc = iopt_table_enforce_dev_resv_regions(&hwpt->ioas->iopt, idev->dev,
-						 &idev->igroup->sw_msi_start);
+	rc = iommufd_device_enforce_rr(idev, hwpt, &idev->igroup->sw_msi_start);
 	if (rc)
 		goto err_unlock;
 
@@ -375,7 +396,7 @@  int iommufd_hw_pagetable_attach(struct iommufd_hw_pagetable *hwpt,
 	mutex_unlock(&idev->igroup->lock);
 	return 0;
 err_unresv:
-	iopt_remove_reserved_iova(&hwpt->ioas->iopt, idev->dev);
+	iommufd_device_remove_rr(idev, hwpt);
 err_unlock:
 	mutex_unlock(&idev->igroup->lock);
 	return rc;
@@ -392,7 +413,7 @@  iommufd_hw_pagetable_detach(struct iommufd_device *idev)
 		iommu_detach_group(hwpt->domain, idev->igroup->group);
 		idev->igroup->hwpt = NULL;
 	}
-	iopt_remove_reserved_iova(&hwpt->ioas->iopt, idev->dev);
+	iommufd_device_remove_rr(idev, hwpt);
 	mutex_unlock(&idev->igroup->lock);
 
 	/* Caller must destroy hwpt */
@@ -444,10 +465,9 @@  iommufd_device_do_replace(struct iommufd_device *idev,
 	}
 
 	old_hwpt = igroup->hwpt;
-	if (hwpt->ioas != old_hwpt->ioas) {
+	if (iommufd_hw_pagetable_compare_ioas(old_hwpt, hwpt)) {
 		list_for_each_entry(cur, &igroup->device_list, group_item) {
-			rc = iopt_table_enforce_dev_resv_regions(
-				&hwpt->ioas->iopt, cur->dev, NULL);
+			rc = iommufd_device_enforce_rr(cur, hwpt, NULL);
 			if (rc)
 				goto err_unresv;
 		}
@@ -461,12 +481,10 @@  iommufd_device_do_replace(struct iommufd_device *idev,
 	if (rc)
 		goto err_unresv;
 
-	if (hwpt->ioas != old_hwpt->ioas) {
+	if (iommufd_hw_pagetable_compare_ioas(old_hwpt, hwpt)) {
 		list_for_each_entry(cur, &igroup->device_list, group_item)
-			iopt_remove_reserved_iova(&old_hwpt->ioas->iopt,
-						  cur->dev);
+			iommufd_device_remove_rr(cur, hwpt);
 	}
-
 	igroup->hwpt = hwpt;
 
 	/*
@@ -483,7 +501,7 @@  iommufd_device_do_replace(struct iommufd_device *idev,
 	return old_hwpt;
 err_unresv:
 	list_for_each_entry(cur, &igroup->device_list, group_item)
-		iopt_remove_reserved_iova(&hwpt->ioas->iopt, cur->dev);
+		iommufd_device_remove_rr(cur, hwpt);
 err_unlock:
 	mutex_unlock(&idev->igroup->lock);
 	return ERR_PTR(rc);
diff --git a/drivers/iommu/iommufd/iommufd_private.h b/drivers/iommu/iommufd/iommufd_private.h
index 34940596c2c2..b14f23d3f42e 100644
--- a/drivers/iommu/iommufd/iommufd_private.h
+++ b/drivers/iommu/iommufd/iommufd_private.h
@@ -281,6 +281,24 @@  static inline void iommufd_hw_pagetable_put(struct iommufd_ctx *ictx,
 		refcount_dec(&hwpt->obj.users);
 }
 
+static inline bool
+iommufd_hw_pagetable_compare_ioas(struct iommufd_hw_pagetable *old_hwpt,
+				  struct iommufd_hw_pagetable *new_hwpt)
+{
+	struct iommufd_ioas *old_ioas, *new_ioas;
+
+	WARN_ON(!old_hwpt || !new_hwpt);
+	if (old_hwpt->user_managed)
+		old_ioas = old_hwpt->parent->ioas;
+	else
+		old_ioas = old_hwpt->ioas;
+	if (new_hwpt->user_managed)
+		new_ioas = new_hwpt->parent->ioas;
+	else
+		new_ioas = new_hwpt->ioas;
+	return old_ioas != new_ioas;
+}
+
 struct iommufd_group {
 	struct kref ref;
 	struct mutex lock;