[v2,2/5] connector/cn_proc: Add filtering to fix some bugs

Message ID 20230315021850.2788946-3-anjali.k.kulkarni@oracle.com
State New
Headers
Series Process connector bug fixes & enhancements |

Commit Message

Anjali Kulkarni March 15, 2023, 2:18 a.m. UTC
  The current proc connector code has the foll. bugs - if there are more
than one listeners for the proc connector messages, and one of them
deregisters for listening using PROC_CN_MCAST_IGNORE, they will still get
all proc connector messages, as long as there is another listener. That's
because there is no client based filtering capability, and we only check
for proc_event_num_listeners being less than 1 to send out multicasts to
listeners.

Another issue is if one client calls PROC_CN_MCAST_LISTEN, and another one
calls PROC_CN_MCAST_IGNORE, then both will end up not getting any messages.
What we need is an ability to filter messages based on whether the client
sent PROC_CN_MCAST_IGNORE.

This patch checks if client has sent PROC_CN_MCAST_IGNORE before sending
any message to that particular client. The PROC_CN_MCAST_IGNORE is stored
in the client socket's sk_user_data.
In addition, we only increment or decrement proc_event_num_listeners once
per client.

A new function netlink_release is added in netlink_sock to store the
protocol's release function. This is called when the socket is deleted,
so we can free sk_user_data at that time. This can be supplied by the
protocol via the release function in netlink_kernel_cfg. cn_release is
the release function added for NETLINK_CONNECTOR.

Signed-off-by: Anjali Kulkarni <anjali.k.kulkarni@oracle.com>
---
 drivers/connector/cn_proc.c   | 54 ++++++++++++++++++++++++++++-------
 drivers/connector/connector.c | 21 +++++++++++---
 drivers/w1/w1_netlink.c       |  6 ++--
 include/linux/connector.h     |  8 +++++-
 include/linux/netlink.h       |  1 +
 include/uapi/linux/cn_proc.h  | 43 ++++++++++++++++------------
 net/netlink/af_netlink.c      | 10 +++++--
 net/netlink/af_netlink.h      |  4 +++
 8 files changed, 110 insertions(+), 37 deletions(-)
  

Comments

Jakub Kicinski March 15, 2023, 5:16 a.m. UTC | #1
On Tue, 14 Mar 2023 19:18:47 -0700 Anjali Kulkarni wrote:
> diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c
> index 003c7e6ec9be..ad8ec18152cd 100644
> --- a/net/netlink/af_netlink.c
> +++ b/net/netlink/af_netlink.c
> @@ -63,6 +63,7 @@
>  #include <linux/net_namespace.h>
>  #include <linux/nospec.h>
>  #include <linux/btf_ids.h>
> +#include <linux/connector.h>

Not needed any more.

