use {
crate::signal::*,
alloc::{boxed::Box, string::String, sync::Arc, vec::Vec},
core::{
fmt::Debug,
future::Future,
pin::Pin,
sync::atomic::*,
task::{Context, Poll},
},
downcast_rs::{impl_downcast, DowncastSync},
spin::Mutex,
};
pub use {super::*, handle::*, rights::*, signal::*};
mod handle;
mod rights;
mod signal;
pub trait KernelObject: DowncastSync + Debug {
fn id(&self) -> KoID;
fn type_name(&self) -> &str;
fn name(&self) -> alloc::string::String;
fn set_name(&self, name: &str);
fn signal(&self) -> Signal;
fn signal_set(&self, signal: Signal);
fn signal_clear(&self, signal: Signal);
fn signal_change(&self, clear: Signal, set: Signal);
fn add_signal_callback(&self, callback: SignalHandler);
fn get_child(&self, _id: KoID) -> ZxResult<Arc<dyn KernelObject>> {
Err(ZxError::WRONG_TYPE)
}
fn peer(&self) -> ZxResult<Arc<dyn KernelObject>> {
Err(ZxError::NOT_SUPPORTED)
}
fn related_koid(&self) -> KoID {
0
}
fn allowed_signals(&self) -> Signal {
Signal::USER_ALL
}
}
impl_downcast!(sync KernelObject);
pub struct KObjectBase {
pub id: KoID,
inner: Mutex<KObjectBaseInner>,
}
#[derive(Default)]
struct KObjectBaseInner {
name: String,
signal: Signal,
signal_callbacks: Vec<SignalHandler>,
}
impl Default for KObjectBase {
fn default() -> Self {
KObjectBase {
id: Self::new_koid(),
inner: Default::default(),
}
}
}
impl KObjectBase {
pub fn new() -> Self {
Self::default()
}
pub fn with_signal(signal: Signal) -> Self {
KObjectBase::with(Default::default(), signal)
}
pub fn with_name(name: &str) -> Self {
KObjectBase::with(name, Default::default())
}
pub fn with(name: &str, signal: Signal) -> Self {
KObjectBase {
id: Self::new_koid(),
inner: Mutex::new(KObjectBaseInner {
name: String::from(name),
signal,
..Default::default()
}),
}
}
fn new_koid() -> KoID {
static KOID: AtomicU64 = AtomicU64::new(1024);
KOID.fetch_add(1, Ordering::SeqCst)
}
pub fn name(&self) -> String {
self.inner.lock().name.clone()
}
pub fn set_name(&self, name: &str) {
self.inner.lock().name = String::from(name);
}
pub fn signal(&self) -> Signal {
self.inner.lock().signal
}
pub fn signal_change(&self, clear: Signal, set: Signal) {
let mut inner = self.inner.lock();
let old_signal = inner.signal;
inner.signal.remove(clear);
inner.signal.insert(set);
let new_signal = inner.signal;
if new_signal == old_signal {
return;
}
inner.signal_callbacks.retain(|f| !f(new_signal));
}
pub fn signal_set(&self, signal: Signal) {
self.signal_change(Signal::empty(), signal);
}
pub fn signal_clear(&self, signal: Signal) {
self.signal_change(signal, Signal::empty());
}
pub fn add_signal_callback(&self, callback: SignalHandler) {
let mut inner = self.inner.lock();
if !callback(inner.signal) {
inner.signal_callbacks.push(callback);
}
}
}
impl dyn KernelObject {
pub fn wait_signal(self: &Arc<Self>, signal: Signal) -> impl Future<Output = Signal> {
#[must_use = "wait_signal does nothing unless polled/`await`-ed"]
struct SignalFuture {
object: Arc<dyn KernelObject>,
signal: Signal,
first: bool,
}
impl Future for SignalFuture {
type Output = Signal;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let current_signal = self.object.signal();
if !(current_signal & self.signal).is_empty() {
return Poll::Ready(current_signal);
}
if self.first {
self.object.add_signal_callback(Box::new({
let signal = self.signal;
let waker = cx.waker().clone();
move |s| {
if (s & signal).is_empty() {
return false;
}
waker.wake_by_ref();
true
}
}));
self.first = false;
}
Poll::Pending
}
}
SignalFuture {
object: self.clone(),
signal,
first: true,
}
}
#[allow(unsafe_code)]
pub fn send_signal_to_port_async(self: &Arc<Self>, signal: Signal, port: &Arc<Port>, key: u64) {
let current_signal = self.signal();
if !(current_signal & signal).is_empty() {
port.push(PortPacketRepr {
key,
status: ZxError::OK,
data: PayloadRepr::Signal(PacketSignal {
trigger: signal,
observed: current_signal,
count: 1,
timestamp: 0,
_reserved1: 0,
}),
});
return;
}
self.add_signal_callback(Box::new({
let port = port.clone();
move |s| {
if (s & signal).is_empty() {
return false;
}
port.push(PortPacketRepr {
key,
status: ZxError::OK,
data: PayloadRepr::Signal(PacketSignal {
trigger: signal,
observed: s,
count: 1,
timestamp: 0,
_reserved1: 0,
}),
});
true
}
}));
}
}
pub fn wait_signal_many(
targets: &[(Arc<dyn KernelObject>, Signal)],
) -> impl Future<Output = Vec<Signal>> {
#[must_use = "wait_signal_many does nothing unless polled/`await`-ed"]
struct SignalManyFuture {
targets: Vec<(Arc<dyn KernelObject>, Signal)>,
first: bool,
}
impl SignalManyFuture {
fn happened(&self, current_signals: &[Signal]) -> bool {
self.targets
.iter()
.zip(current_signals)
.any(|(&(_, desired), ¤t)| !(current & desired).is_empty())
}
}
impl Future for SignalManyFuture {
type Output = Vec<Signal>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let current_signals: Vec<_> =
self.targets.iter().map(|(obj, _)| obj.signal()).collect();
if self.happened(¤t_signals) {
return Poll::Ready(current_signals);
}
if self.first {
for (object, signal) in self.targets.iter() {
object.add_signal_callback(Box::new({
let signal = *signal;
let waker = cx.waker().clone();
move |s| {
if (s & signal).is_empty() {
return false;
}
waker.wake_by_ref();
true
}
}));
}
self.first = false;
}
Poll::Pending
}
}
SignalManyFuture {
targets: Vec::from(targets),
first: true,
}
}
#[macro_export]
macro_rules! impl_kobject {
($class:ident $( $fn:tt )*) => {
impl KernelObject for $class {
fn id(&self) -> KoID {
self.base.id
}
fn type_name(&self) -> &str {
stringify!($class)
}
fn name(&self) -> alloc::string::String {
self.base.name()
}
fn set_name(&self, name: &str){
self.base.set_name(name)
}
fn signal(&self) -> Signal {
self.base.signal()
}
fn signal_set(&self, signal: Signal) {
self.base.signal_set(signal);
}
fn signal_clear(&self, signal: Signal) {
self.base.signal_clear(signal);
}
fn signal_change(&self, clear: Signal, set: Signal) {
self.base.signal_change(clear, set);
}
fn add_signal_callback(&self, callback: SignalHandler) {
self.base.add_signal_callback(callback);
}
$( $fn )*
}
impl core::fmt::Debug for $class {
fn fmt(
&self,
f: &mut core::fmt::Formatter<'_>,
) -> core::result::Result<(), core::fmt::Error> {
f.debug_tuple(&stringify!($class))
.field(&self.id())
.field(&self.name())
.finish()
}
}
};
}
#[macro_export]
macro_rules! define_count_helper {
($class:ident) => {
struct CountHelper(());
impl CountHelper {
fn new() -> Self {
kcounter!(CREATE_COUNT, concat!(stringify!($class), ".create"));
CREATE_COUNT.add(1);
CountHelper(())
}
}
impl Drop for CountHelper {
fn drop(&mut self) {
kcounter!(DESTROY_COUNT, concat!(stringify!($class), ".destroy"));
DESTROY_COUNT.add(1);
}
}
};
}
pub type KoID = u64;
pub type SignalHandler = Box<dyn Fn(Signal) -> bool + Send>;
pub struct DummyObject {
base: KObjectBase,
}
impl_kobject!(DummyObject);
impl DummyObject {
pub fn new() -> Arc<Self> {
Arc::new(DummyObject {
base: KObjectBase::new(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use async_std::sync::Barrier;
use std::time::Duration;
#[async_std::test]
async fn wait() {
let object = DummyObject::new();
let barrier = Arc::new(Barrier::new(2));
async_std::task::spawn({
let object = object.clone();
let barrier = barrier.clone();
async move {
async_std::task::sleep(Duration::from_millis(20)).await;
object.signal_set(Signal::USER_SIGNAL_0);
object.signal_clear(Signal::USER_SIGNAL_0);
object.signal_set(Signal::READABLE);
barrier.wait().await;
object.signal_set(Signal::WRITABLE);
}
});
let object: Arc<dyn KernelObject> = object;
let signal = object.wait_signal(Signal::READABLE).await;
assert_eq!(signal, Signal::READABLE);
barrier.wait().await;
let signal = object.wait_signal(Signal::WRITABLE).await;
assert_eq!(signal, Signal::READABLE | Signal::WRITABLE);
}
#[async_std::test]
async fn wait_many() {
let objs = [DummyObject::new(), DummyObject::new()];
let barrier = Arc::new(Barrier::new(2));
async_std::task::spawn({
let objs = objs.clone();
let barrier = barrier.clone();
async move {
async_std::task::sleep(Duration::from_millis(20)).await;
objs[0].signal_set(Signal::READABLE);
barrier.wait().await;
objs[1].signal_set(Signal::WRITABLE);
}
});
let obj0: Arc<dyn KernelObject> = objs[0].clone();
let obj1: Arc<dyn KernelObject> = objs[1].clone();
let signals = wait_signal_many(&[
(obj0.clone(), Signal::READABLE),
(obj1.clone(), Signal::READABLE),
])
.await;
assert_eq!(signals, [Signal::READABLE, Signal::empty()]);
barrier.wait().await;
let signals = wait_signal_many(&[
(obj0.clone(), Signal::WRITABLE),
(obj1.clone(), Signal::WRITABLE),
])
.await;
assert_eq!(signals, [Signal::READABLE, Signal::WRITABLE]);
}
#[test]
fn test_trait_with_dummy() {
let dummy = DummyObject::new();
assert_eq!(dummy.name(), String::from(""));
dummy.set_name("test");
assert_eq!(dummy.name(), String::from("test"));
dummy.signal_set(Signal::WRITABLE);
assert_eq!(dummy.signal(), Signal::WRITABLE);
dummy.signal_change(Signal::WRITABLE, Signal::READABLE);
assert_eq!(dummy.signal(), Signal::READABLE);
assert_eq!(dummy.get_child(0).unwrap_err(), ZxError::WRONG_TYPE);
assert_eq!(dummy.peer().unwrap_err(), ZxError::NOT_SUPPORTED);
assert_eq!(dummy.related_koid(), 0);
assert_eq!(dummy.allowed_signals(), Signal::USER_ALL);
assert_eq!(
format!("{:?}", dummy),
format!("DummyObject({}, \"test\")", dummy.id())
);
}
}