🗂

スレッドプールの作り方

2022/05/10に公開

「あっ、スレッドプール作んなきゃ」って思うときありますよね。
そんな時のために、C++でスレッドプールを作るレシピを紹介します。

材料

品目 数量 備考
スレッド 1本(*) std::threadでも良いですが、プラットフォーム固有のAPIを利用できる場合は、その方が細かな制御ができて良いでしょう。
ミューテックス 1個(*) 条件変数と組み合わせて使うので必須です。
条件変数 1個(*) 条件変数なしにワーカースレッドを作るのは、三つ葉がないのに親子丼を作るようなものです。
コンテナ 1個 操作の特性上std::dequeがベストでしょう。
ファンクタ 1種 ここでは気軽にstd::functionを使っていますが、メモリ効率などの観点から、別途自分で定義したインタフェースを持つオブジェクトでも構いません。
アトミック変数 あれば うまく使うとロックが減らせて高速化に役立ちます。

(*)が付いている材料は、コア数分用意できると性能がアップします。

シンプルなProducer-Consumerパターン

スレッドを1本だけ使うパターンから紹介します。特に小細工する必要も無いので実装もシンプルになります。

  • メインスレッドから非同期に実行して欲しい処理をワーカースレッドのキューに積む
  • ワーカースレッドはキューから処理を取り出して実行する

生産者(Producer)・消費者(Consumer)パターンと呼ばれる、非常に古典的なデザインパターンの1つです。

コードにするならこんな感じ。

#include <condition_variable>
#include <deque>
#include <functional>
#include <mutex>
#include <thread>

class worker {
public:
    worker() : thread_([this]() { proc_worker(); }) {}
    ~worker() {
        wait_until_idle();
        request_termination();
        if (thread_.joinable()) {
            thread_.join();
        }
    };

    template<typename F>
    void run(F&& func) {
        std::unique_lock<std::mutex> lock(mutex_);
        if (is_requested_termination) { return; }
        queue_.emplace_back(func);
        cond_.notify_all();
    }

    void wait_until_idle() {
        std::unique_lock<std::mutex> lock(mutex_);
        cond_.wait(lock, [this]() { return queue_.empty() || is_requested_termination; });
    }

    void request_termination() {
        std::unique_lock<std::mutex> lock(mutex_);
        is_requested_termination = true;
        cond_.notify_all();
    }

private:
    void proc_worker() {
        while (true) {
            std::function<void()> task;
            {
                std::unique_lock lock(mutex_);
                cond_.wait(lock, [this]() { return !queue_.empty() || is_requested_termination; });
                if (is_requested_termination) break;
                task = queue_.front();
                queue_.pop_front();
                cond_.notify_all();
            }
            task();
        }
    }

    bool is_requested_termination{ false };
    std::thread thread_;
    std::deque<std::function<void()>> queue_;
    std::mutex mutex_;
    std::condition_variable cond_;
};

stdだけを使って最低限の機能を実装するとこんな感じでしょうか。最低限と言いつつも、結構色々と考慮しないといけないので、どうしてもコードは膨らみます。

run()にファンクタを渡すと、キューに積んでその場では即座にリターンします。
ワーカースレッド上からキューを1つずつ取り出し、実行することを繰り返します。
終了をリクエストすることで、ワーカースレッドの無限ループを抜けて、デストラクタでjoinできます。

マルチコアを活かすには

昨今の動作環境はマルチコアが当たり前ですので、ワーカースレッドが1つだけでは処理効率が上がりません。コア数分のスレッドを立てて、分散したくなります。ではどう分散するのが効率的でしょうか?

単純に考えると、上記のworkerを複数用意して、順番に処理を割り振ればいいと思うかも知れません。

constexpr size_t ThreadCount = 8;
worker workers[ThreadCount];
std::atomic<size_t> index = 0;

workers[(index++) % ThreadCount].run([](){ Hoge(); }); 

これは、投げる処理の実行時間がどれも概ね等しいのなら、それほど悪くはないアプローチです。しかし、ワーカーにファンクタを自由に投げつけられる以上、時間が均等であることを期待することはできません。

では、各スレッド上のキューに積まれたタスク数を取得できるようにして、少ないスレッドから利用するようにしてはどうでしょうか。

