0%

tokio源码实现简单分析

前言


Tokio 是一个 Rust 的异步运行时,它提供了一个完整的异步 I/O 框架。其实现在AI Code分析工具,比如Cursor、windsurf等基本都能分析出tokio核心实现,本文并不做八股文总结,仅尝试从任务调度实现角度分析,给出我们可以借鉴的设计思想,其中包括:

  • 工作窃取实现
  • 任务调度优化
  • 阻塞任务实现

先来看tokio简单示例:

  • 创建运行时
1
2
3
4
let rt = Runtime::builder()
.worker_threads(4)
.build()
.unwrap();
  • 提交任务
1
2
3
rt.spawn(async {
println!("Running on worker thread");
});
  • 提交阻塞任务
1
2
3
rt.spawn_blocking(|| {
// 执行阻塞操作
});

整体架构

作为异步 I/O 框架,其核心架构包含以下几个主要组件:

核心组件

1
2
3
4
5
6
7
8
9
/// src/runtime/runtime.rs
pub struct Runtime {
/// 调度器
scheduler: Scheduler,
/// 运行时句柄
handle: Handle,
/// 阻塞线程池
blocking_pool: BlockingPool,
}

调度器类型

调度器

1
2
3
4
5
6
7
/// src/runtime/runtime.rs
pub(super) enum Scheduler {
/// 单线程调度器
CurrentThread(CurrentThread),
/// 多线程调度器
MultiThread(MultiThread),
}

调度器句柄, 支持单线程与多线程模式。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
/// src/runtime/scheduler/current_thread/mod.rs
/// Handle to the current thread scheduler
pub(crate) struct Handle {
/// Scheduler state shared across threads
shared: Shared,

/// Resource driver handles
pub(crate) driver: driver::Handle,

/// Blocking pool spawner
pub(crate) blocking_spawner: blocking::Spawner,

/// Current random number generator seed
pub(crate) seed_generator: RngSeedGenerator,

/// User-supplied hooks to invoke for things
pub(crate) task_hooks: TaskHooks,

/// If this is a `LocalRuntime`, flags the owning thread ID.
pub(crate) local_tid: Option<ThreadId>,
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
/// src/runtime/scheduler/multi_thread/handle.rs
/// Handle to the multi thread scheduler
pub(crate) struct Handle {
/// Task spawner
pub(super) shared: worker::Shared,

/// Resource driver handles
pub(crate) driver: driver::Handle,

/// Blocking pool spawner
pub(crate) blocking_spawner: blocking::Spawner,

/// Current random number generator seed
pub(crate) seed_generator: RngSeedGenerator,

/// User-supplied hooks to invoke for things
pub(crate) task_hooks: TaskHooks,
}

Worker

每一个Worker对应一个线程,其实也可以称Worker线程。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
/// src/runtime/scheduler/multi_thread/worker.rs
pub(super) struct Worker {
/// 调度器句柄
handle: Arc<Handle>,
/// Worker 索引
index: usize,
/// 核心数据结构
core: AtomicCell<Core>,
}

/// src/runtime/scheduler/multi_thread/worker.rs
pub(super) struct Core {
/// 本地任务队列
run_queue: queue::Local,
/// LIFO 槽位
lifo_slot: Option<Notified>,
/// 是否正在搜索任务
is_searching: bool,
/// 是否已关闭
is_shutdown: bool,
}

任务队列

这里我们更多指全局Inject队列,worker自身Local队列在core对象中管理。

1
2
3
4
5
6
7
8
9
10
11
/// src/runtime/scheduler/multi_thread/worker.rs
pub(crate) struct Shared {
/// 全局任务队列
pub(super) inject: inject::Shared,
/// 远程 Worker 列表
pub(super) remotes: Box<[Remote]>,
/// 空闲 Worker 管理
pub(super) idle: Idle,
/// 调度器配置
pub(super) config: Config,
}

Worker窃取算法

在分析窃取前,我们先简单总结下tokio全局Inject队列与Worker本地队列的关系:

  • Inject 队列:一个全局的队列,用于接收新任务和作为本地队列溢出时的缓冲区
  • Worker Local 队列:每个 worker 线程都有一个固定大小的本地队列,用于存储待执行的任务
  • (schedule_task)新任务提交时,如果当前线程是 worker 线程,则将任务放入当前 worker 的本地队列,否则放入全局 inject 队列
  • (push_overflow)当本地队列满时,将一半的任务移到 inject 队列

tokio的工作窃取其实策略很简单,优先随机从其他worker(线程)的队列中,取一半还未运行的任务,窃取至本worker运行,如果其他worker也没有,那从全局inject队列取任务,不过窃取也有限制:

