facebook/wangle

/* * Copyright (c) 2016, Facebook, Inc. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. An additional grant * of patent rights can be found in the PATENTS file in the same directory. * */ #pragma once #include <folly/Executor.h> #include <wangle/concurrent/LifoSemMPMCQueue.h> #include <wangle/concurrent/NamedThreadFactory.h> #include <wangle/deprecated/rx/Observable.h> #include <folly/Baton.h> #include <folly/Memory.h> #include <folly/RWSpinLock.h> #include <algorithm> #include <mutex> #include <queue> #include <glog/logging.h> namespace wangle { class ThreadPoolExecutor : public virtual folly::Executor { public: explicit ThreadPoolExecutor( size_t numThreads, std::shared_ptr<ThreadFactory> threadFactory); ~ThreadPoolExecutor(); virtual void add(folly::Func func) override = 0; virtual void add( folly::Func func, std::chrono::milliseconds expiration, folly::Func expireCallback) = 0; void setThreadFactory(std::shared_ptr<ThreadFactory> threadFactory) { CHECK(numThreads() == 0); threadFactory_ = std::move(threadFactory); } std::shared_ptr<ThreadFactory> getThreadFactory(void) { return threadFactory_; } size_t numThreads(); void setNumThreads(size_t numThreads); /* * stop() is best effort - there is no guarantee that unexecuted tasks won't * be executed before it returns. Specifically, IOThreadPoolExecutor's stop() * behaves like join(). */ void stop(); void join(); struct PoolStats { PoolStats() : threadCount(0), idleThreadCount(0), activeThreadCount(0), pendingTaskCount(0), totalTaskCount(0) {} size_t threadCount, idleThreadCount, activeThreadCount; uint64_t pendingTaskCount, totalTaskCount; }; PoolStats getPoolStats(); struct TaskStats { TaskStats() : expired(false), waitTime(0), runTime(0) {} bool expired; std::chrono::nanoseconds waitTime; std::chrono::nanoseconds runTime; }; Subscription<TaskStats> subscribeToTaskStats( const ObserverPtr<TaskStats>& observer) { return taskStatsSubject_->subscribe(observer); } /** * Base class for threads created with ThreadPoolExecutor. * Some subclasses have methods that operate on these * handles. */ class ThreadHandle { public: virtual ~ThreadHandle() = default; }; /** * Observer interface for thread start/stop. * Provides hooks so actions can be taken when * threads are created */ class Observer { public: virtual void threadStarted(ThreadHandle*) = 0; virtual void threadStopped(ThreadHandle*) = 0; virtual void threadPreviouslyStarted(ThreadHandle* h) { threadStarted(h); } virtual void threadNotYetStopped(ThreadHandle* h) { threadStopped(h); } virtual ~Observer() = default; }; void addObserver(std::shared_ptr<Observer>); void removeObserver(std::shared_ptr<Observer>); protected: // Prerequisite: threadListLock_ writelocked void addThreads(size_t n); // Prerequisite: threadListLock_ writelocked void removeThreads(size_t n, bool isJoin); struct FOLLY_ALIGN_TO_AVOID_FALSE_SHARING Thread : public ThreadHandle { explicit Thread(ThreadPoolExecutor* pool) : id(nextId++), handle(), idle(true), taskStatsSubject(pool->taskStatsSubject_) {} virtual ~Thread() = default; static std::atomic<uint64_t> nextId; uint64_t id; std::thread handle; bool idle; folly::Baton<> startupBaton; std::shared_ptr<Subject<TaskStats>> taskStatsSubject; }; typedef std::shared_ptr<Thread> ThreadPtr; struct Task { explicit Task( folly::Func&& func, std::chrono::milliseconds expiration, folly::Func&& expireCallback); folly::Func func_; TaskStats stats_; std::chrono::steady_clock::time_point enqueueTime_; std::chrono::milliseconds expiration_; folly::Func expireCallback_; }; static void runTask(const ThreadPtr& thread, Task&& task); // The function that will be bound to pool threads. It must call // thread->startupBaton.post() when it's ready to consume work. virtual void threadRun(ThreadPtr thread) = 0; // Stop n threads and put their ThreadPtrs in the threadsStopped_ queue // Prerequisite: threadListLock_ writelocked virtual void stopThreads(size_t n) = 0; // Create a suitable Thread struct virtual ThreadPtr makeThread() { return std::make_shared<Thread>(this); } // Prerequisite: threadListLock_ readlocked virtual uint64_t getPendingTaskCount() = 0; class ThreadList { public: void add(const ThreadPtr& state) { auto it = std::lower_bound(vec_.begin(), vec_.end(), state, // compare method is a static method of class // and therefore cannot be inlined by compiler // as a template predicate of the STL algorithm // but wrapped up with the lambda function (lambda will be inlined) // compiler can inline compare method as well [&](const ThreadPtr& ts1, const ThreadPtr& ts2) -> bool { // inline return compare(ts1, ts2); }); vec_.insert(it, state); } void remove(const ThreadPtr& state) { auto itPair = std::equal_range(vec_.begin(), vec_.end(), state, // the same as above [&](const ThreadPtr& ts1, const ThreadPtr& ts2) -> bool { // inline return compare(ts1, ts2); }); CHECK(itPair.first != vec_.end()); CHECK(std::next(itPair.first) == itPair.second); vec_.erase(itPair.first); } const std::vector<ThreadPtr>& get() const { return vec_; } private: static bool compare(const ThreadPtr& ts1, const ThreadPtr& ts2) { return ts1->id < ts2->id; } std::vector<ThreadPtr> vec_; }; class StoppedThreadQueue : public BlockingQueue<ThreadPtr> { public: void add(ThreadPtr item) override; ThreadPtr take() override; size_t size() override; private: folly::LifoSem sem_; std::mutex mutex_; std::queue<ThreadPtr> queue_; }; std::shared_ptr<ThreadFactory> threadFactory_; ThreadList threadList_; folly::RWSpinLock threadListLock_; StoppedThreadQueue stoppedThreads_; std::atomic<bool> isJoin_; // whether the current downsizing is a join std::shared_ptr<Subject<TaskStats>> taskStatsSubject_; std::vector<std::shared_ptr<Observer>> observers_; }; } // namespace wangle