auto target = std::min_element(std::begin(workers), std::end(workers), [](auto& lhs, auto& rhs) { return lhs.get_queue_count() < rhs.get_queue_count(); });
target->run([](){ Hoge(); });

これもイマイチです。実行時間はタスクの個数では決まらないので、非常に重たいタスクを1つ抱えているところに処理が積まれかねません。

このように色々検討してみると、処理を積む側が実行スレッドを決定するアプローチにはどうしても難がありそうです。ワーカー側が処理を持っていくようにした方が筋が良いということになります。

1つのProducerに複数のConsumerが存在するパターン

というわけで、1つのキューから複数のスレッドがタスクをむしり取って実行するパターンを実装してみました。

#include <condition_variable>
#include <deque>
#include <functional>
#include <mutex>
#include <thread>

template<size_t Count>
class worker_pool {
public:
    worker_pool() {
        int index = 0;
        for (auto& inner : inner_workers_) {
            inner.initialize(this, index++);
        }
    }
    ~worker_pool() {
        wait_until_idle();
        request_termination();
    };

    template<typename F>
    void run(F&& task) {
        std::unique_lock<std::mutex> lock(mutex_);
        if (is_requested_termination) { return; }

        queue_.emplace_back(task);
        for (auto& inner : inner_workers_) {
            inner.wakeup();
        }
    }

    void wait_until_idle() {
        {
            std::unique_lock<std::mutex> lock(mutex_);
            cond_.wait(lock, [this]() { return queue_.empty() || is_requested_termination; });
        }
        for (auto& inner : inner_workers_) {
            inner.wait_until_idle();
        }
    }

    void request_termination() {
        {
            std::unique_lock<std::mutex> lock(mutex_);
            is_requested_termination = true;
            cond_.notify_all();
        }
        for (auto& inner : inner_workers_) {
            inner.request_termination();
        }
    }

    std::function<void()> pull() {
        std::unique_lock<std::mutex> lock(mutex_);
        if (queue_.empty()) return {};

        auto task = queue_.front();
        queue_.pop_front();
        cond_.notify_all();
        return task;
    }

private:
    class inner_worker {
    public:
        inner_worker() : thread_([this]() { proc_worker(); }) {}
        ~inner_worker() {
            wait_until_idle();
            request_termination();
            if (thread_.joinable()) {
                thread_.join();
            }
        };
	
        void initialize(worker_pool* parent, int index) {
            std::unique_lock<std::mutex> lock(mutex_);
            parent_ = parent;
            index_ = index;
            cond_.notify_all();
        }
        
	void wakeup() {
            std::unique_lock<std::mutex> lock(mutex_);
            cond_.notify_all();
        }
        
	void wait_until_idle() {
            std::unique_lock<std::mutex> lock(mutex_);
            cond_.wait(lock, [this]() { return !current_task_ || is_requested_termination; });
        }
        
	void request_termination() {
            std::unique_lock<std::mutex> lock(mutex_);
            is_requested_termination = true;
            cond_.notify_all();
        }
	
    private:
        void proc_worker() {
            {
                std::unique_lock<std::mutex> lock(mutex_);
                cond_.wait(lock, [this]() { return parent_ != nullptr && index_ >= 0; });
            }
            while (true) {
                auto task = parent_->pull();
                if (!task) {
                    std::unique_lock<std::mutex> lock(mutex_);
                    current_task_ = {};
                    cond_.notify_all();
                    cond_.wait(lock, [&]() { return !!(current_task_ = parent_->pull()) || is_requested_termination; });
                    if (is_requested_termination) break;
                } else {
                    std::unique_lock<std::mutex> lock(mutex_);
                    current_task_ = std::move(task);
                }

                current_task_();
            }
        }

        worker_pool* parent_{ nullptr };
        int index_{ -1 };
        std::function<void()> current_task_{};
        bool is_requested_termination{ false };
        std::thread thread_;
        std::mutex mutex_;
        std::condition_variable cond_;
    };

    inner_worker inner_workers_[Count];
    bool is_requested_termination{ false };
    std::deque<std::function<void()>> queue_;
    std::mutex mutex_;
    std::condition_variable cond_;
};

扱ってる内容に対しては、コンパクトにまとまっている方だと思います。