  • LOCAL_QUEUE_CAPACITY 是队列的总容量(默认是 256)
  • 如果本队列已使用的空间超过总容量的一半(128),就不进行窃取,这应该是确保本队列有足够的空间来接收窃取的任务

任务调度流程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
/// src/runtime/scheduler/multi_thread/worker.rs#Context.run
fn run(&self, mut core: Box<Core>) -> RunResult {

while !core.is_shutdown {

// 首先检查当前工作线程是否有可用任务
if let Some(task) = core.next_task(&self.worker) {
core = self.run_task(task, core)?;
continue;
}

// 如果没有本地任务,尝试从其他工作线程窃取任务
if let Some(task) = core.steal_work(&self.worker) {
// Found work, switch back to processing
core.stats.start_processing_scheduled_tasks();
core = self.run_task(task, core)?;
} else {
// 如果没有任务可执行,进入等待状态
core = if !self.defer.is_empty() {
self.park_timeout(core, Some(Duration::from_millis(0)))
} else {
self.park(core)
};
core.stats.start_processing_scheduled_tasks();
}
}
}

工作窃取实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
/// src/runtime/scheduler/multi_thread/worker.rs#Core#steal_work
fn steal_work(&mut self, worker: &Worker) -> Option<Notified> {
// 1. 首先尝试将工作线程状态转换为"搜索"状态
if !self.transition_to_searching(worker) {
return None;
}

// 2. 获取远程工作线程的数量
let num = worker.handle.shared.remotes.len();
// 3. 随机选择一个起始工作线程
let start = self.rand.fastrand_n(num as u32) as usize;

// 4. 遍历所有工作线程尝试窃取任务
for i in 0..num {
let i = (start + i) % num;

// 5. 跳过自己,因为知道自己没有任务
if i == worker.index {
continue;
}

// 6. 尝试从目标工作线程窃取任务
let target = &worker.handle.shared.remotes[i];
if let Some(task) = target
.steal
.steal_into(&mut self.run_queue, &mut self.stats)
{
return Some(task);
}
}

// 7. 如果从其他工作线程没有窃取到任务,检查全局队列
worker.handle.next_remote_task()
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
 /// src/runtime/scheduler/multi_thread/queue.rs#Steal#steal_into
pub(crate) fn steal_into(
&self,
dst: &mut Local<T>,
dst_stats: &mut Stats,
) -> Option<task::Notified<T>> {
// 1. 获取目标队列的尾部位置
let dst_tail = unsafe { dst.inner.tail.unsync_load() };

// 2. 检查目标队列是否有足够空间
let (steal, _) = unpack(dst.inner.head.load(Acquire));
if dst_tail.wrapping_sub(steal) > LOCAL_QUEUE_CAPACITY as UnsignedShort / 2 {
return None;
}

// 3. 尝试窃取任务到目标队列
let mut n = self.steal_into2(dst, dst_tail);

if n == 0 {
return None;
}

// 4. 更新统计信息
dst_stats.incr_steal_count(n as u16);
dst_stats.incr_steal_operations();

// 5. 准备返回一个任务
n -= 1;
let ret_pos = dst_tail.wrapping_add(n);
let ret_idx = ret_pos as usize & MASK;

// 6. 获取要返回的任务
let ret = dst.inner.buffer[ret_idx].with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });

if n == 0 {
return Some(ret);
}

// 7. 更新目标队列的尾部位置,使窃取的任务对消费者可见
dst.inner.tail.store(dst_tail.wrapping_add(n), Release);

Some(ret)
}

fn steal_into2(&self, dst: &mut Local<T>, dst_tail: UnsignedShort) -> UnsignedShort {
// 计算可窃取的任务数量
let n = src_tail.wrapping_sub(src_head_real);
// 计算可窃取的任务数量
let n = n - n / 2;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
 /// src/runtime/scheduler/multi_thread/queue.rs#Handle#next_remote_task
fn next_remote_task(&self) -> Option<Notified> {
// 1. 首先检查注入队列是否为空
if self.shared.inject.is_empty() {
return None;
}

// 2. 获取同步锁
let mut synced = self.shared.synced.lock();

// 3. 从注入队列中弹出一个任务
unsafe { self.shared.inject.pop(&mut synced.inject) }
}

任务调度优化

任务调度比较简单就是一LIFO算法调用。

1
2
3
4
impl Core {
// 使用 LIFO 槽位优化最近提交的任务
lifo_slot: Option<Notified>,
}

阻塞任务处理

个人理解,Tokio 的异步运行时主要设计用于处理 I/O 密集型任务,这里的阻塞任务更多是支持如文件 I/O、CPU 密集型计算等:

阻塞操作会占用异步运行时的工作线程,如果阻塞操作直接在异步运行时执行,会降低整体吞吐量,需要将阻塞操作与异步操作分离,以保持异步运行时的效率。

阻塞线程池

1
2
3
4
5
6
7
8
9
10
/// src/runtime/blocking/pool.rs
pub(crate) struct BlockingPool {
spawner: Spawner,
shutdown_rx: shutdown::Receiver,
}

#[derive(Clone)]
pub(crate) struct Spawner {
inner: Arc<Inner>,
}

任务提交

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
/// src/runtime/blocking/pool.rs
impl Spawner {
pub(crate) fn spawn_blocking<F, R>(&self, rt: &Handle, func: F) -> JoinHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
// 1. 检查函数大小
let fn_size = std::mem::size_of::<F>();

// 2. 根据函数大小决定是用Box装箱
let (join_handle, spawn_result) = if fn_size > BOX_FUTURE_THRESHOLD {
self.spawn_blocking_inner(
Box::new(func), // 大函数装箱
Mandatory::NonMandatory,
SpawnMeta::new_unnamed(fn_size),
rt,
)
} else {
self.spawn_blocking_inner(
func, // 小函数直接传递
Mandatory::NonMandatory,
SpawnMeta::new_unnamed(fn_size),
rt,
)
};

// 3. 处理 spawn 结果
match spawn_result {
Ok(()) => join_handle,
// 兼容性处理:即使运行时正在关闭,也返回 join_handle
Err(SpawnError::ShuttingDown) => join_handle,
Err(SpawnError::NoThreads(e)) => {
panic!("OS can't spawn worker thread: {e}")
}
}
}
}