[RFC,9/9] ntsync: Introduce NTSYNC_IOC_KILL_OWNER.

Message ID 20240124004028.16826-10-zfigura@codeweavers.com
State New
Headers
Series NT synchronization primitive driver |

Commit Message

Elizabeth Figura Jan. 24, 2024, 12:40 a.m. UTC
  This does not correspond to any NT syscall, but rather should be called by the
user-space NT emulator when a thread dies. It is responsible for marking any
mutexes owned by that thread as abandoned.

Signed-off-by: Elizabeth Figura <zfigura@codeweavers.com>
---
 drivers/misc/ntsync.c       | 80 ++++++++++++++++++++++++++++++++++++-
 include/uapi/linux/ntsync.h |  1 +
 2 files changed, 79 insertions(+), 2 deletions(-)
  

Patch

diff --git a/drivers/misc/ntsync.c b/drivers/misc/ntsync.c
index 28f43768d1c3..1173c750c106 100644
--- a/drivers/misc/ntsync.c
+++ b/drivers/misc/ntsync.c
@@ -64,6 +64,7 @@  struct ntsync_obj {
 		struct {
 			__u32 count;
 			__u32 owner;
+			bool ownerdead;
 		} mutex;
 	} u;
 };
@@ -87,6 +88,7 @@  struct ntsync_q {
 	atomic_t signaled;
 
 	bool all;
+	bool ownerdead;
 	__u32 count;
 	struct ntsync_q_entry entries[];
 };
@@ -240,6 +242,9 @@  static void try_wake_all(struct ntsync_device *dev, struct ntsync_q *q,
 				obj->u.sem.count--;
 				break;
 			case NTSYNC_TYPE_MUTEX:
+				if (obj->u.mutex.ownerdead)
+					q->ownerdead = true;
+				obj->u.mutex.ownerdead = false;
 				obj->u.mutex.count++;
 				obj->u.mutex.owner = q->owner;
 				break;
@@ -299,6 +304,9 @@  static void try_wake_any_mutex(struct ntsync_obj *mutex)
 			continue;
 
 		if (atomic_cmpxchg(&q->signaled, -1, entry->index) == -1) {
+			if (mutex->u.mutex.ownerdead)
+				q->ownerdead = true;
+			mutex->u.mutex.ownerdead = false;
 			mutex->u.mutex.count++;
 			mutex->u.mutex.owner = q->owner;
 			wake_up_process(q->task);
@@ -514,6 +522,71 @@  static int ntsync_put_mutex(struct ntsync_device *dev, void __user *argp)
 	return ret;
 }
 
+/*
+ * Actually change the mutex state to mark its owner as dead.
+ */
+static void put_mutex_ownerdead_state(struct ntsync_obj *mutex)
+{
+	lockdep_assert_held(&mutex->lock);
+
+	mutex->u.mutex.ownerdead = true;
+	mutex->u.mutex.owner = 0;
+	mutex->u.mutex.count = 0;
+}
+
+static int ntsync_kill_owner(struct ntsync_device *dev, void __user *argp)
+{
+	struct ntsync_obj *obj;
+	unsigned long id;
+	__u32 owner;
+
+	if (get_user(owner, (__u32 __user *)argp))
+		return -EFAULT;
+	if (!owner)
+		return -EINVAL;
+
+	rcu_read_lock();
+
+	xa_for_each(&dev->objects, id, obj) {
+		if (!kref_get_unless_zero(&obj->refcount))
+			continue;
+
+		if (obj->type != NTSYNC_TYPE_MUTEX) {
+			put_obj(obj);
+			continue;
+		}
+
+		if (atomic_read(&obj->all_hint) > 0) {
+			spin_lock(&dev->wait_all_lock);
+			spin_lock_nest_lock(&obj->lock, &dev->wait_all_lock);
+
+			if (obj->u.mutex.owner == owner) {
+				put_mutex_ownerdead_state(obj);
+				try_wake_all_obj(dev, obj);
+				try_wake_any_mutex(obj);
+			}
+
+			spin_unlock(&obj->lock);
+			spin_unlock(&dev->wait_all_lock);
+		} else {
+			spin_lock(&obj->lock);
+
+			if (obj->u.mutex.owner == owner) {
+				put_mutex_ownerdead_state(obj);
+				try_wake_any_mutex(obj);
+			}
+
+			spin_unlock(&obj->lock);
+		}
+
+		put_obj(obj);
+	}
+
+	rcu_read_unlock();
+
+	return 0;
+}
+
 static int ntsync_schedule(const struct ntsync_q *q, ktime_t *timeout)
 {
 	int ret = 0;
@@ -585,6 +658,7 @@  static int setup_wait(struct ntsync_device *dev,
 	q->owner = args->owner;
 	atomic_set(&q->signaled, -1);
 	q->all = all;
+	q->ownerdead = false;
 	q->count = count;
 
 	for (i = 0; i < count; i++) {
@@ -697,7 +771,7 @@  static int ntsync_wait_any(struct ntsync_device *dev, void __user *argp)
 		struct ntsync_wait_args __user *user_args = argp;
 
 		/* even if we caught a signal, we need to communicate success */
-		ret = 0;
+		ret = q->ownerdead ? -EOWNERDEAD : 0;
 
 		if (put_user(signaled, &user_args->index))
 			ret = -EFAULT;
@@ -778,7 +852,7 @@  static int ntsync_wait_all(struct ntsync_device *dev, void __user *argp)
 		struct ntsync_wait_args __user *user_args = argp;
 
 		/* even if we caught a signal, we need to communicate success */
-		ret = 0;
+		ret = q->ownerdead ? -EOWNERDEAD : 0;
 
 		if (put_user(signaled, &user_args->index))
 			ret = -EFAULT;
@@ -803,6 +877,8 @@  static long ntsync_char_ioctl(struct file *file, unsigned int cmd,
 		return ntsync_create_sem(dev, argp);
 	case NTSYNC_IOC_DELETE:
 		return ntsync_delete(dev, argp);
+	case NTSYNC_IOC_KILL_OWNER:
+		return ntsync_kill_owner(dev, argp);
 	case NTSYNC_IOC_PUT_MUTEX:
 		return ntsync_put_mutex(dev, argp);
 	case NTSYNC_IOC_PUT_SEM:
diff --git a/include/uapi/linux/ntsync.h b/include/uapi/linux/ntsync.h
index 2e44e7e77776..fec9a3993322 100644
--- a/include/uapi/linux/ntsync.h
+++ b/include/uapi/linux/ntsync.h
@@ -48,5 +48,6 @@  struct ntsync_wait_args {
 					      struct ntsync_mutex_args)
 #define NTSYNC_IOC_PUT_MUTEX		_IOWR(NTSYNC_IOC_BASE, 6, \
 					      struct ntsync_mutex_args)
+#define NTSYNC_IOC_KILL_OWNER		_IOW (NTSYNC_IOC_BASE, 7, __u32)
 
 #endif