use {
super::exception::*,
super::process::Process,
super::*,
crate::object::*,
alloc::{boxed::Box, sync::Arc},
bitflags::bitflags,
core::{
any::Any,
future::Future,
ops::Deref,
pin::Pin,
task::{Context, Poll, Waker},
time::Duration,
},
futures::{channel::oneshot::*, future::FutureExt, pin_mut, select_biased},
kernel_hal::{sleep_until, GeneralRegs, UserContext},
spin::Mutex,
};
pub use self::thread_state::*;
mod thread_state;
pub struct Thread {
base: KObjectBase,
_counter: CountHelper,
proc: Arc<Process>,
ext: Box<dyn Any + Send + Sync>,
inner: Mutex<ThreadInner>,
exceptionate: Arc<Exceptionate>,
}
impl_kobject!(Thread
fn related_koid(&self) -> KoID {
self.proc.id()
}
);
define_count_helper!(Thread);
#[derive(Default)]
struct ThreadInner {
context: Option<Box<UserContext>>,
suspend_count: usize,
waker: Option<Waker>,
killer: Option<Sender<()>>,
state: ThreadState,
exception: Option<Arc<Exception>>,
first_thread: bool,
killed: bool,
time: u128,
flags: ThreadFlag,
}
impl ThreadInner {
fn state(&self) -> ThreadState {
if self.suspend_count == 0
|| self.context.is_none()
|| self.state == ThreadState::BlockedException
|| self.state == ThreadState::Dying
|| self.state == ThreadState::Dead
{
self.state
} else {
ThreadState::Suspended
}
}
fn change_state(&mut self, state: ThreadState, base: &KObjectBase) {
self.state = state;
match self.state() {
ThreadState::Dead => base.signal_change(
Signal::THREAD_RUNNING | Signal::THREAD_SUSPENDED,
Signal::THREAD_TERMINATED,
),
ThreadState::New | ThreadState::Dying => base.signal_clear(
Signal::THREAD_RUNNING | Signal::THREAD_SUSPENDED | Signal::THREAD_TERMINATED,
),
ThreadState::Suspended => base.signal_change(
Signal::THREAD_RUNNING | Signal::THREAD_TERMINATED,
Signal::THREAD_SUSPENDED,
),
_ => base.signal_change(
Signal::THREAD_TERMINATED | Signal::THREAD_SUSPENDED,
Signal::THREAD_RUNNING,
),
}
}
}
bitflags! {
#[derive(Default)]
pub struct ThreadFlag: usize {
const VCPU = 1 << 3;
}
}
pub type ThreadFn = fn(thread: CurrentThread) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>;
impl Thread {
pub fn create(proc: &Arc<Process>, name: &str) -> ZxResult<Arc<Self>> {
Self::create_with_ext(proc, name, ())
}
pub fn create_with_ext(
proc: &Arc<Process>,
name: &str,
ext: impl Any + Send + Sync,
) -> ZxResult<Arc<Self>> {
let thread = Arc::new(Thread {
base: KObjectBase::with_name(name),
_counter: CountHelper::new(),
proc: proc.clone(),
ext: Box::new(ext),
exceptionate: Exceptionate::new(ExceptionChannelType::Thread),
inner: Mutex::new(ThreadInner {
context: Some(Box::new(UserContext::default())),
..Default::default()
}),
});
proc.add_thread(thread.clone())?;
Ok(thread)
}
pub fn proc(&self) -> &Arc<Process> {
&self.proc
}
pub fn ext(&self) -> &Box<dyn Any + Send + Sync> {
&self.ext
}
pub fn start(
self: &Arc<Self>,
entry: usize,
stack: usize,
arg1: usize,
arg2: usize,
thread_fn: ThreadFn,
) -> ZxResult {
{
let mut inner = self.inner.lock();
let context = inner.context.as_mut().ok_or(ZxError::BAD_STATE)?;
#[cfg(target_arch = "x86_64")]
{
context.general.rip = entry;
context.general.rsp = stack;
context.general.rdi = arg1;
context.general.rsi = arg2;
context.general.rflags |= 0x3202;
}
#[cfg(target_arch = "aarch64")]
{
context.elr = entry;
context.sp = stack;
context.general.x0 = arg1;
context.general.x1 = arg2;
}
inner.change_state(ThreadState::Running, &self.base);
}
let vmtoken = self.proc().vmar().table_phys();
kernel_hal::Thread::spawn(thread_fn(CurrentThread(self.clone())), vmtoken);
Ok(())
}
pub fn start_with_regs(self: &Arc<Self>, regs: GeneralRegs, thread_fn: ThreadFn) -> ZxResult {
{
let mut inner = self.inner.lock();
let context = inner.context.as_mut().ok_or(ZxError::BAD_STATE)?;
context.general = regs;
#[cfg(target_arch = "x86_64")]
{
context.general.rflags |= 0x3202;
}
inner.change_state(ThreadState::Running, &self.base);
}
let vmtoken = self.proc().vmar().table_phys();
kernel_hal::Thread::spawn(thread_fn(CurrentThread(self.clone())), vmtoken);
Ok(())
}
fn stop(&self, killed: bool) {
let mut inner = self.inner.lock();
if inner.state == ThreadState::Dead {
return;
}
if killed {
inner.killed = true;
}
if inner.state == ThreadState::Dying {
if killed {
if let Some(killer) = inner.killer.take() {
killer.send(()).ok();
}
}
return;
}
inner.change_state(ThreadState::Dying, &self.base);
if let Some(waker) = inner.waker.take() {
waker.wake();
}
if let Some(killer) = inner.killer.take() {
killer.send(()).ok();
}
}
pub fn read_state(&self, kind: ThreadStateKind, buf: &mut [u8]) -> ZxResult<usize> {
let inner = self.inner.lock();
let state = inner.state();
if state != ThreadState::BlockedException && state != ThreadState::Suspended {
if inner.exception.is_some() {
return Err(ZxError::NOT_SUPPORTED);
}
return Err(ZxError::BAD_STATE);
}
let context = inner.context.as_ref().ok_or(ZxError::BAD_STATE)?;
context.read_state(kind, buf)
}
pub fn write_state(&self, kind: ThreadStateKind, buf: &[u8]) -> ZxResult {
let mut inner = self.inner.lock();
let state = inner.state();
if state != ThreadState::BlockedException && state != ThreadState::Suspended {
if inner.exception.is_some() {
return Err(ZxError::NOT_SUPPORTED);
}
return Err(ZxError::BAD_STATE);
}
let context = inner.context.as_mut().ok_or(ZxError::BAD_STATE)?;
context.write_state(kind, buf)
}
pub fn get_thread_info(&self) -> ThreadInfo {
let inner = self.inner.lock();
ThreadInfo {
state: inner.state() as u32,
wait_exception_channel_type: inner
.exception
.as_ref()
.map_or(0, |exception| exception.current_channel_type() as u32),
cpu_affinity_mask: [0u64; 8],
}
}
pub fn get_thread_exception_info(&self) -> ZxResult<ExceptionReport> {
let inner = self.inner.lock();
if inner.state() != ThreadState::BlockedException {
return Err(ZxError::BAD_STATE);
}
let report = inner.exception.as_ref().ok_or(ZxError::BAD_STATE)?.report();
Ok(report)
}
pub fn state(&self) -> ThreadState {
self.inner.lock().state()
}
pub fn time_add(&self, time: u128) {
self.inner.lock().time += time;
}
pub fn get_time(&self) -> u64 {
self.inner.lock().time as u64
}
pub(super) fn set_first_thread(&self) {
self.inner.lock().first_thread = true;
}
pub fn is_first_thread(&self) -> bool {
self.inner.lock().first_thread
}
pub fn flags(&self) -> ThreadFlag {
self.inner.lock().flags
}
pub fn update_flags(&self, f: impl FnOnce(&mut ThreadFlag)) {
f(&mut self.inner.lock().flags)
}
#[cfg(target_arch = "x86_64")]
pub fn set_fsbase(&self, fsbase: usize) -> ZxResult {
let mut inner = self.inner.lock();
let context = inner.context.as_mut().ok_or(ZxError::BAD_STATE)?;
context.general.fsbase = fsbase;
Ok(())
}
#[cfg(target_arch = "x86_64")]
pub fn set_gsbase(&self, gsbase: usize) -> ZxResult {
let mut inner = self.inner.lock();
let context = inner.context.as_mut().ok_or(ZxError::BAD_STATE)?;
context.general.gsbase = gsbase;
Ok(())
}
}
impl Task for Thread {
fn kill(&self) {
self.stop(true)
}
fn suspend(&self) {
let mut inner = self.inner.lock();
inner.suspend_count += 1;
let state = inner.state;
inner.change_state(state, &self.base);
}
fn resume(&self) {
let mut inner = self.inner.lock();
assert_ne!(inner.suspend_count, 0);
inner.suspend_count -= 1;
if inner.suspend_count == 0 {
let state = inner.state;
inner.change_state(state, &self.base);
if let Some(waker) = inner.waker.take() {
waker.wake();
}
}
}
fn exceptionate(&self) -> Arc<Exceptionate> {
self.exceptionate.clone()
}
fn debug_exceptionate(&self) -> Arc<Exceptionate> {
panic!("thread do not have debug exceptionate");
}
}
pub struct CurrentThread(pub(super) Arc<Thread>);
impl Deref for CurrentThread {
type Target = Arc<Thread>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl Drop for CurrentThread {
fn drop(&mut self) {
let mut inner = self.inner.lock();
self.exceptionate.shutdown();
inner.change_state(ThreadState::Dead, &self.base);
self.proc().remove_thread(self.base.id);
}
}
impl CurrentThread {
pub fn exit(&self) {
self.stop(false);
}
pub fn wait_for_run(&self) -> impl Future<Output = Box<UserContext>> {
#[must_use = "wait_for_run does nothing unless polled/`await`-ed"]
struct RunnableChecker {
thread: Arc<Thread>,
}
impl Future for RunnableChecker {
type Output = Box<UserContext>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let mut inner = self.thread.inner.lock();
if inner.state() != ThreadState::Suspended {
Poll::Ready(inner.context.take().unwrap())
} else {
inner.waker = Some(cx.waker().clone());
Poll::Pending
}
}
}
RunnableChecker {
thread: self.0.clone(),
}
}
pub fn end_running(&self, context: Box<UserContext>) {
let mut inner = self.inner.lock();
inner.context = Some(context);
let state = inner.state;
inner.change_state(state, &self.base);
}
pub fn with_context<T, F>(&self, f: F) -> T
where
F: FnOnce(&mut UserContext) -> T,
{
let mut inner = self.inner.lock();
let mut cx = inner.context.as_mut().unwrap();
f(&mut cx)
}
pub async fn blocking_run<F, T, FT>(
&self,
future: F,
state: ThreadState,
deadline: Duration,
cancel_token: Option<Receiver<()>>,
) -> ZxResult<T>
where
F: Future<Output = FT> + Unpin,
FT: IntoResult<T>,
{
let (old_state, killed) = {
let mut inner = self.inner.lock();
if inner.state() == ThreadState::Dying {
return Err(ZxError::STOP);
}
let (sender, receiver) = channel();
inner.killer = Some(sender);
let old_state = inner.state;
inner.change_state(state, &self.base);
(old_state, receiver)
};
let ret = if let Some(cancel_token) = cancel_token {
select_biased! {
ret = future.fuse() => ret.into_result(),
_ = killed.fuse() => Err(ZxError::STOP),
_ = sleep_until(deadline).fuse() => Err(ZxError::TIMED_OUT),
_ = cancel_token.fuse() => Err(ZxError::CANCELED),
}
} else {
select_biased! {
ret = future.fuse() => ret.into_result(),
_ = killed.fuse() => Err(ZxError::STOP),
_ = sleep_until(deadline).fuse() => Err(ZxError::TIMED_OUT),
}
};
let mut inner = self.inner.lock();
inner.killer = None;
if inner.state() == ThreadState::Dying {
return ret;
}
assert_eq!(inner.state, state);
inner.change_state(old_state, &self.base);
ret
}
pub async fn handle_exception(&self, type_: ExceptionType) {
let exception = {
let mut inner = self.inner.lock();
let cx = if !type_.is_synth() {
inner.context.as_ref().map(|cx| cx.as_ref())
} else {
None
};
if !type_.is_synth() {
error!(
"User mode exception: {:?} {:#x?}",
type_,
cx.expect("Architectural exception should has context")
);
}
let exception = Exception::new(&self.0, type_, cx);
inner.exception = Some(exception.clone());
exception
};
if type_ == ExceptionType::ThreadExiting {
let handled = self
.0
.proc()
.debug_exceptionate()
.send_exception(&exception);
if let Ok(future) = handled {
self.dying_run(future).await.ok();
}
} else {
let future = exception.handle();
pin_mut!(future);
self.blocking_run(
future,
ThreadState::BlockedException,
Duration::from_nanos(u64::max_value()),
None,
)
.await
.ok();
}
self.inner.lock().exception = None;
}
async fn dying_run<F, T, FT>(&self, future: F) -> ZxResult<T>
where
F: Future<Output = FT> + Unpin,
FT: IntoResult<T>,
{
let killed = {
let mut inner = self.inner.lock();
if inner.killed {
return Err(ZxError::STOP);
}
let (sender, receiver) = channel::<()>();
inner.killer = Some(sender);
receiver
};
select_biased! {
ret = future.fuse() => ret.into_result(),
_ = killed.fuse() => Err(ZxError::STOP),
}
}
}
pub trait IntoResult<T> {
fn into_result(self) -> ZxResult<T>;
}
impl<T> IntoResult<T> for T {
fn into_result(self) -> ZxResult<T> {
Ok(self)
}
}
impl<T> IntoResult<T> for ZxResult<T> {
fn into_result(self) -> ZxResult<T> {
self
}
}
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum ThreadState {
New = 0,
Running = 1,
Suspended = 2,
Blocked = 3,
Dying = 4,
Dead = 5,
BlockedException = 0x103,
BlockedSleeping = 0x203,
BlockedFutex = 0x303,
BlockedPort = 0x403,
BlockedChannel = 0x503,
BlockedWaitOne = 0x603,
BlockedWaitMany = 0x703,
BlockedInterrupt = 0x803,
BlockedPager = 0x903,
}
impl Default for ThreadState {
fn default() -> Self {
ThreadState::New
}
}
#[repr(C)]
pub struct ThreadInfo {
state: u32,
wait_exception_channel_type: u32,
cpu_affinity_mask: [u64; 8],
}
#[cfg(test)]
mod tests {
use super::job::Job;
use super::*;
use kernel_hal::timer_now;
#[test]
fn create() {
let root_job = Job::root();
let proc = Process::create(&root_job, "proc").expect("failed to create process");
let thread = Thread::create(&proc, "thread").expect("failed to create thread");
assert_eq!(thread.flags(), ThreadFlag::empty());
assert_eq!(thread.related_koid(), proc.id());
let child = proc.get_child(thread.id()).unwrap().downcast_arc().unwrap();
assert!(Arc::ptr_eq(&child, &thread));
}
#[async_std::test]
async fn start() {
kernel_hal_unix::init();
let root_job = Job::root();
let proc = Process::create(&root_job, "proc").expect("failed to create process");
let thread = Thread::create(&proc, "thread").expect("failed to create thread");
let thread1 = Thread::create(&proc, "thread1").expect("failed to create thread");
async fn new_thread(thread: CurrentThread) {
let cx = thread.wait_for_run().await;
assert_eq!(cx.general.rip, 1);
assert_eq!(cx.general.rsp, 4);
assert_eq!(cx.general.rdi, 3);
assert_eq!(cx.general.rsi, 2);
async_std::task::sleep(Duration::from_millis(10)).await;
thread.end_running(cx);
}
let handle = Handle::new(proc.clone(), Rights::DEFAULT_PROCESS);
proc.start(&thread, 1, 4, Some(handle.clone()), 2, |thread| {
Box::pin(new_thread(thread))
})
.expect("failed to start thread");
let info = proc.get_info();
assert!(info.started && !info.has_exited && info.return_code == 0);
assert_eq!(proc.status(), Status::Running);
assert_eq!(thread.state(), ThreadState::Running);
assert_eq!(
proc.start(&thread, 1, 4, Some(handle.clone()), 2, |thread| Box::pin(
new_thread(thread)
)),
Err(ZxError::BAD_STATE)
);
assert_eq!(
proc.start(&thread1, 1, 4, Some(handle.clone()), 2, |thread| Box::pin(
new_thread(thread)
)),
Err(ZxError::BAD_STATE)
);
async_std::task::sleep(core::time::Duration::from_millis(100)).await;
assert_eq!(Arc::strong_count(&thread), 1);
assert_eq!(thread.state(), ThreadState::Dead);
}
#[async_std::test]
async fn blocking_run() {
let root_job = Job::root();
let proc = Process::create(&root_job, "proc").expect("failed to create process");
let thread = Thread::create(&proc, "thread").expect("failed to create thread");
let thread = CurrentThread(thread);
let handle = Handle::new(proc.clone(), Rights::DEFAULT_PROCESS);
let handle_value = proc.add_handle(handle);
let object = proc
.get_dyn_object_with_rights(handle_value, Rights::WAIT)
.unwrap();
let cancel_token = proc.get_cancel_token(handle_value).unwrap();
let future = object.wait_signal(Signal::READABLE);
let deadline = timer_now() + Duration::from_millis(20);
let result = thread
.blocking_run(
future,
ThreadState::BlockedWaitOne,
deadline.into(),
Some(cancel_token),
)
.await;
assert_eq!(result.err(), Some(ZxError::TIMED_OUT));
let cancel_token = proc.get_cancel_token(handle_value).unwrap();
let future = object.wait_signal(Signal::READABLE);
let deadline = timer_now() + Duration::from_millis(20);
async_std::task::spawn({
let proc = proc.clone();
async move {
async_std::task::sleep(Duration::from_millis(10)).await;
proc.remove_handle(handle_value).unwrap();
}
});
let result = thread
.blocking_run(
future,
ThreadState::BlockedWaitOne,
deadline.into(),
Some(cancel_token),
)
.await;
assert_eq!(result.err(), Some(ZxError::CANCELED));
}
#[test]
fn info() {
let root_job = Job::root();
let proc = Process::create(&root_job, "proc").expect("failed to create process");
let thread = Thread::create(&proc, "thread").expect("failed to create thread");
let info = thread.get_thread_info();
assert!(info.state == thread.state() as u32 && info.wait_exception_channel_type == 0);
assert_eq!(
thread.get_thread_exception_info().err(),
Some(ZxError::BAD_STATE)
);
}
#[test]
fn read_write_state() {
let root_job = Job::root();
let proc = Process::create(&root_job, "proc").expect("failed to create process");
let thread = Thread::create(&proc, "thread").expect("failed to create thread");
const SIZE: usize = core::mem::size_of::<GeneralRegs>();
let mut buf = [0; 10];
assert_eq!(
thread.read_state(ThreadStateKind::General, &mut buf).err(),
Some(ZxError::BAD_STATE)
);
assert_eq!(
thread.write_state(ThreadStateKind::General, &buf).err(),
Some(ZxError::BAD_STATE)
);
thread.suspend();
assert_eq!(
thread.read_state(ThreadStateKind::General, &mut buf).err(),
Some(ZxError::BUFFER_TOO_SMALL)
);
assert_eq!(
thread.write_state(ThreadStateKind::General, &buf).err(),
Some(ZxError::BUFFER_TOO_SMALL)
);
let mut buf = [0; SIZE];
assert!(thread
.read_state(ThreadStateKind::General, &mut buf)
.is_ok());
assert!(thread.write_state(ThreadStateKind::General, &buf).is_ok());
}
#[test]
fn ext() {
let root_job = Job::root();
let proc = Process::create(&root_job, "proc").expect("failed to create process");
let thread = Thread::create(&proc, "thread").expect("failed to create thread");
let _ext = thread.ext();
}
#[async_std::test]
async fn wait_for_run() {
let root_job = Job::root();
let proc = Process::create(&root_job, "proc").expect("failed to create process");
let thread = Thread::create(&proc, "thread").expect("failed to create thread");
assert_eq!(thread.state(), ThreadState::New);
thread
.start(0, 0, 0, 0, |thread| Box::pin(new_thread(thread)))
.unwrap();
async fn new_thread(thread: CurrentThread) {
assert_eq!(thread.state(), ThreadState::Running);
let context = thread.wait_for_run().await;
thread.end_running(context);
thread.suspend();
thread.suspend();
assert_eq!(thread.state(), ThreadState::Suspended);
async_std::task::spawn({
let thread = (*thread).clone();
async move {
async_std::task::sleep(Duration::from_millis(10)).await;
thread.resume();
async_std::task::sleep(Duration::from_millis(10)).await;
thread.resume();
}
});
let time = timer_now();
let _context = thread.wait_for_run().await;
assert!(timer_now() - time >= Duration::from_millis(20));
}
let thread: Arc<dyn KernelObject> = thread;
thread.wait_signal(Signal::THREAD_TERMINATED).await;
}
#[test]
fn time() {
let root_job = Job::root();
let proc = Process::create(&root_job, "proc").expect("failed to create process");
let thread = Thread::create(&proc, "thread").expect("failed to create thread");
assert_eq!(thread.get_time(), 0);
thread.time_add(10);
assert_eq!(thread.get_time(), 10);
}
}