C++ 并发编程

274 天前
 YuanJiwei

有朋友熟悉 C++ 编程吗?能帮我看看这里有什么 bug 吗(更新两个 atomic)

目的是实现一个 Multiset(Non-blocking)

template <typename T>
class CMSet {
    bool contains(T element);
    int count(T element);
    void add(T element);
    bool remove(T element);
};

template <typename T> struct ANode {
    T data;
    std::atomic<int> count{};
    std::atomic<ANode<T>*> next;
    size_t key{};

    ANode() : next(nullptr) {};
    explicit ANode(T data) : data(data), next(nullptr) {};
    explicit ANode(T data, size_t key) : data(data), key(key), next(nullptr) {
        this->count.store(1);
    };
};


template<typename T>
class CMSetWithNonBlocking : public CMSet<T> {
private:
    atomic<ANode<T>*> head;
public:
    CMSetWithNonBlocking () {
        auto* h = new ANode<T>;
        h->key = std::numeric_limits<size_t>::min();
        auto* t = new ANode<T>;
        t->key = std::numeric_limits<size_t>::max();
        head.store(h);
        head.load()->next.store(t);
    }

    bool contains(T data) {
        size_t key = std::hash<T>{}(data);
        ANode<T>* curr = head.load()->next.load();
        while (curr->key < key) {
            curr = curr->next.load();
        }
        return curr->key == key && curr->count.load() > 0;
    }

    int count(T data) {
        size_t key = std::hash<T>{}(data);
        ANode<T>* curr = head.load()->next.load();
        while (curr->key < key) {
            curr = curr->next.load();
        }
        if (curr->key == key) {
            return curr->count.load();
        }
        return 0;
    }

    void add(T data) {
        size_t key = std::hash<T>{}(data);
        auto* newNode = new ANode<T>(data, key);
        while (true) {
            ANode<T>* prev = head.load();
            ANode<T>* curr = prev->next.load();

            while (curr->key < key) {
                prev = curr;
                curr = curr->next.load();
            }

            if (curr->key == key) {
                // If node with key exists, increment count atomically
                int oldCount = curr->count.load();
                if (curr->count.compare_exchange_weak(oldCount, oldCount + 1)) {
                    delete newNode; // newNode not needed, delete it
                    return;
                }
            } else {
                // Insert new node
                newNode->next.store(curr);
                if (prev->next.compare_exchange_weak(curr, newNode)) {
                    return; // Successfully added
                }
            }
        }
    }

    bool remove(T data) {
        size_t key = std::hash<T>{}(data);
        while (true) {
            ANode<T>* prev = head.load();
            ANode<T>* curr = prev->next.load();

            while (curr->key < key) {
                prev = curr;
                curr = curr->next.load();
            }

            if (curr->key == key) {
                int oldCount = curr->count.load();
                if (oldCount == 0) {
                    return false; // Node already removed or never added
                }
                if (curr->count.compare_exchange_weak(oldCount, oldCount - 1)) {
                    if (oldCount - 1 == 0) {
                        // If count decremented to 0, remove node
                        ANode<T>* next = curr->next.load();
                        if (!prev->next.compare_exchange_weak(curr, next)) {
                            continue; // CAS failed, retry
                        } else {
                            delete curr; // Node removed, delete it
                        }
                    }
                    return true; // Count decremented successfully
                }
                // If CAS fails, loop will retry
            } else {
                return false; // Node with key not found
            }
        }
    }
};
1029 次点击
所在节点    问与答
3 条回复
leonshaw
274 天前
key 是 data 的 hash ,那一个 key 只对应一个 data ?
delete node 的时候,另一个线程可能正在遍历?
remove() 里更新指针 CAS failed ,但是 curr->count 已经减了。

感觉这种场景要么锁,要么用类似 RCU 的机制,单独 atomic 做不了。
GeruzoniAnsasu
274 天前
这…… 要吐槽的地方恐怕有点多

首先你的 CMSet 看起来是作为接口用的,但它并不是个抽象类
然后在 derived 类中,由于你要使用多态特性且试图自己管理裸指针,那么你必然要实现虚析构并满足 rule of three

然后你在计算 key 时试图直接使用 std::hash<T> ,但没有对 T 做任何 constraints/traits ,这会导致实际使用这个类时几乎必然失败(你只用这个 set 来存整数?)

构造和 add 函数缺乏模板转发 ( T&& t; std::forward<T>(t) ) 不过考虑到上一条,也许这个类确实不需要转发

再然后你试图用 atomic 来「以某种想象的方式保证一致性」—— 也是错的,在你的查找函数里有一个链表遍历,在并发条件下无锁遍历===不同步(你有没有考虑……正在附加上去的节点被脱链了?)另外你也忘了读写 count 和 delete 节是两个不同时点,很可能发生 A 读 count==1; B 增加 count 到 2; A 将节点删除 这样的时序。( count 不可交换,只能 atomic 增减)

析构函数完全没考虑——你是想让使用者保证这个对象一定不会在还有元素的时候销毁?



我的建议是
- 忘记 c++有 new 这个关键字,只有 make_shared 和 make_unique
- 只要函数参数是个模板,就记得想想能不能套 std::forward
- 重修 E 系列的书
- 数据结构中包含链表时,老老实实加锁
zouzou0208
273 天前
@GeruzoniAnsasu 学习了,谢谢,感谢这么认真的回答。

这是一个专为移动设备优化的页面(即为了让你能够在 Google 搜索结果里秒开这个页面),如果你希望参与 V2EX 社区的讨论,你可以继续到 V2EX 上打开本讨论主题的完整版本。

https://www.v2ex.com/t/1016352

V2EX 是创意工作者们的社区,是一个分享自己正在做的有趣事物、交流想法,可以遇见新朋友甚至新机会的地方。

V2EX is a community of developers, designers and creative people.

© 2021 V2EX