cstacklocklessrcuaba

Lock Free stack implementation idea - currently broken


I came up with an idea I am trying to implement for a lock free stack that does not rely on reference counting to resolve the ABA problem, and also handles memory reclamation properly. It is similar in concept to RCU, and relies on two features: marking a list entry as removed, and tracking readers traversing the list. The former is simple, it just uses the LSB of the pointer. The latter is my "clever" attempt at an approach to implementing an unbounded lock free stack.

Basically, when any thread attempts to traverse the list, one atomic counter (list.entries) is incremented. When the traversal is complete, a second counter (list.exits) is incremented.

Node allocation is handled by push, and deallocation is handled by pop.

The push and pop operations are fairly similar to the naive lock-free stack implementation, but the nodes marked for removal must be traversed to arrive at a non-marked entry. Push basically is therefore much like a linked list insertion.

The pop operation similarly traverses the list, but it uses atomic_fetch_or to mark the nodes as removed while traversing, until it reaches a non-marked node.

After traversing the list of 0 or more marked nodes, a thread that is popping will attempt to CAS the head of the stack. At least one thread concurrently popping will succeed, and after this point all readers entering the stack will no longer see the formerly marked nodes.

The thread that successfully updates the list then loads the atomic list.entries, and basically spin-loads atomic.exits until that counter finally exceeds list.entries. This should imply that all readers of the "old" version of the list have completed. The thread then simply frees the the list of marked nodes that it swapped off the top of the list.

So the implications from the pop operation should be (I think) that there can be no ABA problem because the nodes that are freed are not returned to the usable pool of pointers until all concurrent readers using them have completed, and obviously the memory reclamation issue is handled as well, for the same reason.

So anyhow, that is theory, but I'm still scratching my head on the implementation, because it is currently not working (in the multithreaded case). It seems like I am getting some write after free issues among other things, but I'm having trouble spotting the issue, or maybe my assumptions are flawed and it just won't work.

Any insights would be greatly appreciated, both on the concept, and on approaches to debugging the code.

Here is my current (broken) code (compile with gcc -D_GNU_SOURCE -std=c11 -Wall -O0 -g -pthread -o list list.c):

#include <pthread.h>
#include <stdatomic.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdlib.h>

#include <sys/resource.h>

#include <stdio.h>
#include <unistd.h>

#define NUM_THREADS 8
#define NUM_OPS (1024 * 1024)

typedef uint64_t list_data_t;

typedef struct list_node_t {
    struct list_node_t * _Atomic next;
    list_data_t data;
} list_node_t;

typedef struct {
    list_node_t * _Atomic head;
    int64_t _Atomic size;
    uint64_t _Atomic entries;
    uint64_t _Atomic exits;
} list_t;

enum {
    NODE_IDLE    = (0x0),
    NODE_REMOVED = (0x1 << 0),
    NODE_FREED   = (0x1 << 1),
    NODE_FLAGS    = (0x3),
};

static __thread struct {
    uint64_t add_count;
    uint64_t remove_count;
    uint64_t added;
    uint64_t removed;
    uint64_t mallocd;
    uint64_t freed;
} stats;

#define NODE_IS_SET(p, f) (((uintptr_t)p & f) == f)
#define NODE_SET_FLAG(p, f) ((void *)((uintptr_t)p | f))
#define NODE_CLR_FLAG(p, f) ((void *)((uintptr_t)p & ~f))
#define NODE_POINTER(p) ((void *)((uintptr_t)p & ~NODE_FLAGS))

list_node_t * list_node_new(list_data_t data)
{
    list_node_t * new = malloc(sizeof(*new));
    new->data = data;
    stats.mallocd++;

    return new;
}

void list_node_free(list_node_t * node)
{
    free(node);
    stats.freed++;
}

