C++实现简单线程池
关键要实现以下接口:
- 线程池构造函数,初始化线程容器
add
,向任务队列添加一个task
知识点:
- lambada
- future
- condition_variable
- 可变模版参数
- memory order
#ifndef GG_THREAD_POOL_H_ #define GG_THREAD_POOL_H_ #include <condition_variable> #include <mutex> #include <atomic> #include <thread> #include <functional> #include <future> #include <queue> #include <vector> namespace gg { class ThreadPool { public: using Task = std::function<void()>; explicit ThreadPool(size_t num_threads = 1) : stop_(false), total_threads_num_(num_threads) { if (total_threads_num_ == 0) { total_threads_num_ = std::thread::hardware_concurrency(); } for (size_t i = 0; i < total_threads_num_; ++i) { thread_pool_.push_back(std::thread([this]() { while (!stop_.load(std::memory_order_acquire)) { Task task; { std::unique_lock<std::mutex> ulk(task_queue_mutex_); task_cv_.wait(ulk, [this]() { return stop_.load(std::memory_order_acquire) || !task_queue_.empty(); }); if (stop_.load(std::memory_order_acquire)) return ; task = std::move(task_queue_.front()); task_queue_.pop(); } task(); } })); } } ~ThreadPool() { stop(); task_cv_.notify_all(); for (auto &t : thread_pool_) { if (t.joinable()) t.join(); } } size_t total_num_threads() const { return total_threads_num_; } template <class Function, class... Args> std::future<typename std::result_of<Function(Args...)>::type> add( Function&& f, Args&&... args) { if (is_stopped()) { throw std::runtime_error("std::thread pool is stopped"); } using return_type = typename std::result_of<Function(Args...)>::type; auto task = std::make_shared<std::packaged_task<return_type()>>( std::bind(std::forward<Function>(f), std::forward<Args>(args)...)); auto ret = task->get_future(); { std::lock_guard<std::mutex> guard(task_queue_mutex_); task_queue_.emplace([task]() { (*task)(); }); } task_cv_.notify_one(); return ret; } void stop() { stop_.store(true, std::memory_order_release); } bool is_stopped() { return stop_.load(std::memory_order_acquire); } private: std::queue<Task> task_queue_; std::vector<std::thread> thread_pool_; std::mutex task_queue_mutex_; std::condition_variable task_cv_; std::atomic<bool> stop_; size_t total_threads_num_; }; } // namespace gg
测试代码,每一个线程计算$2^n$:
#include "thread_pool.h" #include <iostream> int64_t power(int64_t x, int n) { int64_t ans = 1; while (n != 0) { if ((n&1) == 1) { ans *= x; } x *= x; n >>= 1; } return ans; } int main () { gg::ThreadPool thread_pool(8); using task_ret_type = std::future<int64_t>; const int task_num = 50; std::vector<task_ret_type> vec_task_future; for (int i = 0; i < task_num; ++i) { vec_task_future.emplace_back(thread_pool.add(power, 2, i)); } for (auto&& f : vec_task_future) { std::cout << f.get() << std::endl; } return 0; }
goudan-er SHARE · CPP
C++ programming