>  	/* must not acquire netlink_table_lock in any way again before unbind
>  	 * and notifying genetlink is done as otherwise it might deadlock
>  	 */
> -	if (nlk->netlink_unbind) {
> +	if (nlk->netlink_unbind && nlk->groups) {

Why?

>  		int i;
> -
>  		for (i = 0; i < nlk->ngroups; i++)
>  			if (test_bit(i, nlk->groups))
>  				nlk->netlink_unbind(sock_net(sk), i + 1);

Please separate the netlink core changes from the connector
changes.

Please slow down with new versions, we have 300 patches in the queue,
replying to one version just to notice you posted a new one is
frustrating. Give reviewers 24h to reply.
  
Anjali Kulkarni March 15, 2023, 5:41 p.m. UTC | #2
> On Mar 14, 2023, at 10:16 PM, Jakub Kicinski <kuba@kernel.org> wrote:
> 
> On Tue, 14 Mar 2023 19:18:47 -0700 Anjali Kulkarni wrote:
>> diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c
>> index 003c7e6ec9be..ad8ec18152cd 100644
>> --- a/net/netlink/af_netlink.c
>> +++ b/net/netlink/af_netlink.c
>> @@ -63,6 +63,7 @@
>> #include <linux/net_namespace.h>
>> #include <linux/nospec.h>
>> #include <linux/btf_ids.h>
>> +#include <linux/connector.h>
> 
> Not needed any more.
Will remove
> 
>> 	/* must not acquire netlink_table_lock in any way again before unbind
>> 	 * and notifying genetlink is done as otherwise it might deadlock
>> 	 */
>> -	if (nlk->netlink_unbind) {
>> +	if (nlk->netlink_unbind && nlk->groups) {
> 
> Why?
Just an additional check, will remove
> 
>> 		int i;
>> -
>> 		for (i = 0; i < nlk->ngroups; i++)
>> 			if (test_bit(i, nlk->groups))
>> 				nlk->netlink_unbind(sock_net(sk), i + 1);
> 
> Please separate the netlink core changes from the connector
> changes.
Will do in the next revision
> 
> Please slow down with new versions, we have 300 patches in the queue,
> replying to one version just to notice you posted a new one is
> frustrating. Give reviewers 24h to reply.
Ok, sure!

Anjali
  
kernel test robot March 16, 2023, 7:51 a.m. UTC | #3
Hi Anjali,

Thank you for the patch! Perhaps something to improve:

[auto build test WARNING on net-next/master]
[also build test WARNING on net/master vfs-idmapping/for-next linus/master v6.3-rc2]
[cannot apply to next-20230316]
[If your patch is applied to the wrong git tree, kindly drop us a note.
And when submitting patch, we suggest to use '--base' as documented in
https://git-scm.com/docs/git-format-patch#_base_tree_information]

url:    https://github.com/intel-lab-lkp/linux/commits/Anjali-Kulkarni/netlink-Reverse-the-patch-which-removed-filtering/20230315-102137
patch link:    https://lore.kernel.org/r/20230315021850.2788946-3-anjali.k.kulkarni%40oracle.com
patch subject: [PATCH v2 2/5] connector/cn_proc: Add filtering to fix some bugs
config: hexagon-randconfig-r045-20230313 (https://download.01.org/0day-ci/archive/20230316/202303161507.MJc3Twl0-lkp@intel.com/config)
compiler: clang version 17.0.0 (https://github.com/llvm/llvm-project 67409911353323ca5edf2049ef0df54132fa1ca7)
reproduce (this is a W=1 build):
        wget https://raw.githubusercontent.com/intel/lkp-tests/master/sbin/make.cross -O ~/bin/make.cross
        chmod +x ~/bin/make.cross
        # https://github.com/intel-lab-lkp/linux/commit/3d67dd0fedbcbd247d8a8f925547dc57c5ad83e4
        git remote add linux-review https://github.com/intel-lab-lkp/linux
        git fetch --no-tags linux-review Anjali-Kulkarni/netlink-Reverse-the-patch-which-removed-filtering/20230315-102137
        git checkout 3d67dd0fedbcbd247d8a8f925547dc57c5ad83e4
        # save the config file
        mkdir build_dir && cp config build_dir/.config
        COMPILER_INSTALL_PATH=$HOME/0day COMPILER=clang make.cross W=1 O=build_dir ARCH=hexagon olddefconfig
        COMPILER_INSTALL_PATH=$HOME/0day COMPILER=clang make.cross W=1 O=build_dir ARCH=hexagon SHELL=/bin/bash drivers/connector/

If you fix the issue, kindly add following tag where applicable
| Reported-by: kernel test robot <lkp@intel.com>
| Link: https://lore.kernel.org/oe-kbuild-all/202303161507.MJc3Twl0-lkp@intel.com/

All warnings (new ones prefixed by >>):

   In file included from drivers/connector/cn_proc.c:14:
   In file included from include/linux/connector.h:17:
   In file included from include/net/sock.h:38:
   In file included from include/linux/hardirq.h:11:
   In file included from ./arch/hexagon/include/generated/asm/hardirq.h:1:
   In file included from include/asm-generic/hardirq.h:17:
   In file included from include/linux/irq.h:20:
   In file included from include/linux/io.h:13:
   In file included from arch/hexagon/include/asm/io.h:334:
   include/asm-generic/io.h:547:31: warning: performing pointer arithmetic on a null pointer has undefined behavior [-Wnull-pointer-arithmetic]
           val = __raw_readb(PCI_IOBASE + addr);
                             ~~~~~~~~~~ ^
   include/asm-generic/io.h:560:61: warning: performing pointer arithmetic on a null pointer has undefined behavior [-Wnull-pointer-arithmetic]
           val = __le16_to_cpu((__le16 __force)__raw_readw(PCI_IOBASE + addr));
                                                           ~~~~~~~~~~ ^
   include/uapi/linux/byteorder/little_endian.h:37:51: note: expanded from macro '__le16_to_cpu'
   #define __le16_to_cpu(x) ((__force __u16)(__le16)(x))
                                                     ^
   In file included from drivers/connector/cn_proc.c:14:
   In file included from include/linux/connector.h:17:
   In file included from include/net/sock.h:38:
   In file included from include/linux/hardirq.h:11:
   In file included from ./arch/hexagon/include/generated/asm/hardirq.h:1:
   In file included from include/asm-generic/hardirq.h:17:
   In file included from include/linux/irq.h:20:
   In file included from include/linux/io.h:13:
   In file included from arch/hexagon/include/asm/io.h:334:
   include/asm-generic/io.h:573:61: warning: performing pointer arithmetic on a null pointer has undefined behavior [-Wnull-pointer-arithmetic]
           val = __le32_to_cpu((__le32 __force)__raw_readl(PCI_IOBASE + addr));
                                                           ~~~~~~~~~~ ^
   include/uapi/linux/byteorder/little_endian.h:35:51: note: expanded from macro '__le32_to_cpu'
   #define __le32_to_cpu(x) ((__force __u32)(__le32)(x))
                                                     ^
   In file included from drivers/connector/cn_proc.c:14:
   In file included from include/linux/connector.h:17:
   In file included from include/net/sock.h:38:
   In file included from include/linux/hardirq.h:11:
   In file included from ./arch/hexagon/include/generated/asm/hardirq.h:1:
   In file included from include/asm-generic/hardirq.h:17:
   In file included from include/linux/irq.h:20:
   In file included from include/linux/io.h:13:
   In file included from arch/hexagon/include/asm/io.h:334:
   include/asm-generic/io.h:584:33: warning: performing pointer arithmetic on a null pointer has undefined behavior [-Wnull-pointer-arithmetic]
           __raw_writeb(value, PCI_IOBASE + addr);
                               ~~~~~~~~~~ ^
   include/asm-generic/io.h:594:59: warning: performing pointer arithmetic on a null pointer has undefined behavior [-Wnull-pointer-arithmetic]
           __raw_writew((u16 __force)cpu_to_le16(value), PCI_IOBASE + addr);
                                                         ~~~~~~~~~~ ^
   include/asm-generic/io.h:604:59: warning: performing pointer arithmetic on a null pointer has undefined behavior [-Wnull-pointer-arithmetic]
           __raw_writel((u32 __force)cpu_to_le32(value), PCI_IOBASE + addr);
                                                         ~~~~~~~~~~ ^
>> drivers/connector/cn_proc.c:51:5: warning: no previous prototype for function 'cn_filter' [-Wmissing-prototypes]
   int cn_filter(struct sock *dsk, struct sk_buff *skb, void *data)
       ^
   drivers/connector/cn_proc.c:51:1: note: declare 'static' if the function is not intended to be used outside of this translation unit
   int cn_filter(struct sock *dsk, struct sk_buff *skb, void *data)
   ^
   static 
   7 warnings generated.


vim +/cn_filter +51 drivers/connector/cn_proc.c

    50	
  > 51	int cn_filter(struct sock *dsk, struct sk_buff *skb, void *data)
    52	{
    53		enum proc_cn_mcast_op mc_op;
    54	
    55		if (!dsk)
    56			return 0;
    57	
    58		mc_op = ((struct proc_input *)(dsk->sk_user_data))->mcast_op;
    59	
    60		if (mc_op == PROC_CN_MCAST_IGNORE)
    61			return 1;
    62	
    63		return 0;
    64	}
    65	EXPORT_SYMBOL_GPL(cn_filter);
    66
  

Patch

diff --git a/drivers/connector/cn_proc.c b/drivers/connector/cn_proc.c
index ccac1c453080..ef3820b43b5c 100644
--- a/drivers/connector/cn_proc.c
+++ b/drivers/connector/cn_proc.c
@@ -48,6 +48,22 @@  static DEFINE_PER_CPU(struct local_event, local_event) = {
 	.lock = INIT_LOCAL_LOCK(lock),
 };
 
+int cn_filter(struct sock *dsk, struct sk_buff *skb, void *data)
+{
+	enum proc_cn_mcast_op mc_op;
+
+	if (!dsk)
+		return 0;
+
+	mc_op = ((struct proc_input *)(dsk->sk_user_data))->mcast_op;
+
+	if (mc_op == PROC_CN_MCAST_IGNORE)
+		return 1;
+
+	return 0;
+}
+EXPORT_SYMBOL_GPL(cn_filter);
+
 static inline void send_msg(struct cn_msg *msg)
 {
 	local_lock(&local_event.lock);
@@ -61,7 +77,8 @@  static inline void send_msg(struct cn_msg *msg)
 	 *
 	 * If cn_netlink_send() fails, the data is not sent.
 	 */
-	cn_netlink_send(msg, 0, CN_IDX_PROC, GFP_NOWAIT);
+	cn_netlink_send_mult(msg, msg->len, 0, CN_IDX_PROC, GFP_NOWAIT,
+			     cn_filter, NULL);
 
 	local_unlock(&local_event.lock);
 }
@@ -346,11 +363,9 @@  static void cn_proc_ack(int err, int rcvd_seq, int rcvd_ack)
 static void cn_proc_mcast_ctl(struct cn_msg *msg,
 			      struct netlink_skb_parms *nsp)
 {
-	enum proc_cn_mcast_op *mc_op = NULL;
-	int err = 0;
-
-	if (msg->len != sizeof(*mc_op))
-		return;
+	enum proc_cn_mcast_op mc_op = 0, prev_mc_op = 0;
+	int err = 0, initial = 0;
+	struct sock *sk = NULL;
 
 	/* 
 	 * Events are reported with respect to the initial pid
@@ -367,13 +382,32 @@  static void cn_proc_mcast_ctl(struct cn_msg *msg,
 		goto out;
 	}
 
-	mc_op = (enum proc_cn_mcast_op *)msg->data;
-	switch (*mc_op) {
+	if (msg->len == sizeof(mc_op))
+		mc_op = *((enum proc_cn_mcast_op *)msg->data);
+	else
+		return;
+
+	if (nsp->sk) {
+		sk = nsp->sk;
+		if (sk->sk_user_data == NULL) {
+			sk->sk_user_data = kzalloc(sizeof(struct proc_input),
+						   GFP_KERNEL);
+			initial = 1;
+		} else {
+			prev_mc_op =
+			((struct proc_input *)(sk->sk_user_data))->mcast_op;
+		}
+		((struct proc_input *)(sk->sk_user_data))->mcast_op = mc_op;
+	}
+
+	switch (mc_op) {
 	case PROC_CN_MCAST_LISTEN:
-		atomic_inc(&proc_event_num_listeners);
+		if (initial || (prev_mc_op != PROC_CN_MCAST_LISTEN))
+			atomic_inc(&proc_event_num_listeners);
 		break;
 	case PROC_CN_MCAST_IGNORE:
-		atomic_dec(&proc_event_num_listeners);
+		if (!initial && (prev_mc_op != PROC_CN_MCAST_IGNORE))
+			atomic_dec(&proc_event_num_listeners);
 		break;
 	default:
 		err = EINVAL;
diff --git a/drivers/connector/connector.c b/drivers/connector/connector.c
index 48ec7ce6ecac..d1179df2b0ba 100644
--- a/drivers/connector/connector.c
+++ b/drivers/connector/connector.c
@@ -59,7 +59,9 @@  static int cn_already_initialized;
  * both, or if both are zero then the group is looked up and sent there.
  */
 int cn_netlink_send_mult(struct cn_msg *msg, u16 len, u32 portid, u32 __group,
-	gfp_t gfp_mask)
+	gfp_t gfp_mask,
+	int (*filter)(struct sock *dsk, struct sk_buff *skb, void *data),
+	void *filter_data)
 {
 	struct cn_callback_entry *__cbq;
 	unsigned int size;
@@ -110,8 +112,9 @@  int cn_netlink_send_mult(struct cn_msg *msg, u16 len, u32 portid, u32 __group,
 	NETLINK_CB(skb).dst_group = group;
 
 	if (group)
-		return netlink_broadcast(dev->nls, skb, portid, group,
-					 gfp_mask);
+		return netlink_broadcast_filtered(dev->nls, skb, portid, group,
+						  gfp_mask, filter,
+						  (void *)filter_data);
 	return netlink_unicast(dev->nls, skb, portid,
 			!gfpflags_allow_blocking(gfp_mask));
 }
@@ -121,7 +124,8 @@  EXPORT_SYMBOL_GPL(cn_netlink_send_mult);
 int cn_netlink_send(struct cn_msg *msg, u32 portid, u32 __group,
 	gfp_t gfp_mask)
 {
-	return cn_netlink_send_mult(msg, msg->len, portid, __group, gfp_mask);
+	return cn_netlink_send_mult(msg, msg->len, portid, __group, gfp_mask,
+				    NULL, NULL);
 }
 EXPORT_SYMBOL_GPL(cn_netlink_send);
 
@@ -162,6 +166,14 @@  static int cn_call_callback(struct sk_buff *skb)
 	return err;
 }
 
+static void cn_release(struct sock *sk, unsigned long *groups)
+{
+	if (groups && test_bit(CN_IDX_PROC - 1, groups)) {
+		kfree(sk->sk_user_data);
+		sk->sk_user_data = NULL;
+	}
+}
+
 /*
  * Main netlink receiving function.
  *
@@ -249,6 +261,7 @@  static int cn_init(void)
 	struct netlink_kernel_cfg cfg = {
 		.groups	= CN_NETLINK_USERS + 0xf,
 		.input	= cn_rx_skb,
+		.release = cn_release,
 	};
 
 	dev->nls = netlink_kernel_create(&init_net, NETLINK_CONNECTOR, &cfg);
diff --git a/drivers/w1/w1_netlink.c b/drivers/w1/w1_netlink.c
index db110cc442b1..691978cddab7 100644
--- a/drivers/w1/w1_netlink.c
+++ b/drivers/w1/w1_netlink.c
@@ -65,7 +65,8 @@  static void w1_unref_block(struct w1_cb_block *block)
 		u16 len = w1_reply_len(block);
 		if (len) {
 			cn_netlink_send_mult(block->first_cn, len,
-				block->portid, 0, GFP_KERNEL);
+					     block->portid, 0,
+					     GFP_KERNEL, NULL, NULL);
 		}
 		kfree(block);
 	}
@@ -83,7 +84,8 @@  static void w1_reply_make_space(struct w1_cb_block *block, u16 space)
 {
 	u16 len = w1_reply_len(block);
 	if (len + space >= block->maxlen) {
-		cn_netlink_send_mult(block->first_cn, len, block->portid, 0, GFP_KERNEL);
+		cn_netlink_send_mult(block->first_cn, len, block->portid,
+				     0, GFP_KERNEL, NULL, NULL);
 		block->first_cn->len = 0;
 		block->cn = NULL;
 		block->msg = NULL;
diff --git a/include/linux/connector.h b/include/linux/connector.h
index 487350bb19c3..cec2d99ae902 100644
--- a/include/linux/connector.h
+++ b/include/linux/connector.h
@@ -90,13 +90,19 @@  void cn_del_callback(const struct cb_id *id);
  *		If @group is not zero, then message will be delivered
  *		to the specified group.
  * @gfp_mask:	GFP mask.
+ * @filter:     Filter function to be used at netlink layer.
+ * @filter_data:Filter data to be supplied to the filter function
  *
  * It can be safely called from softirq context, but may silently
  * fail under strong memory pressure.
  *
  * If there are no listeners for given group %-ESRCH can be returned.
  */
-int cn_netlink_send_mult(struct cn_msg *msg, u16 len, u32 portid, u32 group, gfp_t gfp_mask);
+int cn_netlink_send_mult(struct cn_msg *msg, u16 len, u32 portid,
+			 u32 group, gfp_t gfp_mask,
+			 int (*filter)(struct sock *dsk, struct sk_buff *skb,
+				       void *data),
+			 void *filter_data);
 
 /**
  * cn_netlink_send - Sends message to the specified groups.
diff --git a/include/linux/netlink.h b/include/linux/netlink.h
index 866bbc5a4c8d..05a316aa93b4 100644
--- a/include/linux/netlink.h
+++ b/include/linux/netlink.h
@@ -51,6 +51,7 @@  struct netlink_kernel_cfg {
 	int		(*bind)(struct net *net, int group);
 	void		(*unbind)(struct net *net, int group);
 	bool		(*compare)(struct net *net, struct sock *sk);
+	void		(*release) (struct sock *sk, unsigned long *groups);
 };
 
 struct sock *__netlink_kernel_create(struct net *net, int unit,
diff --git a/include/uapi/linux/cn_proc.h b/include/uapi/linux/cn_proc.h
index db210625cee8..6a06fb424313 100644
--- a/include/uapi/linux/cn_proc.h
+++ b/include/uapi/linux/cn_proc.h
@@ -30,6 +30,30 @@  enum proc_cn_mcast_op {
 	PROC_CN_MCAST_IGNORE = 2
 };
 
+enum proc_cn_event {
+	/* Use successive bits so the enums can be used to record
+	 * sets of events as well
+	 */
+	PROC_EVENT_NONE = 0x00000000,
+	PROC_EVENT_FORK = 0x00000001,
+	PROC_EVENT_EXEC = 0x00000002,
+	PROC_EVENT_UID  = 0x00000004,
+	PROC_EVENT_GID  = 0x00000040,
+	PROC_EVENT_SID  = 0x00000080,
+	PROC_EVENT_PTRACE = 0x00000100,
+	PROC_EVENT_COMM = 0x00000200,
+	/* "next" should be 0x00000400 */
+	/* "last" is the last process event: exit,
+	 * while "next to last" is coredumping event
+	 */
+	PROC_EVENT_COREDUMP = 0x40000000,
+	PROC_EVENT_EXIT = 0x80000000
+};
+
+struct proc_input {
+	enum proc_cn_mcast_op mcast_op;
+};
+
 /*
  * From the user's point of view, the process
  * ID is the thread group ID and thread ID is the internal
@@ -44,24 +68,7 @@  enum proc_cn_mcast_op {
  */
 
 struct proc_event {
-	enum what {
-		/* Use successive bits so the enums can be used to record
-		 * sets of events as well
-		 */
-		PROC_EVENT_NONE = 0x00000000,
-		PROC_EVENT_FORK = 0x00000001,
-		PROC_EVENT_EXEC = 0x00000002,
-		PROC_EVENT_UID  = 0x00000004,
-		PROC_EVENT_GID  = 0x00000040,
-		PROC_EVENT_SID  = 0x00000080,
-		PROC_EVENT_PTRACE = 0x00000100,
-		PROC_EVENT_COMM = 0x00000200,
-		/* "next" should be 0x00000400 */
-		/* "last" is the last process event: exit,
-		 * while "next to last" is coredumping event */
-		PROC_EVENT_COREDUMP = 0x40000000,
-		PROC_EVENT_EXIT = 0x80000000
-	} what;
+	enum proc_cn_event what;
 	__u32 cpu;
 	__u64 __attribute__((aligned(8))) timestamp_ns;
 		/* Number of nano seconds since system boot */
diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c
index 003c7e6ec9be..ad8ec18152cd 100644
--- a/net/netlink/af_netlink.c
+++ b/net/netlink/af_netlink.c
@@ -63,6 +63,7 @@ 
 #include <linux/net_namespace.h>
 #include <linux/nospec.h>
 #include <linux/btf_ids.h>
+#include <linux/connector.h>
 
 #include <net/net_namespace.h>
 #include <net/netns/generic.h>
@@ -677,6 +678,7 @@  static int netlink_create(struct net *net, struct socket *sock, int protocol,
 	struct netlink_sock *nlk;
 	int (*bind)(struct net *net, int group);
 	void (*unbind)(struct net *net, int group);
+	void (*release)(struct sock *sock, unsigned long *groups);
 	int err = 0;
 
 	sock->state = SS_UNCONNECTED;
@@ -704,6 +706,7 @@  static int netlink_create(struct net *net, struct socket *sock, int protocol,
 	cb_mutex = nl_table[protocol].cb_mutex;
 	bind = nl_table[protocol].bind;
 	unbind = nl_table[protocol].unbind;
+	release = nl_table[protocol].release;
 	netlink_unlock_table();
 
 	if (err < 0)
@@ -719,6 +722,7 @@  static int netlink_create(struct net *net, struct socket *sock, int protocol,
 	nlk->module = module;
 	nlk->netlink_bind = bind;
 	nlk->netlink_unbind = unbind;
+	nlk->netlink_release = release;
 out:
 	return err;
 
@@ -763,13 +767,14 @@  static int netlink_release(struct socket *sock)
 	 * OK. Socket is unlinked, any packets that arrive now
 	 * will be purged.
 	 */
+	if (nlk->netlink_release)
+		nlk->netlink_release(sk, nlk->groups);
 
 	/* must not acquire netlink_table_lock in any way again before unbind
 	 * and notifying genetlink is done as otherwise it might deadlock
 	 */
-	if (nlk->netlink_unbind) {
+	if (nlk->netlink_unbind && nlk->groups) {
 		int i;
-
 		for (i = 0; i < nlk->ngroups; i++)
 			if (test_bit(i, nlk->groups))
 				nlk->netlink_unbind(sock_net(sk), i + 1);
@@ -2117,6 +2122,7 @@  __netlink_kernel_create(struct net *net, int unit, struct module *module,
 		if (cfg) {
 			nl_table[unit].bind = cfg->bind;
 			nl_table[unit].unbind = cfg->unbind;
+			nl_table[unit].release = cfg->release;
 			nl_table[unit].flags = cfg->flags;
 			if (cfg->compare)
 				nl_table[unit].compare = cfg->compare;
diff --git a/net/netlink/af_netlink.h b/net/netlink/af_netlink.h
index 5f454c8de6a4..054335a34804 100644
--- a/net/netlink/af_netlink.h
+++ b/net/netlink/af_netlink.h
@@ -42,6 +42,8 @@  struct netlink_sock {
 	void			(*netlink_rcv)(struct sk_buff *skb);
 	int			(*netlink_bind)(struct net *net, int group);
 	void			(*netlink_unbind)(struct net *net, int group);
+	void			(*netlink_release)(struct sock *sk,
+						   unsigned long *groups);
 	struct module		*module;
 
 	struct rhash_head	node;
@@ -65,6 +67,8 @@  struct netlink_table {
 	int			(*bind)(struct net *net, int group);
 	void			(*unbind)(struct net *net, int group);
 	bool			(*compare)(struct net *net, struct sock *sock);
+	void			(*release)(struct sock *sk,
+					   unsigned long *groups);
 	int			registered;
 };