基本的な流れは1スレッド版と同じですが、タスクが積まれたら全ワーカースレッドを叩き起こしに行くのが大きな違いです。その結果、一番最初に目覚めたスレッドか、あるいはちょうど以前のタスク実行が終わったスレッドがpullしに行くことで、次のタスクをゲットします。
タスクがゲットできなかったスレッドは、条件変数による待機に移行します。

再帰的なスレッドプールの利用に対応する

まぁまぁ良い感じになってきましたが、ここでもう1つ考えたいのが「ワーカーに投げた処理から更にワーカーに処理を投げた場合」の効率化です。そのまま投げるとキューの末尾に追加されるので、自身の後に大量にタスクが追加されていると、再投入したタスクの完了まで相当待つことになります。

再帰的に生成されたタスクは別のキューに積んで、そちらから先に処理すると良さそうな気がします。そのキューの持ち方ですが、ズバリ、各ワーカーごとに持たせてしまうのがベストです。

これを実装したものが以下のコードです。

#include <condition_variable>
#include <deque>
#include <functional>
#include <mutex>
#include <thread>

template<size_t Count>
class worker_pool {
public:
    worker_pool() {
        int index = 0;
        for (auto& inner : inner_workers_) {
            inner.initialize(this, index++);
        }
    }
    ~worker_pool() {
        wait_until_idle();
        request_termination();
    };

    template<typename F>
    void run(F&& task) {
        auto current_thread_index = get_current_thread_index();
        {
            std::unique_lock<std::mutex> lock(mutex_);
	    // インデックス-1はワーカー外スレッドを指す
            if (current_thread_index == -1) {
                global_queue_.emplace_back(std::forward<F>(task));
            } else {
                inner_workers_[current_thread_index].push(std::forward<F>(task));
            }
        }
        wakeup_all(current_thread_index);
    }

    void wait_until_idle() {
        {
            std::unique_lock<std::mutex> lock(mutex_);
            cond_.wait(lock, [this]() { return global_queue_.empty() || is_requested_termination; });
        }
        for (auto& inner : inner_workers_) {
            inner.wait_until_idle();
        }
    }

    void request_termination() {
        {
            std::unique_lock<std::mutex> lock(mutex_);
            is_requested_termination = true;
            cond_.notify_all();
        }
        for (auto& inner : inner_workers_) {
            inner.request_termination();
        }
    }

    std::function<void()> steal_or_pull(int index) {
        for (int i = 0; i < Count; ++i) {
            if (i == index) {
                continue;
            }
            auto task = inner_workers_[i].steal();
            if (!!task) {
                return std::move(task);
            }
        }

        std::unique_lock<std::mutex> lock(mutex_);
        if (global_queue_.empty()) {
            return {};
        }

        auto task = global_queue_.front();
        global_queue_.pop_front();
        cond_.notify_all();
        return std::move(task);
    }

    void wakeup_all(int index) {
        int i = 0;
        for (auto& inner : inner_workers_) {
            if (i != index) {
                inner.wakeup();
            }
            ++i;
        }
    }

private:
    class inner_worker {
    public:
        inner_worker() : thread_([this]() { proc_worker(); }) {}
        ~inner_worker() {
            wait_until_idle();
            request_termination();
            if (thread_.joinable()) {
                thread_.join();
            }
        }

        void initialize(worker_pool* parent, int index) {
            std::unique_lock<std::mutex> lock(mutex_);
            parent_ = parent;
            index_ = index;
            cond_.notify_all();
        }

        template<typename F>
        void push(F&& task) {
            std::unique_lock<std::mutex> lock(mutex_);
            local_queue_.emplace_back(task);
        }

        std::function<void()> steal() {
            std::unique_lock<std::mutex> lock(mutex_);
            if (local_queue_.empty()) {
                return {};
            }

            auto task = local_queue_.front();
            local_queue_.pop_front();
            cond_.notify_all();
            return std::move(task);
        }

        void wakeup() {
            std::unique_lock<std::mutex> lock(mutex_);
            cond_.notify_all();
        }

        void wait_until_idle() {
            std::unique_lock<std::mutex> lock(mutex_);
            cond_.wait(lock, [this]() { return (local_queue_.empty() && !current_task_) || is_requested_termination; });
        }

        void request_termination() {
            std::unique_lock<std::mutex> lock(mutex_);
            is_requested_termination = true;
            cond_.notify_all();
        }

        std::thread::id get_thread_id() const {
            return thread_.get_id();
        }

