[RFC,4/9] ntsync: Introduce NTSYNC_IOC_PUT_SEM.

Message ID 20240124004028.16826-5-zfigura@codeweavers.com
State New
Headers
Series NT synchronization primitive driver |

Commit Message

Elizabeth Figura Jan. 24, 2024, 12:40 a.m. UTC
  This corresponds to the NT syscall NtReleaseSemaphore().

Signed-off-by: Elizabeth Figura <zfigura@codeweavers.com>
---
 drivers/misc/ntsync.c       | 76 +++++++++++++++++++++++++++++++++++++
 include/uapi/linux/ntsync.h |  2 +
 2 files changed, 78 insertions(+)
  

Comments

Nikolay Borisov Jan. 25, 2024, 8:59 a.m. UTC | #1
On 24.01.24 г. 2:40 ч., Elizabeth Figura wrote:
> This corresponds to the NT syscall NtReleaseSemaphore().
> 
> Signed-off-by: Elizabeth Figura <zfigura@codeweavers.com>
> ---
>   drivers/misc/ntsync.c       | 76 +++++++++++++++++++++++++++++++++++++
>   include/uapi/linux/ntsync.h |  2 +
>   2 files changed, 78 insertions(+)
> 
> diff --git a/drivers/misc/ntsync.c b/drivers/misc/ntsync.c
> index 3287b94be351..d1c91c2a4f1a 100644
> --- a/drivers/misc/ntsync.c
> +++ b/drivers/misc/ntsync.c
> @@ -21,9 +21,11 @@ enum ntsync_type {
>   struct ntsync_obj {
>   	struct rcu_head rhead;
>   	struct kref refcount;
> +	spinlock_t lock;
>   
>   	enum ntsync_type type;
>   
> +	/* The following fields are protected by the object lock. */
>   	union {
>   		struct {
>   			__u32 count;
> @@ -36,6 +38,19 @@ struct ntsync_device {
>   	struct xarray objects;
>   };
>   
> +static struct ntsync_obj *get_obj(struct ntsync_device *dev, __u32 id)
> +{
> +	struct ntsync_obj *obj;
> +
> +	rcu_read_lock();
> +	obj = xa_load(&dev->objects, id);
> +	if (obj && !kref_get_unless_zero(&obj->refcount))
> +		obj = NULL;
> +	rcu_read_unlock();
> +
> +	return obj;
> +}
> +
>   static void destroy_obj(struct kref *ref)
>   {
>   	struct ntsync_obj *obj = container_of(ref, struct ntsync_obj, refcount);
> @@ -48,6 +63,18 @@ static void put_obj(struct ntsync_obj *obj)
>   	kref_put(&obj->refcount, destroy_obj);
>   }
>   
> +static struct ntsync_obj *get_obj_typed(struct ntsync_device *dev, __u32 id,
> +					enum ntsync_type type)
> +{
> +	struct ntsync_obj *obj = get_obj(dev, id);
> +
> +	if (obj && obj->type != type) {
> +		put_obj(obj);
> +		return NULL;
> +	}
> +	return obj;
> +}
> +
>   static int ntsync_char_open(struct inode *inode, struct file *file)
>   {
>   	struct ntsync_device *dev;
> @@ -81,6 +108,7 @@ static int ntsync_char_release(struct inode *inode, struct file *file)
>   static void init_obj(struct ntsync_obj *obj)
>   {
>   	kref_init(&obj->refcount);
> +	spin_lock_init(&obj->lock);
>   }
>   
>   static int ntsync_create_sem(struct ntsync_device *dev, void __user *argp)
> @@ -131,6 +159,52 @@ static int ntsync_delete(struct ntsync_device *dev, void __user *argp)
>   	return 0;
>   }
>   
> +/*
> + * Actually change the semaphore state, returning -EOVERFLOW if it is made
> + * invalid.
> + */
> +static int put_sem_state(struct ntsync_obj *sem, __u32 count)

nit: Just a general observation - those functions that contains the 
specific type in their name could take the exact object i.e struct ntsem 
which will make the code somewhat more clear. Of course, this would mean 
that the struct definition in patch 3 should be changed to also contain 
a tag name.
  

Patch

diff --git a/drivers/misc/ntsync.c b/drivers/misc/ntsync.c
index 3287b94be351..d1c91c2a4f1a 100644
--- a/drivers/misc/ntsync.c
+++ b/drivers/misc/ntsync.c
@@ -21,9 +21,11 @@  enum ntsync_type {
 struct ntsync_obj {
 	struct rcu_head rhead;
 	struct kref refcount;
+	spinlock_t lock;
 
 	enum ntsync_type type;
 
+	/* The following fields are protected by the object lock. */
 	union {
 		struct {
 			__u32 count;
@@ -36,6 +38,19 @@  struct ntsync_device {
 	struct xarray objects;
 };
 
+static struct ntsync_obj *get_obj(struct ntsync_device *dev, __u32 id)
+{
+	struct ntsync_obj *obj;
+
+	rcu_read_lock();
+	obj = xa_load(&dev->objects, id);
+	if (obj && !kref_get_unless_zero(&obj->refcount))
+		obj = NULL;
+	rcu_read_unlock();
+
+	return obj;
+}
+
 static void destroy_obj(struct kref *ref)
 {
 	struct ntsync_obj *obj = container_of(ref, struct ntsync_obj, refcount);
@@ -48,6 +63,18 @@  static void put_obj(struct ntsync_obj *obj)
 	kref_put(&obj->refcount, destroy_obj);
 }
 
+static struct ntsync_obj *get_obj_typed(struct ntsync_device *dev, __u32 id,
+					enum ntsync_type type)
+{
+	struct ntsync_obj *obj = get_obj(dev, id);
+
+	if (obj && obj->type != type) {
+		put_obj(obj);
+		return NULL;
+	}
+	return obj;
+}
+
 static int ntsync_char_open(struct inode *inode, struct file *file)
 {
 	struct ntsync_device *dev;
@@ -81,6 +108,7 @@  static int ntsync_char_release(struct inode *inode, struct file *file)
 static void init_obj(struct ntsync_obj *obj)
 {
 	kref_init(&obj->refcount);
+	spin_lock_init(&obj->lock);
 }
 
 static int ntsync_create_sem(struct ntsync_device *dev, void __user *argp)
@@ -131,6 +159,52 @@  static int ntsync_delete(struct ntsync_device *dev, void __user *argp)
 	return 0;
 }
 
+/*
+ * Actually change the semaphore state, returning -EOVERFLOW if it is made
+ * invalid.
+ */
+static int put_sem_state(struct ntsync_obj *sem, __u32 count)
+{
+	lockdep_assert_held(&sem->lock);
+
+	if (sem->u.sem.count + count < sem->u.sem.count ||
+	    sem->u.sem.count + count > sem->u.sem.max)
+		return -EOVERFLOW;
+
+	sem->u.sem.count += count;
+	return 0;
+}
+
+static int ntsync_put_sem(struct ntsync_device *dev, void __user *argp)
+{
+	struct ntsync_sem_args __user *user_args = argp;
+	struct ntsync_sem_args args;
+	struct ntsync_obj *sem;
+	__u32 prev_count;
+	int ret;
+
+	if (copy_from_user(&args, argp, sizeof(args)))
+		return -EFAULT;
+
+	sem = get_obj_typed(dev, args.sem, NTSYNC_TYPE_SEM);
+	if (!sem)
+		return -EINVAL;
+
+	spin_lock(&sem->lock);
+
+	prev_count = sem->u.sem.count;
+	ret = put_sem_state(sem, args.count);
+
+	spin_unlock(&sem->lock);
+
+	put_obj(sem);
+
+	if (!ret && put_user(prev_count, &user_args->count))
+		ret = -EFAULT;
+
+	return ret;
+}
+
 static long ntsync_char_ioctl(struct file *file, unsigned int cmd,
 			      unsigned long parm)
 {
@@ -142,6 +216,8 @@  static long ntsync_char_ioctl(struct file *file, unsigned int cmd,
 		return ntsync_create_sem(dev, argp);
 	case NTSYNC_IOC_DELETE:
 		return ntsync_delete(dev, argp);
+	case NTSYNC_IOC_PUT_SEM:
+		return ntsync_put_sem(dev, argp);
 	default:
 		return -ENOIOCTLCMD;
 	}
diff --git a/include/uapi/linux/ntsync.h b/include/uapi/linux/ntsync.h
index d97afc138dcc..8c610d65f8ef 100644
--- a/include/uapi/linux/ntsync.h
+++ b/include/uapi/linux/ntsync.h
@@ -21,5 +21,7 @@  struct ntsync_sem_args {
 #define NTSYNC_IOC_CREATE_SEM		_IOWR(NTSYNC_IOC_BASE, 0, \
 					      struct ntsync_sem_args)
 #define NTSYNC_IOC_DELETE		_IOW (NTSYNC_IOC_BASE, 1, __u32)
+#define NTSYNC_IOC_PUT_SEM		_IOWR(NTSYNC_IOC_BASE, 2, \
+					      struct ntsync_sem_args)
 
 #endif