[vhost,13/16] vdpa/mlx5: Introduce mr for vq descriptor

Message ID 20230928164550.980832-15-dtatulea@nvidia.com
State New
Headers
Series vdpa: Add support for vq descriptor mappings |

Commit Message

Dragos Tatulea Sept. 28, 2023, 4:45 p.m. UTC
  Introduce the vq descriptor group and ASID 1. Until now .set_map on ASID
1 was only updating the cvq iotlb. From now on it also creates a mkey
for it. The current patch doesn't use it but follow-up patches will
add hardware support for mapping the vq descriptors.

Signed-off-by: Dragos Tatulea <dtatulea@nvidia.com>
---
 drivers/vdpa/mlx5/core/mlx5_vdpa.h |  5 +++--
 drivers/vdpa/mlx5/core/mr.c        | 14 +++++++++-----
 drivers/vdpa/mlx5/net/mlx5_vnet.c  | 20 +++++++++++++-------
 3 files changed, 25 insertions(+), 14 deletions(-)
  

Comments

Eugenio Perez Martin Oct. 4, 2023, 6:53 p.m. UTC | #1
On Thu, Sep 28, 2023 at 6:50 PM Dragos Tatulea <dtatulea@nvidia.com> wrote:
>
> Introduce the vq descriptor group and ASID 1. Until now .set_map on ASID

s/ASID/vq group/?

