[v2,2/3] HID: bpf: actually free hdev memory after attaching a HID-BPF program
Commit Message
Turns out that I got my reference counts wrong and each successful
bus_find_device() actually calls get_device(), and we need to manually
call put_device().
Ensure each bus_find_device() gets a matching put_device() when releasing
the bpf programs and fix all the error paths.
Cc: stable@vger.kernel.org
Fixes: f5c27da4e3c8 ("HID: initial BPF implementation")
Signed-off-by: Benjamin Tissoires <bentiss@kernel.org>
---
new in v2
---
drivers/hid/bpf/hid_bpf_dispatch.c | 29 +++++++++++++++++++++++------
drivers/hid/bpf/hid_bpf_jmp_table.c | 19 ++++++++++++++++---
2 files changed, 39 insertions(+), 9 deletions(-)
Comments
On Wed, Jan 24, 2024 at 12:27 PM Benjamin Tissoires <bentiss@kernelorg> wrote:
>
> Turns out that I got my reference counts wrong and each successful
> bus_find_device() actually calls get_device(), and we need to manually
> call put_device().
>
> Ensure each bus_find_device() gets a matching put_device() when releasing
> the bpf programs and fix all the error paths.
>
> Cc: stable@vger.kernel.org
> Fixes: f5c27da4e3c8 ("HID: initial BPF implementation")
> Signed-off-by: Benjamin Tissoires <bentiss@kernel.org>
>
> ---
>
> new in v2
> ---
> drivers/hid/bpf/hid_bpf_dispatch.c | 29 +++++++++++++++++++++++------
> drivers/hid/bpf/hid_bpf_jmp_table.c | 19 ++++++++++++++++---
> 2 files changed, 39 insertions(+), 9 deletions(-)
>
> diff --git a/drivers/hid/bpf/hid_bpf_dispatch.c b/drivers/hid/bpf/hid_bpf_dispatch.c
> index 5111d1fef0d3..7903c8638e81 100644
> --- a/drivers/hid/bpf/hid_bpf_dispatch.c
> +++ b/drivers/hid/bpf/hid_bpf_dispatch.c
> @@ -292,7 +292,7 @@ hid_bpf_attach_prog(unsigned int hid_id, int prog_fd, __u32 flags)
> struct hid_device *hdev;
> struct bpf_prog *prog;
> struct device *dev;
> - int fd;
> + int err, fd;
>
> if (!hid_bpf_ops)
> return -EINVAL;
> @@ -311,14 +311,24 @@ hid_bpf_attach_prog(unsigned int hid_id, int prog_fd, __u32 flags)
> * on errors or when it'll be detached
> */
> prog = bpf_prog_get(prog_fd);
> - if (IS_ERR(prog))
> - return PTR_ERR(prog);
> + if (IS_ERR(prog)) {
> + err = PTR_ERR(prog);
> + goto out_dev_put;
> + }
>
> fd = do_hid_bpf_attach_prog(hdev, prog_fd, prog, flags);
> - if (fd < 0)
> - bpf_prog_put(prog);
> + if (fd < 0) {
> + err = fd;
> + goto out_prog_put;
> + }
>
> return fd;
> +
> + out_prog_put:
> + bpf_prog_put(prog);
> + out_dev_put:
> + put_device(dev);
> + return err;
> }
>
> /**
> @@ -345,8 +355,10 @@ hid_bpf_allocate_context(unsigned int hid_id)
> hdev = to_hid_device(dev);
>
> ctx_kern = kzalloc(sizeof(*ctx_kern), GFP_KERNEL);
> - if (!ctx_kern)
> + if (!ctx_kern) {
> + put_device(dev);
> return NULL;
> + }
>
> ctx_kern->ctx.hid = hdev;
>
> @@ -363,10 +375,15 @@ noinline void
> hid_bpf_release_context(struct hid_bpf_ctx *ctx)
> {
> struct hid_bpf_ctx_kern *ctx_kern;
> + struct hid_device *hid;
>
> ctx_kern = container_of(ctx, struct hid_bpf_ctx_kern, ctx);
> + hid = (struct hid_device *)ctx_kern->ctx.hid; /* ignore const */
>
> kfree(ctx_kern);
> +
> + /* get_device() is called by bus_find_device() */
> + put_device(&hid->dev);
> }
>
> /**
> diff --git a/drivers/hid/bpf/hid_bpf_jmp_table.c b/drivers/hid/bpf/hid_bpf_jmp_table.c
> index 12f7cebddd73..85a24bc0ea25 100644
> --- a/drivers/hid/bpf/hid_bpf_jmp_table.c
> +++ b/drivers/hid/bpf/hid_bpf_jmp_table.c
> @@ -196,6 +196,7 @@ static void __hid_bpf_do_release_prog(int map_fd, unsigned int idx)
> static void hid_bpf_release_progs(struct work_struct *work)
> {
> int i, j, n, map_fd = -1;
> + bool hdev_destroyed;
>
> if (!jmp_table.map)
> return;
> @@ -220,6 +221,12 @@ static void hid_bpf_release_progs(struct work_struct *work)
> if (entry->hdev) {
> hdev = entry->hdev;
> type = entry->type;
> + /*
> + * hdev is still valid, even if we are called after hid_destroy_device():
> + * when hid_bpf_attach() gets called, it takes a ref on the dev through
> + * bus_find_device()
> + */
> + hdev_destroyed = hdev->bpf.destroyed;
>
> hid_bpf_populate_hdev(hdev, type);
>
> @@ -232,12 +239,18 @@ static void hid_bpf_release_progs(struct work_struct *work)
> if (test_bit(next->idx, jmp_table.enabled))
> continue;
>
> - if (next->hdev == hdev && next->type == type)
> + if (next->hdev == hdev && next->type == type) {
> + /*
> + * clear the hdev reference and decrement the device ref
> + * that was taken during bus_find_device() while calling
> + * hid_bpf_attach()
> + */
> next->hdev = NULL;
> + put_device(&hdev->dev);
sigh... I can't make a correct patch these days... Missing a '}' here
to match the open bracket added above :(
I had some debug information put there to check if the device was
actually freed, and the closing bracket got lost while cleaning this
up.
Cheers,
Benjamin
> }
>
> - /* if type was rdesc fixup, reconnect device */
> - if (type == HID_BPF_PROG_TYPE_RDESC_FIXUP)
> + /* if type was rdesc fixup and the device is not gone, reconnect device */
> + if (type == HID_BPF_PROG_TYPE_RDESC_FIXUP && !hdev_destroyed)
> hid_bpf_reconnect(hdev);
> }
> }
>
> --
> 2.43.0
>
@@ -292,7 +292,7 @@ hid_bpf_attach_prog(unsigned int hid_id, int prog_fd, __u32 flags)
struct hid_device *hdev;
struct bpf_prog *prog;
struct device *dev;
- int fd;
+ int err, fd;
if (!hid_bpf_ops)
return -EINVAL;
@@ -311,14 +311,24 @@ hid_bpf_attach_prog(unsigned int hid_id, int prog_fd, __u32 flags)
* on errors or when it'll be detached
*/
prog = bpf_prog_get(prog_fd);
- if (IS_ERR(prog))
- return PTR_ERR(prog);
+ if (IS_ERR(prog)) {
+ err = PTR_ERR(prog);
+ goto out_dev_put;
+ }
fd = do_hid_bpf_attach_prog(hdev, prog_fd, prog, flags);
- if (fd < 0)
- bpf_prog_put(prog);
+ if (fd < 0) {
+ err = fd;
+ goto out_prog_put;
+ }
return fd;
+
+ out_prog_put:
+ bpf_prog_put(prog);
+ out_dev_put:
+ put_device(dev);
+ return err;
}
/**
@@ -345,8 +355,10 @@ hid_bpf_allocate_context(unsigned int hid_id)
hdev = to_hid_device(dev);
ctx_kern = kzalloc(sizeof(*ctx_kern), GFP_KERNEL);
- if (!ctx_kern)
+ if (!ctx_kern) {
+ put_device(dev);
return NULL;
+ }
ctx_kern->ctx.hid = hdev;
@@ -363,10 +375,15 @@ noinline void
hid_bpf_release_context(struct hid_bpf_ctx *ctx)
{
struct hid_bpf_ctx_kern *ctx_kern;
+ struct hid_device *hid;
ctx_kern = container_of(ctx, struct hid_bpf_ctx_kern, ctx);
+ hid = (struct hid_device *)ctx_kern->ctx.hid; /* ignore const */
kfree(ctx_kern);
+
+ /* get_device() is called by bus_find_device() */
+ put_device(&hid->dev);
}
/**
@@ -196,6 +196,7 @@ static void __hid_bpf_do_release_prog(int map_fd, unsigned int idx)
static void hid_bpf_release_progs(struct work_struct *work)
{
int i, j, n, map_fd = -1;
+ bool hdev_destroyed;
if (!jmp_table.map)
return;
@@ -220,6 +221,12 @@ static void hid_bpf_release_progs(struct work_struct *work)
if (entry->hdev) {
hdev = entry->hdev;
type = entry->type;
+ /*
+ * hdev is still valid, even if we are called after hid_destroy_device():
+ * when hid_bpf_attach() gets called, it takes a ref on the dev through
+ * bus_find_device()
+ */
+ hdev_destroyed = hdev->bpf.destroyed;
hid_bpf_populate_hdev(hdev, type);
@@ -232,12 +239,18 @@ static void hid_bpf_release_progs(struct work_struct *work)
if (test_bit(next->idx, jmp_table.enabled))
continue;
- if (next->hdev == hdev && next->type == type)
+ if (next->hdev == hdev && next->type == type) {
+ /*
+ * clear the hdev reference and decrement the device ref
+ * that was taken during bus_find_device() while calling
+ * hid_bpf_attach()
+ */
next->hdev = NULL;
+ put_device(&hdev->dev);
}
- /* if type was rdesc fixup, reconnect device */
- if (type == HID_BPF_PROG_TYPE_RDESC_FIXUP)
+ /* if type was rdesc fixup and the device is not gone, reconnect device */
+ if (type == HID_BPF_PROG_TYPE_RDESC_FIXUP && !hdev_destroyed)
hid_bpf_reconnect(hdev);
}
}