    private:
        void wait_initialize() {
            std::unique_lock<std::mutex> lock(mutex_);
            cond_.wait(lock, [this]() { return parent_ != nullptr && index_ >= 0; });
        }
        
        void proc_worker() {
            wait_initialize();

            while (true) {
                bool is_assigned = false;
                {
                    // ローカルキュー末尾からの取り出しを優先する
                    std::unique_lock<std::mutex> lock(mutex_);
                    if (!local_queue_.empty()) {
                        current_task_ = local_queue_.back();
                        local_queue_.pop_back();
                        is_assigned = true;
                    }
                }
                if (!is_assigned) {
                    auto task = parent_->steal_or_pull(index_);
                    if (!task) {
                        std::unique_lock<std::mutex> lock(mutex_);
                        current_task_ = {};
                        cond_.notify_all();  // 何も処理していないことを通知する
                        cond_.wait(lock, [&]() {
                            // この述語内ではロックを取得しているはず
                            if (is_requested_termination) {
                                return true;
                            }
                            // steal_or_pullの時にロックを保持しているとデッドロックを起こす
                            lock.unlock();
                            auto task = parent_->steal_or_pull(index_);
                            lock.lock();
                            // current_task_の更新時に再度ロックを取得する
                            current_task_ = std::move(task);
                            return !!current_task_;
                        });
                        if (is_requested_termination) {
                            break;
                        }
                    }
                    else {
                        std::unique_lock<std::mutex> lock(mutex_);
                        current_task_ = std::move(task);
                    }
                }
                current_task_();
            }
        }

        worker_pool* parent_{ nullptr };
        int index_{ -1 };
        std::function<void()> current_task_{};
        bool is_requested_termination{ false };
        std::thread thread_;
        std::deque<std::function<void()>> local_queue_;
        std::mutex mutex_;
        std::condition_variable cond_;
    };

    int get_current_thread_index() {
        auto current_id = std::this_thread::get_id();
        int index = 0;
        for (auto& worker : inner_workers_) {
            if (current_id == worker.get_thread_id()) {
                return index;
            }
            ++index;
        }
        return -1;
    }

    inner_worker inner_workers_[Count];
    bool is_requested_termination{ false };
    std::deque<std::function<void()>> global_queue_;
    std::mutex mutex_;
    std::condition_variable cond_;
};

ほら、簡単……なわけないですね。シンプルに作っていたはずなのに。どうしてこうなった。

ワーカー外から追加されるタスクを積むキューをグローバルキューとして、ワーカー内から追加されるタスクはローカルキューに積むことにします。
ワーカーは、ローカルキューにタスクがある時はそちらから優先して消化しますが、取り出すタスクを末尾からにします。こうすることで、先ほどまで実行していたタスクから生成されたタスクを優先的に処理するようになり、キャッシュ効率なども良くなることが期待できます。

ローカルキューを持つようにしただけでは、1つのワーカー上で大量に再帰的なタスクが生成されると分散効率が悪くなります。そこで、他に暇しているワーカーがいたら、よそのワーカーのローカルキューからタスクをぶんどることにします。いわゆる「ワークスティーリング」と呼ばれるアルゴリズムです。
ワークスティーリングする時は、よそのローカルキューの先頭から持って行くことにします。ワーカーは自身のローカルキューについては末尾から処理するので、先頭の方に埋もれているタスクから処理してあげた方が、親タスク投入からの経過時間や分散効率などの観点でも良い結果が得られます。

優先順位としては

  1. 自身のワーカーから再帰的に投下されたタスク(ローカルキューの末尾)
  2. よそのワーカーに積まれて埋もれているタスク(ローカルキューの先頭)
  3. ワーカー外から投下されたタスク(グローバルキューの先頭)
    の順に処理しようと、各ワーカーは動きます。

参考資料

https://ufcpp.net/study/csharp/misc_task.html#thread_pool4
C#の大御所さまの資料ですが、わかりやすかったのでこれを元に実装してみたのがこの記事を書くきっかけでした。岩永さん、ありがとうございます。

つかれた

現状のコードはかなり保守的にロックを取っているので、アトミック変数を使った高速化ネタなどもあるんですが、とても長くなってしまったのでいったんここまで。やっぱりマルチスレッドは大変ですね。

Discussion