> 1 was only updating the cvq iotlb. From now on it also creates a mkey
> for it. The current patch doesn't use it but follow-up patches will
> add hardware support for mapping the vq descriptors.
>
> Signed-off-by: Dragos Tatulea <dtatulea@nvidia.com>
> ---
>  drivers/vdpa/mlx5/core/mlx5_vdpa.h |  5 +++--
>  drivers/vdpa/mlx5/core/mr.c        | 14 +++++++++-----
>  drivers/vdpa/mlx5/net/mlx5_vnet.c  | 20 +++++++++++++-------
>  3 files changed, 25 insertions(+), 14 deletions(-)
>
> diff --git a/drivers/vdpa/mlx5/core/mlx5_vdpa.h b/drivers/vdpa/mlx5/core/mlx5_vdpa.h
> index bbe4335106bd..ae09296f4270 100644
> --- a/drivers/vdpa/mlx5/core/mlx5_vdpa.h
> +++ b/drivers/vdpa/mlx5/core/mlx5_vdpa.h
> @@ -70,11 +70,12 @@ struct mlx5_vdpa_wq_ent {
>  enum {
>         MLX5_VDPA_DATAVQ_GROUP,
>         MLX5_VDPA_CVQ_GROUP,
> +       MLX5_VDPA_DATAVQ_DESC_GROUP,
>         MLX5_VDPA_NUMVQ_GROUPS
>  };
>
>  enum {
> -       MLX5_VDPA_NUM_AS = MLX5_VDPA_NUMVQ_GROUPS
> +       MLX5_VDPA_NUM_AS = 2
>  };
>
>  struct mlx5_vdpa_dev {
> @@ -89,7 +90,7 @@ struct mlx5_vdpa_dev {
>         u16 max_idx;
>         u32 generation;
>
> -       struct mlx5_vdpa_mr *mr;
> +       struct mlx5_vdpa_mr *mr[MLX5_VDPA_NUM_AS];

I'm wondering if it makes sense to squash all of this patch with the
previous one, as I think *mr[MLX5_VDPA_NUM_AS] makes way more sense
than just *mr.

Whatever you choose, for both patches:

Acked-by: Eugenio Pérez <eperezma@redhat.com>

>         /* serialize mr access */
>         struct mutex mr_mtx;
>         struct mlx5_control_vq cvq;
> diff --git a/drivers/vdpa/mlx5/core/mr.c b/drivers/vdpa/mlx5/core/mr.c
> index 00eff5a07152..3dee6d9bed6b 100644
> --- a/drivers/vdpa/mlx5/core/mr.c
> +++ b/drivers/vdpa/mlx5/core/mr.c
> @@ -511,8 +511,10 @@ void mlx5_vdpa_destroy_mr(struct mlx5_vdpa_dev *mvdev,
>
>         _mlx5_vdpa_destroy_mr(mvdev, mr);
>
> -       if (mvdev->mr == mr)
> -               mvdev->mr = NULL;
> +       for (int i = 0; i < MLX5_VDPA_NUM_AS; i++) {
> +               if (mvdev->mr[i] == mr)
> +                       mvdev->mr[i] = NULL;
> +       }
>
>         mutex_unlock(&mvdev->mr_mtx);
>
> @@ -523,11 +525,11 @@ void mlx5_vdpa_update_mr(struct mlx5_vdpa_dev *mvdev,
>                          struct mlx5_vdpa_mr *new_mr,
>                          unsigned int asid)
>  {
> -       struct mlx5_vdpa_mr *old_mr = mvdev->mr;
> +       struct mlx5_vdpa_mr *old_mr = mvdev->mr[asid];
>
>         mutex_lock(&mvdev->mr_mtx);
>
> -       mvdev->mr = new_mr;
> +       mvdev->mr[asid] = new_mr;
>         if (old_mr) {
>                 _mlx5_vdpa_destroy_mr(mvdev, old_mr);
>                 kfree(old_mr);
> @@ -539,7 +541,9 @@ void mlx5_vdpa_update_mr(struct mlx5_vdpa_dev *mvdev,
>
>  void mlx5_vdpa_destroy_mr_resources(struct mlx5_vdpa_dev *mvdev)
>  {
> -       mlx5_vdpa_destroy_mr(mvdev, mvdev->mr);
> +       for (int i = 0; i < MLX5_VDPA_NUM_AS; i++)
> +               mlx5_vdpa_destroy_mr(mvdev, mvdev->mr[i]);
> +
>         prune_iotlb(mvdev);
>  }
>
> diff --git a/drivers/vdpa/mlx5/net/mlx5_vnet.c b/drivers/vdpa/mlx5/net/mlx5_vnet.c
> index 4a87f9119fca..25bd2c324f5b 100644
> --- a/drivers/vdpa/mlx5/net/mlx5_vnet.c
> +++ b/drivers/vdpa/mlx5/net/mlx5_vnet.c
> @@ -821,6 +821,8 @@ static int create_virtqueue(struct mlx5_vdpa_net *ndev, struct mlx5_vdpa_virtque
>  {
>         int inlen = MLX5_ST_SZ_BYTES(create_virtio_net_q_in);
>         u32 out[MLX5_ST_SZ_DW(create_virtio_net_q_out)] = {};
> +       struct mlx5_vdpa_dev *mvdev = &ndev->mvdev;
> +       struct mlx5_vdpa_mr *vq_mr;
>         void *obj_context;
>         u16 mlx_features;
>         void *cmd_hdr;
> @@ -873,7 +875,9 @@ static int create_virtqueue(struct mlx5_vdpa_net *ndev, struct mlx5_vdpa_virtque
>         MLX5_SET64(virtio_q, vq_ctx, desc_addr, mvq->desc_addr);
>         MLX5_SET64(virtio_q, vq_ctx, used_addr, mvq->device_addr);
>         MLX5_SET64(virtio_q, vq_ctx, available_addr, mvq->driver_addr);
> -       MLX5_SET(virtio_q, vq_ctx, virtio_q_mkey, ndev->mvdev.mr->mkey);
> +       vq_mr = mvdev->mr[mvdev->group2asid[MLX5_VDPA_DATAVQ_GROUP]];
> +       if (vq_mr)
> +               MLX5_SET(virtio_q, vq_ctx, virtio_q_mkey, vq_mr->mkey);
>         MLX5_SET(virtio_q, vq_ctx, umem_1_id, mvq->umem1.id);
>         MLX5_SET(virtio_q, vq_ctx, umem_1_size, mvq->umem1.size);
>         MLX5_SET(virtio_q, vq_ctx, umem_2_id, mvq->umem2.id);
> @@ -2633,7 +2637,8 @@ static void restore_channels_info(struct mlx5_vdpa_net *ndev)
>  }
>
>  static int mlx5_vdpa_change_map(struct mlx5_vdpa_dev *mvdev,
> -                               struct mlx5_vdpa_mr *new_mr, unsigned int asid)
> +                               struct mlx5_vdpa_mr *new_mr,
> +                               unsigned int asid)
>  {
>         struct mlx5_vdpa_net *ndev = to_mlx5_vdpa_ndev(mvdev);
>         int err;
> @@ -2652,8 +2657,10 @@ static int mlx5_vdpa_change_map(struct mlx5_vdpa_dev *mvdev,
>
>         restore_channels_info(ndev);
>         err = setup_driver(mvdev);
> +       if (err)
> +               return err;
>
> -       return err;
> +       return 0;
>  }
>
>  /* reslock must be held for this function */
> @@ -2869,8 +2876,8 @@ static int set_map_data(struct mlx5_vdpa_dev *mvdev, struct vhost_iotlb *iotlb,
>         struct mlx5_vdpa_mr *new_mr;
>         int err;
>
> -       if (mvdev->group2asid[MLX5_VDPA_DATAVQ_GROUP] != asid)
> -               goto end;
> +       if (asid >= MLX5_VDPA_NUM_AS)
> +               return -EINVAL;
>
>         new_mr = mlx5_vdpa_create_mr(mvdev, iotlb);
>         if (IS_ERR(new_mr)) {
> @@ -2879,7 +2886,7 @@ static int set_map_data(struct mlx5_vdpa_dev *mvdev, struct vhost_iotlb *iotlb,
>                 return err;
>         }
>
> -       if (!mvdev->mr) {
> +       if (!mvdev->mr[asid]) {
>                 mlx5_vdpa_update_mr(mvdev, new_mr, asid);
>         } else {
>                 err = mlx5_vdpa_change_map(mvdev, new_mr, asid);
> @@ -2889,7 +2896,6 @@ static int set_map_data(struct mlx5_vdpa_dev *mvdev, struct vhost_iotlb *iotlb,
>                 }
>         }
>
> -end:
>         return mlx5_vdpa_update_cvq_iotlb(mvdev, iotlb, asid);
>
>  out_err:
> --
> 2.41.0
>
  
Dragos Tatulea Oct. 5, 2023, 12:09 p.m. UTC | #2
On Wed, 2023-10-04 at 20:53 +0200, Eugenio Perez Martin wrote:
> On Thu, Sep 28, 2023 at 6:50 PM Dragos Tatulea <dtatulea@nvidia.com> wrote:
> > 
> > Introduce the vq descriptor group and ASID 1. Until now .set_map on ASID
> 
> s/ASID/vq group/?
> 
Oh, indeed.

> > 1 was only updating the cvq iotlb. From now on it also creates a mkey
> > for it. The current patch doesn't use it but follow-up patches will
> > add hardware support for mapping the vq descriptors.
> > 
> > Signed-off-by: Dragos Tatulea <dtatulea@nvidia.com>
> > ---
> >  drivers/vdpa/mlx5/core/mlx5_vdpa.h |  5 +++--
> >  drivers/vdpa/mlx5/core/mr.c        | 14 +++++++++-----
> >  drivers/vdpa/mlx5/net/mlx5_vnet.c  | 20 +++++++++++++-------
> >  3 files changed, 25 insertions(+), 14 deletions(-)
> > 
> > diff --git a/drivers/vdpa/mlx5/core/mlx5_vdpa.h
> > b/drivers/vdpa/mlx5/core/mlx5_vdpa.h
> > index bbe4335106bd..ae09296f4270 100644
> > --- a/drivers/vdpa/mlx5/core/mlx5_vdpa.h
> > +++ b/drivers/vdpa/mlx5/core/mlx5_vdpa.h
> > @@ -70,11 +70,12 @@ struct mlx5_vdpa_wq_ent {
> >  enum {
> >         MLX5_VDPA_DATAVQ_GROUP,
> >         MLX5_VDPA_CVQ_GROUP,
> > +       MLX5_VDPA_DATAVQ_DESC_GROUP,
> >         MLX5_VDPA_NUMVQ_GROUPS
> >  };
> > 
> >  enum {
> > -       MLX5_VDPA_NUM_AS = MLX5_VDPA_NUMVQ_GROUPS
> > +       MLX5_VDPA_NUM_AS = 2
> >  };
> > 
> >  struct mlx5_vdpa_dev {
> > @@ -89,7 +90,7 @@ struct mlx5_vdpa_dev {
> >         u16 max_idx;
> >         u32 generation;
> > 
> > -       struct mlx5_vdpa_mr *mr;
> > +       struct mlx5_vdpa_mr *mr[MLX5_VDPA_NUM_AS];
> 
> I'm wondering if it makes sense to squash all of this patch with the
> previous one, as I think *mr[MLX5_VDPA_NUM_AS] makes way more sense
> than just *mr.
> 
I've been on the fence about this one. It seemed cleaner to have two patches.

> Whatever you choose, for both patches:
> 
> Acked-by: Eugenio Pérez <eperezma@redhat.com>
> 
> >         /* serialize mr access */
> >         struct mutex mr_mtx;
> >         struct mlx5_control_vq cvq;
> > diff --git a/drivers/vdpa/mlx5/core/mr.c b/drivers/vdpa/mlx5/core/mr.c
> > index 00eff5a07152..3dee6d9bed6b 100644
> > --- a/drivers/vdpa/mlx5/core/mr.c
> > +++ b/drivers/vdpa/mlx5/core/mr.c
> > @@ -511,8 +511,10 @@ void mlx5_vdpa_destroy_mr(struct mlx5_vdpa_dev *mvdev,
> > 
> >         _mlx5_vdpa_destroy_mr(mvdev, mr);
> > 
> > -       if (mvdev->mr == mr)
> > -               mvdev->mr = NULL;
> > +       for (int i = 0; i < MLX5_VDPA_NUM_AS; i++) {
> > +               if (mvdev->mr[i] == mr)
> > +                       mvdev->mr[i] = NULL;
> > +       }
> > 
> >         mutex_unlock(&mvdev->mr_mtx);
> > 
> > @@ -523,11 +525,11 @@ void mlx5_vdpa_update_mr(struct mlx5_vdpa_dev *mvdev,
> >                          struct mlx5_vdpa_mr *new_mr,
> >                          unsigned int asid)
> >  {
> > -       struct mlx5_vdpa_mr *old_mr = mvdev->mr;
> > +       struct mlx5_vdpa_mr *old_mr = mvdev->mr[asid];
> > 
> >         mutex_lock(&mvdev->mr_mtx);
> > 
> > -       mvdev->mr = new_mr;
> > +       mvdev->mr[asid] = new_mr;
> >         if (old_mr) {
> >                 _mlx5_vdpa_destroy_mr(mvdev, old_mr);
> >                 kfree(old_mr);
> > @@ -539,7 +541,9 @@ void mlx5_vdpa_update_mr(struct mlx5_vdpa_dev *mvdev,
> > 
> >  void mlx5_vdpa_destroy_mr_resources(struct mlx5_vdpa_dev *mvdev)
> >  {
> > -       mlx5_vdpa_destroy_mr(mvdev, mvdev->mr);
> > +       for (int i = 0; i < MLX5_VDPA_NUM_AS; i++)
> > +               mlx5_vdpa_destroy_mr(mvdev, mvdev->mr[i]);
> > +
> >         prune_iotlb(mvdev);
> >  }
> > 
> > diff --git a/drivers/vdpa/mlx5/net/mlx5_vnet.c
> > b/drivers/vdpa/mlx5/net/mlx5_vnet.c
> > index 4a87f9119fca..25bd2c324f5b 100644
> > --- a/drivers/vdpa/mlx5/net/mlx5_vnet.c
> > +++ b/drivers/vdpa/mlx5/net/mlx5_vnet.c
> > @@ -821,6 +821,8 @@ static int create_virtqueue(struct mlx5_vdpa_net *ndev,
> > struct mlx5_vdpa_virtque
> >  {
> >         int inlen = MLX5_ST_SZ_BYTES(create_virtio_net_q_in);
> >         u32 out[MLX5_ST_SZ_DW(create_virtio_net_q_out)] = {};
> > +       struct mlx5_vdpa_dev *mvdev = &ndev->mvdev;
> > +       struct mlx5_vdpa_mr *vq_mr;
> >         void *obj_context;
> >         u16 mlx_features;
> >         void *cmd_hdr;
> > @@ -873,7 +875,9 @@ static int create_virtqueue(struct mlx5_vdpa_net *ndev,
> > struct mlx5_vdpa_virtque
> >         MLX5_SET64(virtio_q, vq_ctx, desc_addr, mvq->desc_addr);
> >         MLX5_SET64(virtio_q, vq_ctx, used_addr, mvq->device_addr);
> >         MLX5_SET64(virtio_q, vq_ctx, available_addr, mvq->driver_addr);
> > -       MLX5_SET(virtio_q, vq_ctx, virtio_q_mkey, ndev->mvdev.mr->mkey);
> > +       vq_mr = mvdev->mr[mvdev->group2asid[MLX5_VDPA_DATAVQ_GROUP]];
> > +       if (vq_mr)
> > +               MLX5_SET(virtio_q, vq_ctx, virtio_q_mkey, vq_mr->mkey);
> >         MLX5_SET(virtio_q, vq_ctx, umem_1_id, mvq->umem1.id);
> >         MLX5_SET(virtio_q, vq_ctx, umem_1_size, mvq->umem1.size);
> >         MLX5_SET(virtio_q, vq_ctx, umem_2_id, mvq->umem2.id);
> > @@ -2633,7 +2637,8 @@ static void restore_channels_info(struct mlx5_vdpa_net
> > *ndev)
> >  }
> > 
> >  static int mlx5_vdpa_change_map(struct mlx5_vdpa_dev *mvdev,
> > -                               struct mlx5_vdpa_mr *new_mr, unsigned int
> > asid)
> > +                               struct mlx5_vdpa_mr *new_mr,
> > +                               unsigned int asid)
> >  {
> >         struct mlx5_vdpa_net *ndev = to_mlx5_vdpa_ndev(mvdev);
> >         int err;
> > @@ -2652,8 +2657,10 @@ static int mlx5_vdpa_change_map(struct mlx5_vdpa_dev
> > *mvdev,
> > 
> >         restore_channels_info(ndev);
> >         err = setup_driver(mvdev);
> > +       if (err)
> > +               return err;
> > 
> > -       return err;
> > +       return 0;
> >  }
> > 
> >  /* reslock must be held for this function */
> > @@ -2869,8 +2876,8 @@ static int set_map_data(struct mlx5_vdpa_dev *mvdev,
> > struct vhost_iotlb *iotlb,
> >         struct mlx5_vdpa_mr *new_mr;
> >         int err;
> > 
> > -       if (mvdev->group2asid[MLX5_VDPA_DATAVQ_GROUP] != asid)
> > -               goto end;
> > +       if (asid >= MLX5_VDPA_NUM_AS)
> > +               return -EINVAL;
> > 
> >         new_mr = mlx5_vdpa_create_mr(mvdev, iotlb);
> >         if (IS_ERR(new_mr)) {
> > @@ -2879,7 +2886,7 @@ static int set_map_data(struct mlx5_vdpa_dev *mvdev,
> > struct vhost_iotlb *iotlb,
> >                 return err;
> >         }
> > 
> > -       if (!mvdev->mr) {
> > +       if (!mvdev->mr[asid]) {
> >                 mlx5_vdpa_update_mr(mvdev, new_mr, asid);
> >         } else {
> >                 err = mlx5_vdpa_change_map(mvdev, new_mr, asid);
> > @@ -2889,7 +2896,6 @@ static int set_map_data(struct mlx5_vdpa_dev *mvdev,
> > struct vhost_iotlb *iotlb,
> >                 }
> >         }
> > 
> > -end:
> >         return mlx5_vdpa_update_cvq_iotlb(mvdev, iotlb, asid);
> > 
> >  out_err:
> > --
> > 2.41.0
> > 
>
  

Patch

diff --git a/drivers/vdpa/mlx5/core/mlx5_vdpa.h b/drivers/vdpa/mlx5/core/mlx5_vdpa.h
index bbe4335106bd..ae09296f4270 100644
--- a/drivers/vdpa/mlx5/core/mlx5_vdpa.h
+++ b/drivers/vdpa/mlx5/core/mlx5_vdpa.h
@@ -70,11 +70,12 @@  struct mlx5_vdpa_wq_ent {
 enum {
 	MLX5_VDPA_DATAVQ_GROUP,
 	MLX5_VDPA_CVQ_GROUP,
+	MLX5_VDPA_DATAVQ_DESC_GROUP,
 	MLX5_VDPA_NUMVQ_GROUPS
 };
 
 enum {
-	MLX5_VDPA_NUM_AS = MLX5_VDPA_NUMVQ_GROUPS
+	MLX5_VDPA_NUM_AS = 2
 };
 
 struct mlx5_vdpa_dev {
@@ -89,7 +90,7 @@  struct mlx5_vdpa_dev {
 	u16 max_idx;
 	u32 generation;
 
-	struct mlx5_vdpa_mr *mr;
+	struct mlx5_vdpa_mr *mr[MLX5_VDPA_NUM_AS];
 	/* serialize mr access */
 	struct mutex mr_mtx;
 	struct mlx5_control_vq cvq;
diff --git a/drivers/vdpa/mlx5/core/mr.c b/drivers/vdpa/mlx5/core/mr.c
index 00eff5a07152..3dee6d9bed6b 100644
--- a/drivers/vdpa/mlx5/core/mr.c
+++ b/drivers/vdpa/mlx5/core/mr.c
@@ -511,8 +511,10 @@  void mlx5_vdpa_destroy_mr(struct mlx5_vdpa_dev *mvdev,
 
 	_mlx5_vdpa_destroy_mr(mvdev, mr);
 
-	if (mvdev->mr == mr)
-		mvdev->mr = NULL;
+	for (int i = 0; i < MLX5_VDPA_NUM_AS; i++) {
+		if (mvdev->mr[i] == mr)
+			mvdev->mr[i] = NULL;
+	}
 
 	mutex_unlock(&mvdev->mr_mtx);
 
@@ -523,11 +525,11 @@  void mlx5_vdpa_update_mr(struct mlx5_vdpa_dev *mvdev,
 			 struct mlx5_vdpa_mr *new_mr,
 			 unsigned int asid)
 {
-	struct mlx5_vdpa_mr *old_mr = mvdev->mr;
+	struct mlx5_vdpa_mr *old_mr = mvdev->mr[asid];
 
 	mutex_lock(&mvdev->mr_mtx);
 
-	mvdev->mr = new_mr;
+	mvdev->mr[asid] = new_mr;
 	if (old_mr) {
 		_mlx5_vdpa_destroy_mr(mvdev, old_mr);
 		kfree(old_mr);
@@ -539,7 +541,9 @@  void mlx5_vdpa_update_mr(struct mlx5_vdpa_dev *mvdev,
 
 void mlx5_vdpa_destroy_mr_resources(struct mlx5_vdpa_dev *mvdev)
 {
-	mlx5_vdpa_destroy_mr(mvdev, mvdev->mr);
+	for (int i = 0; i < MLX5_VDPA_NUM_AS; i++)
+		mlx5_vdpa_destroy_mr(mvdev, mvdev->mr[i]);
+
 	prune_iotlb(mvdev);
 }
 
diff --git a/drivers/vdpa/mlx5/net/mlx5_vnet.c b/drivers/vdpa/mlx5/net/mlx5_vnet.c
index 4a87f9119fca..25bd2c324f5b 100644
--- a/drivers/vdpa/mlx5/net/mlx5_vnet.c
+++ b/drivers/vdpa/mlx5/net/mlx5_vnet.c
@@ -821,6 +821,8 @@  static int create_virtqueue(struct mlx5_vdpa_net *ndev, struct mlx5_vdpa_virtque
 {
 	int inlen = MLX5_ST_SZ_BYTES(create_virtio_net_q_in);
 	u32 out[MLX5_ST_SZ_DW(create_virtio_net_q_out)] = {};
+	struct mlx5_vdpa_dev *mvdev = &ndev->mvdev;
+	struct mlx5_vdpa_mr *vq_mr;
 	void *obj_context;
 	u16 mlx_features;
 	void *cmd_hdr;
@@ -873,7 +875,9 @@  static int create_virtqueue(struct mlx5_vdpa_net *ndev, struct mlx5_vdpa_virtque
 	MLX5_SET64(virtio_q, vq_ctx, desc_addr, mvq->desc_addr);
 	MLX5_SET64(virtio_q, vq_ctx, used_addr, mvq->device_addr);
 	MLX5_SET64(virtio_q, vq_ctx, available_addr, mvq->driver_addr);
-	MLX5_SET(virtio_q, vq_ctx, virtio_q_mkey, ndev->mvdev.mr->mkey);
+	vq_mr = mvdev->mr[mvdev->group2asid[MLX5_VDPA_DATAVQ_GROUP]];
+	if (vq_mr)
+		MLX5_SET(virtio_q, vq_ctx, virtio_q_mkey, vq_mr->mkey);
 	MLX5_SET(virtio_q, vq_ctx, umem_1_id, mvq->umem1.id);
 	MLX5_SET(virtio_q, vq_ctx, umem_1_size, mvq->umem1.size);
 	MLX5_SET(virtio_q, vq_ctx, umem_2_id, mvq->umem2.id);
@@ -2633,7 +2637,8 @@  static void restore_channels_info(struct mlx5_vdpa_net *ndev)
 }
 
 static int mlx5_vdpa_change_map(struct mlx5_vdpa_dev *mvdev,
-				struct mlx5_vdpa_mr *new_mr, unsigned int asid)
+				struct mlx5_vdpa_mr *new_mr,
+				unsigned int asid)
 {
 	struct mlx5_vdpa_net *ndev = to_mlx5_vdpa_ndev(mvdev);
 	int err;
@@ -2652,8 +2657,10 @@  static int mlx5_vdpa_change_map(struct mlx5_vdpa_dev *mvdev,
 
 	restore_channels_info(ndev);
 	err = setup_driver(mvdev);
+	if (err)
+		return err;
 
-	return err;
+	return 0;
 }
 
 /* reslock must be held for this function */
@@ -2869,8 +2876,8 @@  static int set_map_data(struct mlx5_vdpa_dev *mvdev, struct vhost_iotlb *iotlb,
 	struct mlx5_vdpa_mr *new_mr;
 	int err;
 
-	if (mvdev->group2asid[MLX5_VDPA_DATAVQ_GROUP] != asid)
-		goto end;
+	if (asid >= MLX5_VDPA_NUM_AS)
+		return -EINVAL;
 
 	new_mr = mlx5_vdpa_create_mr(mvdev, iotlb);
 	if (IS_ERR(new_mr)) {
@@ -2879,7 +2886,7 @@  static int set_map_data(struct mlx5_vdpa_dev *mvdev, struct vhost_iotlb *iotlb,
 		return err;
 	}
 
-	if (!mvdev->mr) {
+	if (!mvdev->mr[asid]) {
 		mlx5_vdpa_update_mr(mvdev, new_mr, asid);
 	} else {
 		err = mlx5_vdpa_change_map(mvdev, new_mr, asid);
@@ -2889,7 +2896,6 @@  static int set_map_data(struct mlx5_vdpa_dev *mvdev, struct vhost_iotlb *iotlb,
 		}
 	}
 
-end:
 	return mlx5_vdpa_update_cvq_iotlb(mvdev, iotlb, asid);
 
 out_err: