@@ -30,15 +30,15 @@ struct ksmbd_session_rpc {
static void free_channel_list(struct ksmbd_session *sess)
{
- struct channel *chann, *tmp;
+ struct channel *chann;
+ unsigned long index;
- write_lock(&sess->chann_lock);
- list_for_each_entry_safe(chann, tmp, &sess->ksmbd_chann_list,
- chann_list) {
- list_del(&chann->chann_list);
+ xa_for_each(&sess->ksmbd_chann_list, index, chann) {
+ xa_erase(&sess->ksmbd_chann_list, index);
kfree(chann);
}
- write_unlock(&sess->chann_lock);
+
+ xa_destroy(&sess->ksmbd_chann_list);
}
static void __session_rpc_close(struct ksmbd_session *sess,
@@ -190,21 +190,15 @@ int ksmbd_session_register(struct ksmbd_conn *conn,
static int ksmbd_chann_del(struct ksmbd_conn *conn, struct ksmbd_session *sess)
{
- struct channel *chann, *tmp;
-
- write_lock(&sess->chann_lock);
- list_for_each_entry_safe(chann, tmp, &sess->ksmbd_chann_list,
- chann_list) {
- if (chann->conn == conn) {
- list_del(&chann->chann_list);
- kfree(chann);
- write_unlock(&sess->chann_lock);
- return 0;
- }
- }
- write_unlock(&sess->chann_lock);
+ struct channel *chann;
+
+ chann = xa_erase(&sess->ksmbd_chann_list, (long)conn);
+ if (!chann)
+ return -ENOENT;
- return -ENOENT;
+ kfree(chann);
+
+ return 0;
}
void ksmbd_sessions_deregister(struct ksmbd_conn *conn)
@@ -234,7 +228,7 @@ void ksmbd_sessions_deregister(struct ksmbd_conn *conn)
return;
sess_destroy:
- if (list_empty(&sess->ksmbd_chann_list)) {
+ if (xa_empty(&sess->ksmbd_chann_list)) {
xa_erase(&conn->sessions, sess->id);
ksmbd_session_destroy(sess);
}
@@ -320,6 +314,9 @@ static struct ksmbd_session *__session_create(int protocol)
struct ksmbd_session *sess;
int ret;
+ if (protocol != CIFDS_SESSION_FLAG_SMB2)
+ return NULL;
+
sess = kzalloc(sizeof(struct ksmbd_session), GFP_KERNEL);
if (!sess)
return NULL;
@@ -329,30 +326,20 @@ static struct ksmbd_session *__session_create(int protocol)
set_session_flag(sess, protocol);
xa_init(&sess->tree_conns);
- INIT_LIST_HEAD(&sess->ksmbd_chann_list);
+ xa_init(&sess->ksmbd_chann_list);
INIT_LIST_HEAD(&sess->rpc_handle_list);
sess->sequence_number = 1;
- rwlock_init(&sess->chann_lock);
-
- switch (protocol) {
- case CIFDS_SESSION_FLAG_SMB2:
- ret = __init_smb2_session(sess);
- break;
- default:
- ret = -EINVAL;
- break;
- }
+ ret = __init_smb2_session(sess);
if (ret)
goto error;
ida_init(&sess->tree_conn_ida);
- if (protocol == CIFDS_SESSION_FLAG_SMB2) {
- down_write(&sessions_table_lock);
- hash_add(sessions_table, &sess->hlist, sess->id);
- up_write(&sessions_table_lock);
- }
+ down_write(&sessions_table_lock);
+ hash_add(sessions_table, &sess->hlist, sess->id);
+ up_write(&sessions_table_lock);
+
return sess;
error:
@@ -21,7 +21,6 @@ struct ksmbd_file_table;
struct channel {
__u8 smb3signingkey[SMB3_SIGN_KEY_SIZE];
struct ksmbd_conn *conn;
- struct list_head chann_list;
};
struct preauth_session {
@@ -50,8 +49,7 @@ struct ksmbd_session {
char sess_key[CIFS_KEY_SIZE];
struct hlist_node hlist;
- rwlock_t chann_lock;
- struct list_head ksmbd_chann_list;
+ struct xarray ksmbd_chann_list;
struct xarray tree_conns;
struct ida tree_conn_ida;
struct list_head rpc_handle_list;
@@ -74,14 +74,7 @@ static inline bool check_session_id(struct ksmbd_conn *conn, u64 id)
struct channel *lookup_chann_list(struct ksmbd_session *sess, struct ksmbd_conn *conn)
{
- struct channel *chann;
-
- list_for_each_entry(chann, &sess->ksmbd_chann_list, chann_list) {
- if (chann->conn == conn)
- return chann;
- }
-
- return NULL;
+ return xa_load(&sess->ksmbd_chann_list, (long)conn);
}
/**
@@ -595,6 +588,7 @@ static void destroy_previous_session(struct ksmbd_conn *conn,
struct ksmbd_session *prev_sess = ksmbd_session_lookup_slowpath(id);
struct ksmbd_user *prev_user;
struct channel *chann;
+ long index;
if (!prev_sess)
return;
@@ -608,10 +602,8 @@ static void destroy_previous_session(struct ksmbd_conn *conn,
return;
prev_sess->state = SMB2_SESSION_EXPIRED;
- write_lock(&prev_sess->chann_lock);
- list_for_each_entry(chann, &prev_sess->ksmbd_chann_list, chann_list)
+ xa_for_each(&prev_sess->ksmbd_chann_list, index, chann)
chann->conn->status = KSMBD_SESS_EXITING;
- write_unlock(&prev_sess->chann_lock);
}
/**
@@ -1519,19 +1511,14 @@ static int ntlm_authenticate(struct ksmbd_work *work)
binding_session:
if (conn->dialect >= SMB30_PROT_ID) {
- read_lock(&sess->chann_lock);
chann = lookup_chann_list(sess, conn);
- read_unlock(&sess->chann_lock);
if (!chann) {
chann = kmalloc(sizeof(struct channel), GFP_KERNEL);
if (!chann)
return -ENOMEM;
chann->conn = conn;
- INIT_LIST_HEAD(&chann->chann_list);
- write_lock(&sess->chann_lock);
- list_add(&chann->chann_list, &sess->ksmbd_chann_list);
- write_unlock(&sess->chann_lock);
+ xa_store(&sess->ksmbd_chann_list, (long)conn, chann, GFP_KERNEL);
}
}
@@ -1606,19 +1593,14 @@ static int krb5_authenticate(struct ksmbd_work *work)
}
if (conn->dialect >= SMB30_PROT_ID) {
- read_lock(&sess->chann_lock);
chann = lookup_chann_list(sess, conn);
- read_unlock(&sess->chann_lock);
if (!chann) {
chann = kmalloc(sizeof(struct channel), GFP_KERNEL);
if (!chann)
return -ENOMEM;
chann->conn = conn;
- INIT_LIST_HEAD(&chann->chann_list);
- write_lock(&sess->chann_lock);
- list_add(&chann->chann_list, &sess->ksmbd_chann_list);
- write_unlock(&sess->chann_lock);
+ xa_store(&sess->ksmbd_chann_list, (long)conn, chann, GFP_KERNEL);
}
}
@@ -8409,14 +8391,11 @@ int smb3_check_sign_req(struct ksmbd_work *work)
if (le16_to_cpu(hdr->Command) == SMB2_SESSION_SETUP_HE) {
signing_key = work->sess->smb3signingkey;
} else {
- read_lock(&work->sess->chann_lock);
chann = lookup_chann_list(work->sess, conn);
if (!chann) {
- read_unlock(&work->sess->chann_lock);
return 0;
}
signing_key = chann->smb3signingkey;
- read_unlock(&work->sess->chann_lock);
}
if (!signing_key) {
@@ -8476,14 +8455,11 @@ void smb3_set_sign_rsp(struct ksmbd_work *work)
le16_to_cpu(hdr->Command) == SMB2_SESSION_SETUP_HE) {
signing_key = work->sess->smb3signingkey;
} else {
- read_lock(&work->sess->chann_lock);
chann = lookup_chann_list(work->sess, work->conn);
if (!chann) {
- read_unlock(&work->sess->chann_lock);
return;
}
signing_key = chann->smb3signingkey;
- read_unlock(&work->sess->chann_lock);
}
if (!signing_key)