static void list_add(list_t * list, list_data_t data)
{
    atomic_fetch_add_explicit(&list->entries, 1, memory_order_seq_cst);

    list_node_t * new = list_node_new(data);
    list_node_t * _Atomic * next = &list->head;
    list_node_t * current = atomic_load_explicit(next,  memory_order_seq_cst);
    do
    {
        stats.add_count++;
        while ((NODE_POINTER(current) != NULL) &&
                NODE_IS_SET(current, NODE_REMOVED))
        {
                stats.add_count++;
                current = NODE_POINTER(current);
                next = &current->next;
                current = atomic_load_explicit(next, memory_order_seq_cst);
        }
        atomic_store_explicit(&new->next, current, memory_order_seq_cst);
    }
    while(!atomic_compare_exchange_weak_explicit(
            next, &current, new,
            memory_order_seq_cst, memory_order_seq_cst));

    atomic_fetch_add_explicit(&list->exits, 1, memory_order_seq_cst);
    atomic_fetch_add_explicit(&list->size, 1, memory_order_seq_cst);
    stats.added++;
}

static bool list_remove(list_t * list, list_data_t * pData)
{
    uint64_t entries = atomic_fetch_add_explicit(
            &list->entries, 1, memory_order_seq_cst);

    list_node_t * start = atomic_fetch_or_explicit(
            &list->head, NODE_REMOVED, memory_order_seq_cst);
    list_node_t * current = start;

    stats.remove_count++;
    while ((NODE_POINTER(current) != NULL) &&
            NODE_IS_SET(current, NODE_REMOVED))
    {
        stats.remove_count++;
        current = NODE_POINTER(current);
        current = atomic_fetch_or_explicit(&current->next,
                NODE_REMOVED, memory_order_seq_cst);
    }

    uint64_t exits = atomic_fetch_add_explicit(
            &list->exits, 1, memory_order_seq_cst) + 1;

    bool result = false;
    current = NODE_POINTER(current);
    if (current != NULL)
    {
        result = true;
        *pData = current->data;

        current = atomic_load_explicit(
                &current->next, memory_order_seq_cst);

        atomic_fetch_add_explicit(&list->size,
                -1, memory_order_seq_cst);

        stats.removed++;
    }

    start = NODE_SET_FLAG(start, NODE_REMOVED);
    if (atomic_compare_exchange_strong_explicit(
            &list->head, &start, current,
            memory_order_seq_cst, memory_order_seq_cst))
    {
        entries = atomic_load_explicit(&list->entries, memory_order_seq_cst);
        while ((int64_t)(entries - exits) > 0)
        {
            pthread_yield();
            exits = atomic_load_explicit(&list->exits, memory_order_seq_cst);
        }

        list_node_t * end = NODE_POINTER(current);
        list_node_t * current = NODE_POINTER(start);
        while (current != end)
        {
            list_node_t * tmp = current;
            current = atomic_load_explicit(&current->next, memory_order_seq_cst);
            list_node_free(tmp);
            current = NODE_POINTER(current);
        }
    }

    return result;
}

static list_t list;

pthread_mutex_t ioLock = PTHREAD_MUTEX_INITIALIZER;

void * thread_entry(void * arg)
{
    sleep(2);
    int id = *(int *)arg;

    for (int i = 0; i < NUM_OPS; i++)
    {
        bool insert = random() % 2;

        if (insert)
        {
            list_add(&list, i);
        }
        else
        {
            list_data_t data;
            list_remove(&list, &data);
        }
    }

    struct rusage u;
    getrusage(RUSAGE_THREAD, &u);

    pthread_mutex_lock(&ioLock);
    printf("Thread %d stats:\n", id);
    printf("\tadded = %lu\n", stats.added);
    printf("\tremoved = %lu\n", stats.removed);
    printf("\ttotal added = %ld\n", (int64_t)(stats.added - stats.removed));
    printf("\tadded count = %lu\n", stats.add_count);
    printf("\tremoved count = %lu\n", stats.remove_count);
    printf("\tadd average = %f\n", (float)stats.add_count / stats.added);
    printf("\tremove average = %f\n", (float)stats.remove_count / stats.removed);
    printf("\tmallocd = %lu\n", stats.mallocd);
    printf("\tfreed = %lu\n", stats.freed);
    printf("\ttotal mallocd = %ld\n", (int64_t)(stats.mallocd - stats.freed));
    printf("\tutime = %f\n", u.ru_utime.tv_sec
            + u.ru_utime.tv_usec / 1000000.0f);
    printf("\tstime = %f\n", u.ru_stime.tv_sec
                    + u.ru_stime.tv_usec / 1000000.0f);
    pthread_mutex_unlock(&ioLock);

    return NULL;
}

