On Sun, Nov 26, 2023 at 10:34:21PM -0800, Yi Liu wrote:
> +int iommu_replace_device_pasid(struct iommu_domain *domain,
> + struct device *dev, ioasid_t pasid)
> +{
> + struct iommu_group *group = dev->iommu_group;
> + struct iommu_domain *old_domain;
> + int ret;
> +
> + if (!domain)
> + return -EINVAL;
> +
> + if (!group)
> + return -ENODEV;
> +
> + mutex_lock(&group->mutex);
> + __iommu_remove_group_pasid(group, pasid);
It is not replace if you do remove first.
Replace must just call set_dev_pasid and nothing much else..
Jason
@@ -19,6 +19,8 @@ static inline const struct iommu_ops *dev_iommu_ops(struct device *dev)
int iommu_group_replace_domain(struct iommu_group *group,
struct iommu_domain *new_domain);
+int iommu_replace_device_pasid(struct iommu_domain *domain,
+ struct device *dev, ioasid_t pasid);
int iommu_device_register_bus(struct iommu_device *iommu,
const struct iommu_ops *ops, struct bus_type *bus,
@@ -3430,6 +3430,27 @@ static void __iommu_remove_group_pasid(struct iommu_group *group,
}
}
+static int __iommu_group_attach_pasid(struct iommu_domain *domain,
+ struct iommu_group *group, ioasid_t pasid)
+{
+ void *curr;
+ int ret;
+
+ lockdep_assert_held(&group->mutex);
+
+ curr = xa_cmpxchg(&group->pasid_array, pasid, NULL, domain, GFP_KERNEL);
+ if (curr)
+ return xa_err(curr) ? : -EBUSY;
+
+ ret = __iommu_set_group_pasid(domain, group, pasid);
+ if (ret) {
+ __iommu_remove_group_pasid(group, pasid);
+ xa_erase(&group->pasid_array, pasid);
+ }
+
+ return ret;
+}
+
/*
* iommu_attach_device_pasid() - Attach a domain to pasid of device
* @domain: the iommu domain.
@@ -3453,19 +3474,9 @@ int iommu_attach_device_pasid(struct iommu_domain *domain,
return -ENODEV;
mutex_lock(&group->mutex);
- curr = xa_cmpxchg(&group->pasid_array, pasid, NULL, domain, GFP_KERNEL);
- if (curr) {
- ret = xa_err(curr) ? : -EBUSY;
- goto out_unlock;
- }
-
- ret = __iommu_set_group_pasid(domain, group, pasid);
- if (ret) {
- __iommu_remove_group_pasid(group, pasid);
- xa_erase(&group->pasid_array, pasid);
- }
-out_unlock:
+ ret = __iommu_group_attach_pasid(domain, group, pasid);
mutex_unlock(&group->mutex);
+
return ret;
}
EXPORT_SYMBOL_GPL(iommu_attach_device_pasid);
@@ -3479,8 +3490,8 @@ EXPORT_SYMBOL_GPL(iommu_attach_device_pasid);
* The @domain must have been attached to @pasid of the @dev with
* iommu_attach_device_pasid().
*/
-void iommu_detach_device_pasid(struct iommu_domain *domain, struct device *dev,
- ioasid_t pasid)
+void iommu_detach_device_pasid(struct iommu_domain *domain,
+ struct device *dev, ioasid_t pasid)
{
/* Caller must be a probed driver on dev */
struct iommu_group *group = dev->iommu_group;
@@ -3492,6 +3503,49 @@ void iommu_detach_device_pasid(struct iommu_domain *domain, struct device *dev,
}
EXPORT_SYMBOL_GPL(iommu_detach_device_pasid);
+/**
+ * iommu_replace_device_pasid - replace the domain that a pasid is attached to
+ * @domain: new IOMMU domain to replace with
+ * @dev: the physical device
+ * @pasid: pasid that will be attached to the new domain
+ *
+ * This API allows the pasid to switch domains. Return 0 on success, or an
+ * error. The pasid will roll back to use the old domain if failure. The
+ * caller could call iommu_detach_device_pasid() before free the old domain
+ * in order to avoid use-after-free case.
+ */
+int iommu_replace_device_pasid(struct iommu_domain *domain,
+ struct device *dev, ioasid_t pasid)
+{
+ struct iommu_group *group = dev->iommu_group;
+ struct iommu_domain *old_domain;
+ int ret;
+
+ if (!domain)
+ return -EINVAL;
+
+ if (!group)
+ return -ENODEV;
+
+ mutex_lock(&group->mutex);
+ __iommu_remove_group_pasid(group, pasid);
+ old_domain = xa_erase(&group->pasid_array, pasid);
+ ret = __iommu_group_attach_pasid(domain, group, pasid);
+ if (ret)
+ goto err_rollback;
+ mutex_unlock(&group->mutex);
+
+ return 0;
+
+err_rollback:
+ if (old_domain)
+ __iommu_group_attach_pasid(old_domain, group, pasid);
+ mutex_unlock(&group->mutex);
+
+ return ret;
+}
+EXPORT_SYMBOL_NS_GPL(iommu_replace_device_pasid, IOMMUFD_INTERNAL);
+
/*
* iommu_get_domain_for_dev_pasid() - Retrieve domain for @pasid of @dev
* @dev: the queried device