int main(int argc, char ** argv)
{
    struct {
            pthread_t thread;
            int id;
    }
    threads[NUM_THREADS];
    for (int i = 0; i < NUM_THREADS; i++)
    {
        threads[i].id = i;
        pthread_create(&threads[i].thread, NULL, thread_entry, &threads[i].id);
    }

    for (int i = 0; i < NUM_THREADS; i++)
    {
        pthread_join(threads[i].thread, NULL);
    }

    printf("Size = %ld\n", atomic_load(&list.size));

    uint32_t count = 0;

    list_data_t data;
    while(list_remove(&list, &data))
    {
        count++;
    }
    printf("Removed %u\n", count);
}

Solution

  • You mention you are trying to solve the ABA problem, but the description and code is actually an attempt to solve a harder problem: the memory reclamation problem.

    This problem typically arises in the "deletion" functionality of lock-free collections implemented in languages without garbage collection. The core issue is that a thread removing a node from a shared structure often doesn't know when it is safe to free the removed node as because other reads may still have a reference to it. Solving this problem often, as a side effect, also solves the ABA problem: which is specifically about a CAS operation succeeding even though the underlying pointer (and state of the object) has been been changed at least twice in the meantime, ending up with the original value but presenting a totally different state.

    The ABA problem is easier in the sense that there are several straightforward solutions to the ABA problem specifically that don't lead to a solution to the "memory reclamation" problem. It is also easier in the sense that hardware that can detect the modification of the location, e.g., with LL/SC or transactional memory primitives, might not exhibit the problem at all.

    So that said, you are hunting for a solution to the memory reclamation problem, and it will also avoid the ABA problem.

    The core of your issue is this statement:

    The thread that successfully updates the list then loads the atomic list.entries, and basically spin-loads atomic.exits until that counter finally exceeds list.entries. This should imply that all readers of the "old" version of the list have completed. The thread then simply frees the the list of marked nodes that it swapped off the top of the list.

    This logic doesn't hold. Waiting for list.exits (you say atomic.exits but I think it's a typo as you only talk about list.exits elsewhere) to be greater than list.entries only tells you there have now been more total exits than there were entries at the point the mutating thread captured the entry count. However, these exits may have been generated by new readers coming and going: it doesn't at all imply that all the old readers have finished as you claim!

    Here's a simple example. First a writing thread T1 and a reading thread T2 access the list around the same time, so list.entries is 2 and list.exits is 0. The writing thread pops an node, and saves the current value (2) of list.entries and waits for lists.exits to be greater than 2. Now three more reading threads, T3, T4, T5 arrive and do a quick read of the list and leave. Now lists.exits is 3, and your condition is met and T1 frees the node. T2 hasn't gone anywhere though and blows up since it is reading a freed node!

    The basic idea you have can work, but your two counter approach in particular definitely doesn't work.

    This is a well-studied problem, so you don't have to invent your own algorithm (see the link above), or even write your own code since things like librcu and concurrencykit already exist.

    For Educational Purposes

    If you wanted to make this work for educational purposes though, one approach would be to use ensure that threads coming in after a modification have started use a different set of list.entry/exit counters. One way to do this would be a generation counter, and when the writer wants to modify the list, it increments the generation counter, which causes new readers to switch to a different set of list.entry/exit counters.

    Now the writer just has to wait for list.entry[old] == list.exists[old], which means all the old readers have left. You could also just get away with a single counter per generation: you don't really two entry/exit counters (although it might help reduce contention).

    Of course, you know have a new problem of managing this list of separate counters per generation... which kind of looks like the original problem of building a lock-free list! This problem is a bit easier though because you might put some reasonable bound on the number of generations "in flight" and just allocate them all up-front, or you might implement a limited type of lock-free list that is easier to reason about because additions and deletions only occur at the head or tail.