diff options
Diffstat (limited to 'rust')
64 files changed, 5805 insertions, 432 deletions
diff --git a/rust/Makefile b/rust/Makefile index f70d5e244fee..b5e0a73b78f3 100644 --- a/rust/Makefile +++ b/rust/Makefile @@ -8,16 +8,16 @@ always-$(CONFIG_RUST) += exports_core_generated.h # Missing prototypes are expected in the helpers since these are exported # for Rust only, thus there is no header nor prototypes. -obj-$(CONFIG_RUST) += helpers.o -CFLAGS_REMOVE_helpers.o = -Wmissing-prototypes -Wmissing-declarations +obj-$(CONFIG_RUST) += helpers/helpers.o +CFLAGS_REMOVE_helpers/helpers.o = -Wmissing-prototypes -Wmissing-declarations always-$(CONFIG_RUST) += libmacros.so no-clean-files += libmacros.so always-$(CONFIG_RUST) += bindings/bindings_generated.rs bindings/bindings_helpers_generated.rs obj-$(CONFIG_RUST) += alloc.o bindings.o kernel.o -always-$(CONFIG_RUST) += exports_alloc_generated.h exports_bindings_generated.h \ - exports_kernel_generated.h +always-$(CONFIG_RUST) += exports_alloc_generated.h exports_helpers_generated.h \ + exports_bindings_generated.h exports_kernel_generated.h always-$(CONFIG_RUST) += uapi/uapi_generated.rs obj-$(CONFIG_RUST) += uapi.o @@ -44,17 +44,10 @@ rustc_sysroot := $(shell MAKEFLAGS= $(RUSTC) $(rust_flags) --print sysroot) rustc_host_target := $(shell $(RUSTC) --version --verbose | grep -F 'host: ' | cut -d' ' -f2) RUST_LIB_SRC ?= $(rustc_sysroot)/lib/rustlib/src/rust/library -ifeq ($(quiet),silent_) -cargo_quiet=-q +ifneq ($(quiet),) rust_test_quiet=-q rustdoc_test_quiet=--test-args -q rustdoc_test_kernel_quiet=>/dev/null -else ifeq ($(quiet),quiet_) -rust_test_quiet=-q -rustdoc_test_quiet=--test-args -q -rustdoc_test_kernel_quiet=>/dev/null -else -cargo_quiet=--verbose endif core-cfgs = \ @@ -70,6 +63,7 @@ quiet_cmd_rustdoc = RUSTDOC $(if $(rustdoc_host),H, ) $< OBJTREE=$(abspath $(objtree)) \ $(RUSTDOC) $(if $(rustdoc_host),$(rust_common_flags),$(rust_flags)) \ $(rustc_target_flags) -L$(objtree)/$(obj) \ + -Zunstable-options --generate-link-to-definition \ --output $(rustdoc_output) \ --crate-name $(subst rustdoc-,,$@) \ $(if $(rustdoc_host),,--sysroot=/dev/null) \ @@ -135,22 +129,21 @@ quiet_cmd_rustc_test_library = RUSTC TL $< @$(objtree)/include/generated/rustc_cfg $(rustc_target_flags) \ --crate-type $(if $(rustc_test_library_proc),proc-macro,rlib) \ --out-dir $(objtree)/$(obj)/test --cfg testlib \ - --sysroot $(objtree)/$(obj)/test/sysroot \ -L$(objtree)/$(obj)/test \ --crate-name $(subst rusttest-,,$(subst rusttestlib-,,$@)) $< -rusttestlib-build_error: $(src)/build_error.rs rusttest-prepare FORCE +rusttestlib-build_error: $(src)/build_error.rs FORCE +$(call if_changed,rustc_test_library) rusttestlib-macros: private rustc_target_flags = --extern proc_macro rusttestlib-macros: private rustc_test_library_proc = yes -rusttestlib-macros: $(src)/macros/lib.rs rusttest-prepare FORCE +rusttestlib-macros: $(src)/macros/lib.rs FORCE +$(call if_changed,rustc_test_library) -rusttestlib-bindings: $(src)/bindings/lib.rs rusttest-prepare FORCE +rusttestlib-bindings: $(src)/bindings/lib.rs FORCE +$(call if_changed,rustc_test_library) -rusttestlib-uapi: $(src)/uapi/lib.rs rusttest-prepare FORCE +rusttestlib-uapi: $(src)/uapi/lib.rs FORCE +$(call if_changed,rustc_test_library) quiet_cmd_rustdoc_test = RUSTDOC T $< @@ -159,7 +152,7 @@ quiet_cmd_rustdoc_test = RUSTDOC T $< $(RUSTDOC) --test $(rust_common_flags) \ @$(objtree)/include/generated/rustc_cfg \ $(rustc_target_flags) $(rustdoc_test_target_flags) \ - --sysroot $(objtree)/$(obj)/test/sysroot $(rustdoc_test_quiet) \ + $(rustdoc_test_quiet) \ -L$(objtree)/$(obj)/test --output $(rustdoc_output) \ --crate-name $(subst rusttest-,,$@) $< @@ -192,7 +185,6 @@ quiet_cmd_rustc_test = RUSTC T $< $(RUSTC) --test $(rust_common_flags) \ @$(objtree)/include/generated/rustc_cfg \ $(rustc_target_flags) --out-dir $(objtree)/$(obj)/test \ - --sysroot $(objtree)/$(obj)/test/sysroot \ -L$(objtree)/$(obj)/test \ --crate-name $(subst rusttest-,,$@) $<; \ $(objtree)/$(obj)/test/$(subst rusttest-,,$@) $(rust_test_quiet) \ @@ -200,60 +192,15 @@ quiet_cmd_rustc_test = RUSTC T $< rusttest: rusttest-macros rusttest-kernel -# This prepares a custom sysroot with our custom `alloc` instead of -# the standard one. -# -# This requires several hacks: -# - Unlike `core` and `alloc`, `std` depends on more than a dozen crates, -# including third-party crates that need to be downloaded, plus custom -# `build.rs` steps. Thus hardcoding things here is not maintainable. -# - `cargo` knows how to build the standard library, but it is an unstable -# feature so far (`-Zbuild-std`). -# - `cargo` only considers the use case of building the standard library -# to use it in a given package. Thus we need to create a dummy package -# and pick the generated libraries from there. -# - The usual ways of modifying the dependency graph in `cargo` do not seem -# to apply for the `-Zbuild-std` steps, thus we have to mislead it -# by modifying the sources in the sysroot. -# - To avoid messing with the user's Rust installation, we create a clone -# of the sysroot. However, `cargo` ignores `RUSTFLAGS` in the `-Zbuild-std` -# steps, thus we use a wrapper binary passed via `RUSTC` to pass the flag. -# -# In the future, we hope to avoid the whole ordeal by either: -# - Making the `test` crate not depend on `std` (either improving upstream -# or having our own custom crate). -# - Making the tests run in kernel space (requires the previous point). -# - Making `std` and friends be more like a "normal" crate, so that -# `-Zbuild-std` and related hacks are not needed. -quiet_cmd_rustsysroot = RUSTSYSROOT - cmd_rustsysroot = \ - rm -rf $(objtree)/$(obj)/test; \ - mkdir -p $(objtree)/$(obj)/test; \ - cp -a $(rustc_sysroot) $(objtree)/$(obj)/test/sysroot; \ - echo '\#!/bin/sh' > $(objtree)/$(obj)/test/rustc_sysroot; \ - echo "$(RUSTC) --sysroot=$(abspath $(objtree)/$(obj)/test/sysroot) \"\$$@\"" \ - >> $(objtree)/$(obj)/test/rustc_sysroot; \ - chmod u+x $(objtree)/$(obj)/test/rustc_sysroot; \ - $(CARGO) -q new $(objtree)/$(obj)/test/dummy; \ - RUSTC=$(objtree)/$(obj)/test/rustc_sysroot $(CARGO) $(cargo_quiet) \ - test -Zbuild-std --target $(rustc_host_target) \ - --manifest-path $(objtree)/$(obj)/test/dummy/Cargo.toml; \ - rm $(objtree)/$(obj)/test/sysroot/lib/rustlib/$(rustc_host_target)/lib/*; \ - cp $(objtree)/$(obj)/test/dummy/target/$(rustc_host_target)/debug/deps/* \ - $(objtree)/$(obj)/test/sysroot/lib/rustlib/$(rustc_host_target)/lib - -rusttest-prepare: FORCE - +$(call if_changed,rustsysroot) - rusttest-macros: private rustc_target_flags = --extern proc_macro rusttest-macros: private rustdoc_test_target_flags = --crate-type proc-macro -rusttest-macros: $(src)/macros/lib.rs rusttest-prepare FORCE +rusttest-macros: $(src)/macros/lib.rs FORCE +$(call if_changed,rustc_test) +$(call if_changed,rustdoc_test) rusttest-kernel: private rustc_target_flags = --extern alloc \ --extern build_error --extern macros --extern bindings --extern uapi -rusttest-kernel: $(src)/kernel/lib.rs rusttest-prepare \ +rusttest-kernel: $(src)/kernel/lib.rs \ rusttestlib-build_error rusttestlib-macros rusttestlib-bindings \ rusttestlib-uapi FORCE +$(call if_changed,rustc_test) @@ -281,7 +228,7 @@ bindgen_skip_c_flags := -mno-fp-ret-in-387 -mpreferred-stack-boundary=% \ -fno-reorder-blocks -fno-allow-store-data-races -fasan-shadow-offset=% \ -fzero-call-used-regs=% -fno-stack-clash-protection \ -fno-inline-functions-called-once -fsanitize=bounds-strict \ - -fstrict-flex-arrays=% \ + -fstrict-flex-arrays=% -fmin-function-alignment=% \ --param=% --param asan-% # Derived from `scripts/Makefile.clang`. @@ -324,7 +271,7 @@ quiet_cmd_bindgen = BINDGEN $@ cmd_bindgen = \ $(BINDGEN) $< $(bindgen_target_flags) \ --use-core --with-derive-default --ctypes-prefix core::ffi --no-layout-tests \ - --no-debug '.*' \ + --no-debug '.*' --enable-function-attribute-detection \ -o $@ -- $(bindgen_c_flags_final) -DMODULE \ $(bindgen_target_cflags) $(bindgen_target_extra) @@ -353,13 +300,13 @@ $(obj)/bindings/bindings_helpers_generated.rs: private bindgen_target_cflags = \ -I$(objtree)/$(obj) -Wno-missing-prototypes -Wno-missing-declarations $(obj)/bindings/bindings_helpers_generated.rs: private bindgen_target_extra = ; \ sed -Ei 's/pub fn rust_helper_([a-zA-Z0-9_]*)/#[link_name="rust_helper_\1"]\n pub fn \1/g' $@ -$(obj)/bindings/bindings_helpers_generated.rs: $(src)/helpers.c FORCE +$(obj)/bindings/bindings_helpers_generated.rs: $(src)/helpers/helpers.c FORCE $(call if_changed_dep,bindgen) quiet_cmd_exports = EXPORTS $@ cmd_exports = \ $(NM) -p --defined-only $< \ - | awk '/ (T|R|D) / {printf "EXPORT_SYMBOL_RUST_GPL(%s);\n",$$3}' > $@ + | awk '$$2~/(T|R|D|B)/ && $$3!~/__cfi/ {printf "EXPORT_SYMBOL_RUST_GPL(%s);\n",$$3}' > $@ $(obj)/exports_core_generated.h: $(obj)/core.o FORCE $(call if_changed,exports) @@ -367,6 +314,18 @@ $(obj)/exports_core_generated.h: $(obj)/core.o FORCE $(obj)/exports_alloc_generated.h: $(obj)/alloc.o FORCE $(call if_changed,exports) +# Even though Rust kernel modules should never use the bindings directly, +# symbols from the `bindings` crate and the C helpers need to be exported +# because Rust generics and inlined functions may not get their code generated +# in the crate where they are defined. Other helpers, called from non-inline +# functions, may not be exported, in principle. However, in general, the Rust +# compiler does not guarantee codegen will be performed for a non-inline +# function either. Therefore, we export all symbols from helpers and bindings. +# In the future, this may be revisited to reduce the number of exports after +# the compiler is informed about the places codegen is required. +$(obj)/exports_helpers_generated.h: $(obj)/helpers/helpers.o FORCE + $(call if_changed,exports) + $(obj)/exports_bindings_generated.h: $(obj)/bindings.o FORCE $(call if_changed,exports) @@ -383,9 +342,7 @@ quiet_cmd_rustc_procmacro = $(RUSTC_OR_CLIPPY_QUIET) P $@ --crate-name $(patsubst lib%.so,%,$(notdir $@)) $< # Procedural macros can only be used with the `rustc` that compiled it. -# Therefore, to get `libmacros.so` automatically recompiled when the compiler -# version changes, we add `core.o` as a dependency (even if it is not needed). -$(obj)/libmacros.so: $(src)/macros/lib.rs $(obj)/core.o FORCE +$(obj)/libmacros.so: $(src)/macros/lib.rs FORCE +$(call if_changed_dep,rustc_procmacro) quiet_cmd_rustc_library = $(if $(skip_clippy),RUSTC,$(RUSTC_OR_CLIPPY_QUIET)) L $@ @@ -398,18 +355,19 @@ quiet_cmd_rustc_library = $(if $(skip_clippy),RUSTC,$(RUSTC_OR_CLIPPY_QUIET)) L --crate-type rlib -L$(objtree)/$(obj) \ --crate-name $(patsubst %.o,%,$(notdir $@)) $< \ --sysroot=/dev/null \ - $(if $(rustc_objcopy),;$(OBJCOPY) $(rustc_objcopy) $@) + $(if $(rustc_objcopy),;$(OBJCOPY) $(rustc_objcopy) $@) \ + $(cmd_objtool) rust-analyzer: $(Q)$(srctree)/scripts/generate_rust_analyzer.py \ --cfgs='core=$(core-cfgs)' --cfgs='alloc=$(alloc-cfgs)' \ $(realpath $(srctree)) $(realpath $(objtree)) \ - $(RUST_LIB_SRC) $(KBUILD_EXTMOD) > \ + $(rustc_sysroot) $(RUST_LIB_SRC) $(KBUILD_EXTMOD) > \ $(if $(KBUILD_EXTMOD),$(extmod_prefix),$(objtree))/rust-project.json redirect-intrinsics = \ - __addsf3 __eqsf2 __gesf2 __lesf2 __ltsf2 __mulsf3 __nesf2 __unordsf2 \ - __adddf3 __ledf2 __ltdf2 __muldf3 __unorddf2 \ + __addsf3 __eqsf2 __extendsfdf2 __gesf2 __lesf2 __ltsf2 __mulsf3 __nesf2 __truncdfsf2 __unordsf2 \ + __adddf3 __eqdf2 __ledf2 __ltdf2 __muldf3 __unorddf2 \ __muloti4 __multi3 \ __udivmodti4 __udivti3 __umodti3 @@ -420,44 +378,50 @@ ifneq ($(or $(CONFIG_ARM64),$(and $(CONFIG_RISCV),$(CONFIG_64BIT))),) __ashlti3 __lshrti3 endif +define rule_rustc_library + $(call cmd_and_fixdep,rustc_library) + $(call cmd,gen_objtooldep) +endef + $(obj)/core.o: private skip_clippy = 1 -$(obj)/core.o: private skip_flags = -Dunreachable_pub +$(obj)/core.o: private skip_flags = -Wunreachable_pub $(obj)/core.o: private rustc_objcopy = $(foreach sym,$(redirect-intrinsics),--redefine-sym $(sym)=__rust$(sym)) $(obj)/core.o: private rustc_target_flags = $(core-cfgs) -$(obj)/core.o: $(RUST_LIB_SRC)/core/src/lib.rs FORCE - +$(call if_changed_dep,rustc_library) -ifdef CONFIG_X86_64 +$(obj)/core.o: $(RUST_LIB_SRC)/core/src/lib.rs \ + $(wildcard $(objtree)/include/config/RUSTC_VERSION_TEXT) FORCE + +$(call if_changed_rule,rustc_library) +ifneq ($(or $(CONFIG_X86_64),$(CONFIG_X86_32)),) $(obj)/core.o: scripts/target.json endif $(obj)/compiler_builtins.o: private rustc_objcopy = -w -W '__*' $(obj)/compiler_builtins.o: $(src)/compiler_builtins.rs $(obj)/core.o FORCE - +$(call if_changed_dep,rustc_library) + +$(call if_changed_rule,rustc_library) $(obj)/alloc.o: private skip_clippy = 1 -$(obj)/alloc.o: private skip_flags = -Dunreachable_pub +$(obj)/alloc.o: private skip_flags = -Wunreachable_pub $(obj)/alloc.o: private rustc_target_flags = $(alloc-cfgs) $(obj)/alloc.o: $(RUST_LIB_SRC)/alloc/src/lib.rs $(obj)/compiler_builtins.o FORCE - +$(call if_changed_dep,rustc_library) + +$(call if_changed_rule,rustc_library) $(obj)/build_error.o: $(src)/build_error.rs $(obj)/compiler_builtins.o FORCE - +$(call if_changed_dep,rustc_library) + +$(call if_changed_rule,rustc_library) $(obj)/bindings.o: $(src)/bindings/lib.rs \ $(obj)/compiler_builtins.o \ $(obj)/bindings/bindings_generated.rs \ $(obj)/bindings/bindings_helpers_generated.rs FORCE - +$(call if_changed_dep,rustc_library) + +$(call if_changed_rule,rustc_library) $(obj)/uapi.o: $(src)/uapi/lib.rs \ $(obj)/compiler_builtins.o \ $(obj)/uapi/uapi_generated.rs FORCE - +$(call if_changed_dep,rustc_library) + +$(call if_changed_rule,rustc_library) $(obj)/kernel.o: private rustc_target_flags = --extern alloc \ --extern build_error --extern macros --extern bindings --extern uapi $(obj)/kernel.o: $(src)/kernel/lib.rs $(obj)/alloc.o $(obj)/build_error.o \ $(obj)/libmacros.so $(obj)/bindings.o $(obj)/uapi.o FORCE - +$(call if_changed_dep,rustc_library) + +$(call if_changed_rule,rustc_library) endif # CONFIG_RUST diff --git a/rust/bindgen_parameters b/rust/bindgen_parameters index a721d466bee4..b7c7483123b7 100644 --- a/rust/bindgen_parameters +++ b/rust/bindgen_parameters @@ -24,3 +24,8 @@ # These functions use the `__preserve_most` calling convention, which neither bindgen # nor Rust currently understand, and which Clang currently declares to be unstable. --blocklist-function __list_.*_report + +# These constants are sometimes not recognized by bindgen depending on config. +# We use const helpers to aid bindgen, to avoid conflicts when constants are +# recognized, block generation of the non-helper constants. +--blocklist-item ARCH_SLAB_MINALIGN diff --git a/rust/bindings/bindings_helper.h b/rust/bindings/bindings_helper.h index ddb5644d4fd9..ae82e9c941af 100644 --- a/rust/bindings/bindings_helper.h +++ b/rust/bindings/bindings_helper.h @@ -7,8 +7,12 @@ */ #include <kunit/test.h> +#include <linux/blk-mq.h> +#include <linux/blk_types.h> +#include <linux/blkdev.h> #include <linux/errname.h> #include <linux/ethtool.h> +#include <linux/firmware.h> #include <linux/jiffies.h> #include <linux/mdio.h> #include <linux/phy.h> @@ -20,8 +24,11 @@ /* `bindgen` gets confused at certain things. */ const size_t RUST_CONST_HELPER_ARCH_SLAB_MINALIGN = ARCH_SLAB_MINALIGN; +const size_t RUST_CONST_HELPER_PAGE_SIZE = PAGE_SIZE; const gfp_t RUST_CONST_HELPER_GFP_ATOMIC = GFP_ATOMIC; const gfp_t RUST_CONST_HELPER_GFP_KERNEL = GFP_KERNEL; const gfp_t RUST_CONST_HELPER_GFP_KERNEL_ACCOUNT = GFP_KERNEL_ACCOUNT; const gfp_t RUST_CONST_HELPER_GFP_NOWAIT = GFP_NOWAIT; const gfp_t RUST_CONST_HELPER___GFP_ZERO = __GFP_ZERO; +const gfp_t RUST_CONST_HELPER___GFP_HIGHMEM = ___GFP_HIGHMEM; +const blk_features_t RUST_CONST_HELPER_BLK_FEAT_ROTATIONAL = BLK_FEAT_ROTATIONAL; diff --git a/rust/bindings/lib.rs b/rust/bindings/lib.rs index 40ddaee50d8b..93a1a3fc97bc 100644 --- a/rust/bindings/lib.rs +++ b/rust/bindings/lib.rs @@ -24,6 +24,7 @@ unsafe_op_in_unsafe_fn )] +#[allow(dead_code)] mod bindings_raw { // Use glob import here to expose all helpers. // Symbols defined within the module will take precedence to the glob import. diff --git a/rust/compiler_builtins.rs b/rust/compiler_builtins.rs index bba2922c6ef7..f14b8d7caf89 100644 --- a/rust/compiler_builtins.rs +++ b/rust/compiler_builtins.rs @@ -40,16 +40,19 @@ macro_rules! define_panicking_intrinsics( define_panicking_intrinsics!("`f32` should not be used", { __addsf3, __eqsf2, + __extendsfdf2, __gesf2, __lesf2, __ltsf2, __mulsf3, __nesf2, + __truncdfsf2, __unordsf2, }); define_panicking_intrinsics!("`f64` should not be used", { __adddf3, + __eqdf2, __ledf2, __ltdf2, __muldf3, diff --git a/rust/exports.c b/rust/exports.c index 3803c21d1403..e5695f3b45b7 100644 --- a/rust/exports.c +++ b/rust/exports.c @@ -17,6 +17,7 @@ #include "exports_core_generated.h" #include "exports_alloc_generated.h" +#include "exports_helpers_generated.h" #include "exports_bindings_generated.h" #include "exports_kernel_generated.h" diff --git a/rust/helpers.c b/rust/helpers.c deleted file mode 100644 index 2c37a0f5d7a8..000000000000 --- a/rust/helpers.c +++ /dev/null @@ -1,188 +0,0 @@ -// SPDX-License-Identifier: GPL-2.0 -/* - * Non-trivial C macros cannot be used in Rust. Similarly, inlined C functions - * cannot be called either. This file explicitly creates functions ("helpers") - * that wrap those so that they can be called from Rust. - * - * Even though Rust kernel modules should never use the bindings directly, some - * of these helpers need to be exported because Rust generics and inlined - * functions may not get their code generated in the crate where they are - * defined. Other helpers, called from non-inline functions, may not be - * exported, in principle. However, in general, the Rust compiler does not - * guarantee codegen will be performed for a non-inline function either. - * Therefore, this file exports all the helpers. In the future, this may be - * revisited to reduce the number of exports after the compiler is informed - * about the places codegen is required. - * - * All symbols are exported as GPL-only to guarantee no GPL-only feature is - * accidentally exposed. - * - * Sorted alphabetically. - */ - -#include <kunit/test-bug.h> -#include <linux/bug.h> -#include <linux/build_bug.h> -#include <linux/err.h> -#include <linux/errname.h> -#include <linux/mutex.h> -#include <linux/refcount.h> -#include <linux/sched/signal.h> -#include <linux/slab.h> -#include <linux/spinlock.h> -#include <linux/wait.h> -#include <linux/workqueue.h> - -__noreturn void rust_helper_BUG(void) -{ - BUG(); -} -EXPORT_SYMBOL_GPL(rust_helper_BUG); - -void rust_helper_mutex_lock(struct mutex *lock) -{ - mutex_lock(lock); -} -EXPORT_SYMBOL_GPL(rust_helper_mutex_lock); - -void rust_helper___spin_lock_init(spinlock_t *lock, const char *name, - struct lock_class_key *key) -{ -#ifdef CONFIG_DEBUG_SPINLOCK - __raw_spin_lock_init(spinlock_check(lock), name, key, LD_WAIT_CONFIG); -#else - spin_lock_init(lock); -#endif -} -EXPORT_SYMBOL_GPL(rust_helper___spin_lock_init); - -void rust_helper_spin_lock(spinlock_t *lock) -{ - spin_lock(lock); -} -EXPORT_SYMBOL_GPL(rust_helper_spin_lock); - -void rust_helper_spin_unlock(spinlock_t *lock) -{ - spin_unlock(lock); -} -EXPORT_SYMBOL_GPL(rust_helper_spin_unlock); - -void rust_helper_init_wait(struct wait_queue_entry *wq_entry) -{ - init_wait(wq_entry); -} -EXPORT_SYMBOL_GPL(rust_helper_init_wait); - -int rust_helper_signal_pending(struct task_struct *t) -{ - return signal_pending(t); -} -EXPORT_SYMBOL_GPL(rust_helper_signal_pending); - -refcount_t rust_helper_REFCOUNT_INIT(int n) -{ - return (refcount_t)REFCOUNT_INIT(n); -} -EXPORT_SYMBOL_GPL(rust_helper_REFCOUNT_INIT); - -void rust_helper_refcount_inc(refcount_t *r) -{ - refcount_inc(r); -} -EXPORT_SYMBOL_GPL(rust_helper_refcount_inc); - -bool rust_helper_refcount_dec_and_test(refcount_t *r) -{ - return refcount_dec_and_test(r); -} -EXPORT_SYMBOL_GPL(rust_helper_refcount_dec_and_test); - -__force void *rust_helper_ERR_PTR(long err) -{ - return ERR_PTR(err); -} -EXPORT_SYMBOL_GPL(rust_helper_ERR_PTR); - -bool rust_helper_IS_ERR(__force const void *ptr) -{ - return IS_ERR(ptr); -} -EXPORT_SYMBOL_GPL(rust_helper_IS_ERR); - -long rust_helper_PTR_ERR(__force const void *ptr) -{ - return PTR_ERR(ptr); -} -EXPORT_SYMBOL_GPL(rust_helper_PTR_ERR); - -const char *rust_helper_errname(int err) -{ - return errname(err); -} -EXPORT_SYMBOL_GPL(rust_helper_errname); - -struct task_struct *rust_helper_get_current(void) -{ - return current; -} -EXPORT_SYMBOL_GPL(rust_helper_get_current); - -void rust_helper_get_task_struct(struct task_struct *t) -{ - get_task_struct(t); -} -EXPORT_SYMBOL_GPL(rust_helper_get_task_struct); - -void rust_helper_put_task_struct(struct task_struct *t) -{ - put_task_struct(t); -} -EXPORT_SYMBOL_GPL(rust_helper_put_task_struct); - -struct kunit *rust_helper_kunit_get_current_test(void) -{ - return kunit_get_current_test(); -} -EXPORT_SYMBOL_GPL(rust_helper_kunit_get_current_test); - -void rust_helper_init_work_with_key(struct work_struct *work, work_func_t func, - bool onstack, const char *name, - struct lock_class_key *key) -{ - __init_work(work, onstack); - work->data = (atomic_long_t)WORK_DATA_INIT(); - lockdep_init_map(&work->lockdep_map, name, key, 0); - INIT_LIST_HEAD(&work->entry); - work->func = func; -} -EXPORT_SYMBOL_GPL(rust_helper_init_work_with_key); - -void * __must_check __realloc_size(2) -rust_helper_krealloc(const void *objp, size_t new_size, gfp_t flags) -{ - return krealloc(objp, new_size, flags); -} -EXPORT_SYMBOL_GPL(rust_helper_krealloc); - -/* - * `bindgen` binds the C `size_t` type as the Rust `usize` type, so we can - * use it in contexts where Rust expects a `usize` like slice (array) indices. - * `usize` is defined to be the same as C's `uintptr_t` type (can hold any - * pointer) but not necessarily the same as `size_t` (can hold the size of any - * single object). Most modern platforms use the same concrete integer type for - * both of them, but in case we find ourselves on a platform where - * that's not true, fail early instead of risking ABI or - * integer-overflow issues. - * - * If your platform fails this assertion, it means that you are in - * danger of integer-overflow bugs (even if you attempt to add - * `--no-size_t-is-usize`). It may be easiest to change the kernel ABI on - * your platform such that `size_t` matches `uintptr_t` (i.e., to increase - * `size_t`, because `uintptr_t` has to be at least as big as `size_t`). - */ -static_assert( - sizeof(size_t) == sizeof(uintptr_t) && - __alignof__(size_t) == __alignof__(uintptr_t), - "Rust code expects C `size_t` to match Rust `usize`" -); diff --git a/rust/helpers/blk.c b/rust/helpers/blk.c new file mode 100644 index 000000000000..cc9f4e6a2d23 --- /dev/null +++ b/rust/helpers/blk.c @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/blk-mq.h> +#include <linux/blkdev.h> + +void *rust_helper_blk_mq_rq_to_pdu(struct request *rq) +{ + return blk_mq_rq_to_pdu(rq); +} + +struct request *rust_helper_blk_mq_rq_from_pdu(void *pdu) +{ + return blk_mq_rq_from_pdu(pdu); +} diff --git a/rust/helpers/bug.c b/rust/helpers/bug.c new file mode 100644 index 000000000000..e2d13babc737 --- /dev/null +++ b/rust/helpers/bug.c @@ -0,0 +1,8 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/bug.h> + +__noreturn void rust_helper_BUG(void) +{ + BUG(); +} diff --git a/rust/helpers/build_assert.c b/rust/helpers/build_assert.c new file mode 100644 index 000000000000..6a54b2680b14 --- /dev/null +++ b/rust/helpers/build_assert.c @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/build_bug.h> + +/* + * `bindgen` binds the C `size_t` type as the Rust `usize` type, so we can + * use it in contexts where Rust expects a `usize` like slice (array) indices. + * `usize` is defined to be the same as C's `uintptr_t` type (can hold any + * pointer) but not necessarily the same as `size_t` (can hold the size of any + * single object). Most modern platforms use the same concrete integer type for + * both of them, but in case we find ourselves on a platform where + * that's not true, fail early instead of risking ABI or + * integer-overflow issues. + * + * If your platform fails this assertion, it means that you are in + * danger of integer-overflow bugs (even if you attempt to add + * `--no-size_t-is-usize`). It may be easiest to change the kernel ABI on + * your platform such that `size_t` matches `uintptr_t` (i.e., to increase + * `size_t`, because `uintptr_t` has to be at least as big as `size_t`). + */ +static_assert( + sizeof(size_t) == sizeof(uintptr_t) && + __alignof__(size_t) == __alignof__(uintptr_t), + "Rust code expects C `size_t` to match Rust `usize`" +); diff --git a/rust/helpers/build_bug.c b/rust/helpers/build_bug.c new file mode 100644 index 000000000000..e994f7b5928c --- /dev/null +++ b/rust/helpers/build_bug.c @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/export.h> +#include <linux/errname.h> + +const char *rust_helper_errname(int err) +{ + return errname(err); +} diff --git a/rust/helpers/err.c b/rust/helpers/err.c new file mode 100644 index 000000000000..be3d45ef78a2 --- /dev/null +++ b/rust/helpers/err.c @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/err.h> +#include <linux/export.h> + +__force void *rust_helper_ERR_PTR(long err) +{ + return ERR_PTR(err); +} + +bool rust_helper_IS_ERR(__force const void *ptr) +{ + return IS_ERR(ptr); +} + +long rust_helper_PTR_ERR(__force const void *ptr) +{ + return PTR_ERR(ptr); +} diff --git a/rust/helpers/helpers.c b/rust/helpers/helpers.c new file mode 100644 index 000000000000..30f40149f3a9 --- /dev/null +++ b/rust/helpers/helpers.c @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Non-trivial C macros cannot be used in Rust. Similarly, inlined C functions + * cannot be called either. This file explicitly creates functions ("helpers") + * that wrap those so that they can be called from Rust. + * + * Sorted alphabetically. + */ + +#include "blk.c" +#include "bug.c" +#include "build_assert.c" +#include "build_bug.c" +#include "err.c" +#include "kunit.c" +#include "mutex.c" +#include "page.c" +#include "rbtree.c" +#include "refcount.c" +#include "signal.c" +#include "slab.c" +#include "spinlock.c" +#include "task.c" +#include "uaccess.c" +#include "wait.c" +#include "workqueue.c" diff --git a/rust/helpers/kunit.c b/rust/helpers/kunit.c new file mode 100644 index 000000000000..9d725067eb3b --- /dev/null +++ b/rust/helpers/kunit.c @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <kunit/test-bug.h> +#include <linux/export.h> + +struct kunit *rust_helper_kunit_get_current_test(void) +{ + return kunit_get_current_test(); +} diff --git a/rust/helpers/mutex.c b/rust/helpers/mutex.c new file mode 100644 index 000000000000..a17ca8cdb50c --- /dev/null +++ b/rust/helpers/mutex.c @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/export.h> +#include <linux/mutex.h> + +void rust_helper_mutex_lock(struct mutex *lock) +{ + mutex_lock(lock); +} + +void rust_helper___mutex_init(struct mutex *mutex, const char *name, + struct lock_class_key *key) +{ + __mutex_init(mutex, name, key); +} diff --git a/rust/helpers/page.c b/rust/helpers/page.c new file mode 100644 index 000000000000..b3f2b8fbf87f --- /dev/null +++ b/rust/helpers/page.c @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/gfp.h> +#include <linux/highmem.h> + +struct page *rust_helper_alloc_pages(gfp_t gfp_mask, unsigned int order) +{ + return alloc_pages(gfp_mask, order); +} + +void *rust_helper_kmap_local_page(struct page *page) +{ + return kmap_local_page(page); +} + +void rust_helper_kunmap_local(const void *addr) +{ + kunmap_local(addr); +} diff --git a/rust/helpers/rbtree.c b/rust/helpers/rbtree.c new file mode 100644 index 000000000000..6d404b84a9b5 --- /dev/null +++ b/rust/helpers/rbtree.c @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/rbtree.h> + +void rust_helper_rb_link_node(struct rb_node *node, struct rb_node *parent, + struct rb_node **rb_link) +{ + rb_link_node(node, parent, rb_link); +} diff --git a/rust/helpers/refcount.c b/rust/helpers/refcount.c new file mode 100644 index 000000000000..f47afc148ec3 --- /dev/null +++ b/rust/helpers/refcount.c @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/export.h> +#include <linux/refcount.h> + +refcount_t rust_helper_REFCOUNT_INIT(int n) +{ + return (refcount_t)REFCOUNT_INIT(n); +} + +void rust_helper_refcount_inc(refcount_t *r) +{ + refcount_inc(r); +} + +bool rust_helper_refcount_dec_and_test(refcount_t *r) +{ + return refcount_dec_and_test(r); +} diff --git a/rust/helpers/signal.c b/rust/helpers/signal.c new file mode 100644 index 000000000000..63c407f80c26 --- /dev/null +++ b/rust/helpers/signal.c @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/export.h> +#include <linux/sched/signal.h> + +int rust_helper_signal_pending(struct task_struct *t) +{ + return signal_pending(t); +} diff --git a/rust/helpers/slab.c b/rust/helpers/slab.c new file mode 100644 index 000000000000..f043e087f9d6 --- /dev/null +++ b/rust/helpers/slab.c @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/slab.h> + +void * __must_check __realloc_size(2) +rust_helper_krealloc(const void *objp, size_t new_size, gfp_t flags) +{ + return krealloc(objp, new_size, flags); +} diff --git a/rust/helpers/spinlock.c b/rust/helpers/spinlock.c new file mode 100644 index 000000000000..acc1376b833c --- /dev/null +++ b/rust/helpers/spinlock.c @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/export.h> +#include <linux/spinlock.h> + +void rust_helper___spin_lock_init(spinlock_t *lock, const char *name, + struct lock_class_key *key) +{ +#ifdef CONFIG_DEBUG_SPINLOCK + __raw_spin_lock_init(spinlock_check(lock), name, key, LD_WAIT_CONFIG); +#else + spin_lock_init(lock); +#endif +} + +void rust_helper_spin_lock(spinlock_t *lock) +{ + spin_lock(lock); +} + +void rust_helper_spin_unlock(spinlock_t *lock) +{ + spin_unlock(lock); +} diff --git a/rust/helpers/task.c b/rust/helpers/task.c new file mode 100644 index 000000000000..7ac789232d11 --- /dev/null +++ b/rust/helpers/task.c @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/export.h> +#include <linux/sched/task.h> + +struct task_struct *rust_helper_get_current(void) +{ + return current; +} + +void rust_helper_get_task_struct(struct task_struct *t) +{ + get_task_struct(t); +} + +void rust_helper_put_task_struct(struct task_struct *t) +{ + put_task_struct(t); +} diff --git a/rust/helpers/uaccess.c b/rust/helpers/uaccess.c new file mode 100644 index 000000000000..f49076f813cd --- /dev/null +++ b/rust/helpers/uaccess.c @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/uaccess.h> + +unsigned long rust_helper_copy_from_user(void *to, const void __user *from, + unsigned long n) +{ + return copy_from_user(to, from, n); +} + +unsigned long rust_helper_copy_to_user(void __user *to, const void *from, + unsigned long n) +{ + return copy_to_user(to, from, n); +} diff --git a/rust/helpers/wait.c b/rust/helpers/wait.c new file mode 100644 index 000000000000..c7336bbf2750 --- /dev/null +++ b/rust/helpers/wait.c @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/export.h> +#include <linux/wait.h> + +void rust_helper_init_wait(struct wait_queue_entry *wq_entry) +{ + init_wait(wq_entry); +} diff --git a/rust/helpers/workqueue.c b/rust/helpers/workqueue.c new file mode 100644 index 000000000000..f59427acc323 --- /dev/null +++ b/rust/helpers/workqueue.c @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/export.h> +#include <linux/workqueue.h> + +void rust_helper_init_work_with_key(struct work_struct *work, work_func_t func, + bool onstack, const char *name, + struct lock_class_key *key) +{ + __init_work(work, onstack); + work->data = (atomic_long_t)WORK_DATA_INIT(); + lockdep_init_map(&work->lockdep_map, name, key, 0); + INIT_LIST_HEAD(&work->entry); + work->func = func; +} diff --git a/rust/kernel/alloc.rs b/rust/kernel/alloc.rs index 531b5e471cb1..1966bd407017 100644 --- a/rust/kernel/alloc.rs +++ b/rust/kernel/alloc.rs @@ -20,6 +20,13 @@ pub struct AllocError; #[derive(Clone, Copy)] pub struct Flags(u32); +impl Flags { + /// Get the raw representation of this flag. + pub(crate) fn as_raw(self) -> u32 { + self.0 + } +} + impl core::ops::BitOr for Flags { type Output = Self; fn bitor(self, rhs: Self) -> Self::Output { @@ -52,6 +59,14 @@ pub mod flags { /// This is normally or'd with other flags. pub const __GFP_ZERO: Flags = Flags(bindings::__GFP_ZERO); + /// Allow the allocation to be in high memory. + /// + /// Allocations in high memory may not be mapped into the kernel's address space, so this can't + /// be used with `kmalloc` and other similar methods. + /// + /// This is normally or'd with other flags. + pub const __GFP_HIGHMEM: Flags = Flags(bindings::__GFP_HIGHMEM); + /// Users can not sleep and need the allocation to succeed. /// /// A lower watermark is applied to allow access to "atomic reserves". The current @@ -66,7 +81,7 @@ pub mod flags { /// The same as [`GFP_KERNEL`], except the allocation is accounted to kmemcg. pub const GFP_KERNEL_ACCOUNT: Flags = Flags(bindings::GFP_KERNEL_ACCOUNT); - /// Ror kernel allocations that should not stall for direct reclaim, start physical IO or + /// For kernel allocations that should not stall for direct reclaim, start physical IO or /// use any filesystem callback. It is very likely to fail to allocate memory, even for very /// small allocations. pub const GFP_NOWAIT: Flags = Flags(bindings::GFP_NOWAIT); diff --git a/rust/kernel/alloc/allocator.rs b/rust/kernel/alloc/allocator.rs index 229642960cd1..e6ea601f38c6 100644 --- a/rust/kernel/alloc/allocator.rs +++ b/rust/kernel/alloc/allocator.rs @@ -18,23 +18,16 @@ pub(crate) unsafe fn krealloc_aligned(ptr: *mut u8, new_layout: Layout, flags: F // Customized layouts from `Layout::from_size_align()` can have size < align, so pad first. let layout = new_layout.pad_to_align(); - let mut size = layout.size(); - - if layout.align() > bindings::ARCH_SLAB_MINALIGN { - // The alignment requirement exceeds the slab guarantee, thus try to enlarge the size - // to use the "power-of-two" size/alignment guarantee (see comments in `kmalloc()` for - // more information). - // - // Note that `layout.size()` (after padding) is guaranteed to be a multiple of - // `layout.align()`, so `next_power_of_two` gives enough alignment guarantee. - size = size.next_power_of_two(); - } + // Note that `layout.size()` (after padding) is guaranteed to be a multiple of `layout.align()` + // which together with the slab guarantees means the `krealloc` will return a properly aligned + // object (see comments in `kmalloc()` for more information). + let size = layout.size(); // SAFETY: // - `ptr` is either null or a pointer returned from a previous `k{re}alloc()` by the // function safety requirement. - // - `size` is greater than 0 since it's either a `layout.size()` (which cannot be zero - // according to the function safety requirement) or a result from `next_power_of_two()`. + // - `size` is greater than 0 since it's from `layout.size()` (which cannot be zero according + // to the function safety requirement) unsafe { bindings::krealloc(ptr as *const core::ffi::c_void, size, flags.0) as *mut u8 } } diff --git a/rust/kernel/alloc/box_ext.rs b/rust/kernel/alloc/box_ext.rs index 829cb1c1cf9e..7009ad78d4e0 100644 --- a/rust/kernel/alloc/box_ext.rs +++ b/rust/kernel/alloc/box_ext.rs @@ -4,7 +4,7 @@ use super::{AllocError, Flags}; use alloc::boxed::Box; -use core::mem::MaybeUninit; +use core::{mem::MaybeUninit, ptr, result::Result}; /// Extensions to [`Box`]. pub trait BoxExt<T>: Sized { @@ -17,12 +17,32 @@ pub trait BoxExt<T>: Sized { /// /// The allocation may fail, in which case an error is returned. fn new_uninit(flags: Flags) -> Result<Box<MaybeUninit<T>>, AllocError>; + + /// Drops the contents, but keeps the allocation. + /// + /// # Examples + /// + /// ``` + /// use kernel::alloc::{flags, box_ext::BoxExt}; + /// let value = Box::new([0; 32], flags::GFP_KERNEL)?; + /// assert_eq!(*value, [0; 32]); + /// let mut value = Box::drop_contents(value); + /// // Now we can re-use `value`: + /// value.write([1; 32]); + /// // SAFETY: We just wrote to it. + /// let value = unsafe { value.assume_init() }; + /// assert_eq!(*value, [1; 32]); + /// # Ok::<(), Error>(()) + /// ``` + fn drop_contents(this: Self) -> Box<MaybeUninit<T>>; } impl<T> BoxExt<T> for Box<T> { fn new(x: T, flags: Flags) -> Result<Self, AllocError> { - let b = <Self as BoxExt<_>>::new_uninit(flags)?; - Ok(Box::write(b, x)) + let mut b = <Self as BoxExt<_>>::new_uninit(flags)?; + b.write(x); + // SAFETY: We just wrote to it. + Ok(unsafe { b.assume_init() }) } #[cfg(any(test, testlib))] @@ -53,4 +73,17 @@ impl<T> BoxExt<T> for Box<T> { // zero-sized types, we use `NonNull::dangling`. Ok(unsafe { Box::from_raw(ptr) }) } + + fn drop_contents(this: Self) -> Box<MaybeUninit<T>> { + let ptr = Box::into_raw(this); + // SAFETY: `ptr` is valid, because it came from `Box::into_raw`. + unsafe { ptr::drop_in_place(ptr) }; + + // CAST: `MaybeUninit<T>` is a transparent wrapper of `T`. + let ptr = ptr.cast::<MaybeUninit<T>>(); + + // SAFETY: `ptr` is valid for writes, because it came from `Box::into_raw` and it is valid for + // reads, since the pointer came from `Box::into_raw` and the type is `MaybeUninit<T>`. + unsafe { Box::from_raw(ptr) } + } } diff --git a/rust/kernel/block.rs b/rust/kernel/block.rs new file mode 100644 index 000000000000..150f710efe5b --- /dev/null +++ b/rust/kernel/block.rs @@ -0,0 +1,5 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Types for working with the block layer. + +pub mod mq; diff --git a/rust/kernel/block/mq.rs b/rust/kernel/block/mq.rs new file mode 100644 index 000000000000..fb0f393c1cea --- /dev/null +++ b/rust/kernel/block/mq.rs @@ -0,0 +1,98 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! This module provides types for implementing block drivers that interface the +//! blk-mq subsystem. +//! +//! To implement a block device driver, a Rust module must do the following: +//! +//! - Implement [`Operations`] for a type `T`. +//! - Create a [`TagSet<T>`]. +//! - Create a [`GenDisk<T>`], via the [`GenDiskBuilder`]. +//! - Add the disk to the system by calling [`GenDiskBuilder::build`] passing in +//! the `TagSet` reference. +//! +//! The types available in this module that have direct C counterparts are: +//! +//! - The [`TagSet`] type that abstracts the C type `struct tag_set`. +//! - The [`GenDisk`] type that abstracts the C type `struct gendisk`. +//! - The [`Request`] type that abstracts the C type `struct request`. +//! +//! The kernel will interface with the block device driver by calling the method +//! implementations of the `Operations` trait. +//! +//! IO requests are passed to the driver as [`kernel::types::ARef<Request>`] +//! instances. The `Request` type is a wrapper around the C `struct request`. +//! The driver must mark end of processing by calling one of the +//! `Request::end`, methods. Failure to do so can lead to deadlock or timeout +//! errors. Please note that the C function `blk_mq_start_request` is implicitly +//! called when the request is queued with the driver. +//! +//! The `TagSet` is responsible for creating and maintaining a mapping between +//! `Request`s and integer ids as well as carrying a pointer to the vtable +//! generated by `Operations`. This mapping is useful for associating +//! completions from hardware with the correct `Request` instance. The `TagSet` +//! determines the maximum queue depth by setting the number of `Request` +//! instances available to the driver, and it determines the number of queues to +//! instantiate for the driver. If possible, a driver should allocate one queue +//! per core, to keep queue data local to a core. +//! +//! One `TagSet` instance can be shared between multiple `GenDisk` instances. +//! This can be useful when implementing drivers where one piece of hardware +//! with one set of IO resources are represented to the user as multiple disks. +//! +//! One significant difference between block device drivers implemented with +//! these Rust abstractions and drivers implemented in C, is that the Rust +//! drivers have to own a reference count on the `Request` type when the IO is +//! in flight. This is to ensure that the C `struct request` instances backing +//! the Rust `Request` instances are live while the Rust driver holds a +//! reference to the `Request`. In addition, the conversion of an integer tag to +//! a `Request` via the `TagSet` would not be sound without this bookkeeping. +//! +//! [`GenDisk`]: gen_disk::GenDisk +//! [`GenDisk<T>`]: gen_disk::GenDisk +//! [`GenDiskBuilder`]: gen_disk::GenDiskBuilder +//! [`GenDiskBuilder::build`]: gen_disk::GenDiskBuilder::build +//! +//! # Example +//! +//! ```rust +//! use kernel::{ +//! alloc::flags, +//! block::mq::*, +//! new_mutex, +//! prelude::*, +//! sync::{Arc, Mutex}, +//! types::{ARef, ForeignOwnable}, +//! }; +//! +//! struct MyBlkDevice; +//! +//! #[vtable] +//! impl Operations for MyBlkDevice { +//! +//! fn queue_rq(rq: ARef<Request<Self>>, _is_last: bool) -> Result { +//! Request::end_ok(rq); +//! Ok(()) +//! } +//! +//! fn commit_rqs() {} +//! } +//! +//! let tagset: Arc<TagSet<MyBlkDevice>> = +//! Arc::pin_init(TagSet::new(1, 256, 1), flags::GFP_KERNEL)?; +//! let mut disk = gen_disk::GenDiskBuilder::new() +//! .capacity_sectors(4096) +//! .build(format_args!("myblk"), tagset)?; +//! +//! # Ok::<(), kernel::error::Error>(()) +//! ``` + +pub mod gen_disk; +mod operations; +mod raw_writer; +mod request; +mod tag_set; + +pub use operations::Operations; +pub use request::Request; +pub use tag_set::TagSet; diff --git a/rust/kernel/block/mq/gen_disk.rs b/rust/kernel/block/mq/gen_disk.rs new file mode 100644 index 000000000000..708125dce96a --- /dev/null +++ b/rust/kernel/block/mq/gen_disk.rs @@ -0,0 +1,196 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Generic disk abstraction. +//! +//! C header: [`include/linux/blkdev.h`](srctree/include/linux/blkdev.h) +//! C header: [`include/linux/blk_mq.h`](srctree/include/linux/blk_mq.h) + +use crate::block::mq::{raw_writer::RawWriter, Operations, TagSet}; +use crate::{bindings, error::from_err_ptr, error::Result, sync::Arc}; +use crate::{error, static_lock_class}; +use core::fmt::{self, Write}; + +/// A builder for [`GenDisk`]. +/// +/// Use this struct to configure and add new [`GenDisk`] to the VFS. +pub struct GenDiskBuilder { + rotational: bool, + logical_block_size: u32, + physical_block_size: u32, + capacity_sectors: u64, +} + +impl Default for GenDiskBuilder { + fn default() -> Self { + Self { + rotational: false, + logical_block_size: bindings::PAGE_SIZE as u32, + physical_block_size: bindings::PAGE_SIZE as u32, + capacity_sectors: 0, + } + } +} + +impl GenDiskBuilder { + /// Create a new instance. + pub fn new() -> Self { + Self::default() + } + + /// Set the rotational media attribute for the device to be built. + pub fn rotational(mut self, rotational: bool) -> Self { + self.rotational = rotational; + self + } + + /// Validate block size by verifying that it is between 512 and `PAGE_SIZE`, + /// and that it is a power of two. + fn validate_block_size(size: u32) -> Result<()> { + if !(512..=bindings::PAGE_SIZE as u32).contains(&size) || !size.is_power_of_two() { + Err(error::code::EINVAL) + } else { + Ok(()) + } + } + + /// Set the logical block size of the device to be built. + /// + /// This method will check that block size is a power of two and between 512 + /// and 4096. If not, an error is returned and the block size is not set. + /// + /// This is the smallest unit the storage device can address. It is + /// typically 4096 bytes. + pub fn logical_block_size(mut self, block_size: u32) -> Result<Self> { + Self::validate_block_size(block_size)?; + self.logical_block_size = block_size; + Ok(self) + } + + /// Set the physical block size of the device to be built. + /// + /// This method will check that block size is a power of two and between 512 + /// and 4096. If not, an error is returned and the block size is not set. + /// + /// This is the smallest unit a physical storage device can write + /// atomically. It is usually the same as the logical block size but may be + /// bigger. One example is SATA drives with 4096 byte physical block size + /// that expose a 512 byte logical block size to the operating system. + pub fn physical_block_size(mut self, block_size: u32) -> Result<Self> { + Self::validate_block_size(block_size)?; + self.physical_block_size = block_size; + Ok(self) + } + + /// Set the capacity of the device to be built, in sectors (512 bytes). + pub fn capacity_sectors(mut self, capacity: u64) -> Self { + self.capacity_sectors = capacity; + self + } + + /// Build a new `GenDisk` and add it to the VFS. + pub fn build<T: Operations>( + self, + name: fmt::Arguments<'_>, + tagset: Arc<TagSet<T>>, + ) -> Result<GenDisk<T>> { + // SAFETY: `bindings::queue_limits` contain only fields that are valid when zeroed. + let mut lim: bindings::queue_limits = unsafe { core::mem::zeroed() }; + + lim.logical_block_size = self.logical_block_size; + lim.physical_block_size = self.physical_block_size; + if self.rotational { + lim.features = bindings::BLK_FEAT_ROTATIONAL; + } + + // SAFETY: `tagset.raw_tag_set()` points to a valid and initialized tag set + let gendisk = from_err_ptr(unsafe { + bindings::__blk_mq_alloc_disk( + tagset.raw_tag_set(), + &mut lim, + core::ptr::null_mut(), + static_lock_class!().as_ptr(), + ) + })?; + + const TABLE: bindings::block_device_operations = bindings::block_device_operations { + submit_bio: None, + open: None, + release: None, + ioctl: None, + compat_ioctl: None, + check_events: None, + unlock_native_capacity: None, + getgeo: None, + set_read_only: None, + swap_slot_free_notify: None, + report_zones: None, + devnode: None, + alternative_gpt_sector: None, + get_unique_id: None, + // TODO: Set to THIS_MODULE. Waiting for const_refs_to_static feature to + // be merged (unstable in rustc 1.78 which is staged for linux 6.10) + // https://github.com/rust-lang/rust/issues/119618 + owner: core::ptr::null_mut(), + pr_ops: core::ptr::null_mut(), + free_disk: None, + poll_bio: None, + }; + + // SAFETY: `gendisk` is a valid pointer as we initialized it above + unsafe { (*gendisk).fops = &TABLE }; + + let mut raw_writer = RawWriter::from_array( + // SAFETY: `gendisk` points to a valid and initialized instance. We + // have exclusive access, since the disk is not added to the VFS + // yet. + unsafe { &mut (*gendisk).disk_name }, + )?; + raw_writer.write_fmt(name)?; + raw_writer.write_char('\0')?; + + // SAFETY: `gendisk` points to a valid and initialized instance of + // `struct gendisk`. `set_capacity` takes a lock to synchronize this + // operation, so we will not race. + unsafe { bindings::set_capacity(gendisk, self.capacity_sectors) }; + + crate::error::to_result( + // SAFETY: `gendisk` points to a valid and initialized instance of + // `struct gendisk`. + unsafe { + bindings::device_add_disk(core::ptr::null_mut(), gendisk, core::ptr::null_mut()) + }, + )?; + + // INVARIANT: `gendisk` was initialized above. + // INVARIANT: `gendisk` was added to the VFS via `device_add_disk` above. + Ok(GenDisk { + _tagset: tagset, + gendisk, + }) + } +} + +/// A generic block device. +/// +/// # Invariants +/// +/// - `gendisk` must always point to an initialized and valid `struct gendisk`. +/// - `gendisk` was added to the VFS through a call to +/// `bindings::device_add_disk`. +pub struct GenDisk<T: Operations> { + _tagset: Arc<TagSet<T>>, + gendisk: *mut bindings::gendisk, +} + +// SAFETY: `GenDisk` is an owned pointer to a `struct gendisk` and an `Arc` to a +// `TagSet` It is safe to send this to other threads as long as T is Send. +unsafe impl<T: Operations + Send> Send for GenDisk<T> {} + +impl<T: Operations> Drop for GenDisk<T> { + fn drop(&mut self) { + // SAFETY: By type invariant, `self.gendisk` points to a valid and + // initialized instance of `struct gendisk`, and it was previously added + // to the VFS. + unsafe { bindings::del_gendisk(self.gendisk) }; + } +} diff --git a/rust/kernel/block/mq/operations.rs b/rust/kernel/block/mq/operations.rs new file mode 100644 index 000000000000..9ba7fdfeb4b2 --- /dev/null +++ b/rust/kernel/block/mq/operations.rs @@ -0,0 +1,245 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! This module provides an interface for blk-mq drivers to implement. +//! +//! C header: [`include/linux/blk-mq.h`](srctree/include/linux/blk-mq.h) + +use crate::{ + bindings, + block::mq::request::RequestDataWrapper, + block::mq::Request, + error::{from_result, Result}, + types::ARef, +}; +use core::{marker::PhantomData, sync::atomic::AtomicU64, sync::atomic::Ordering}; + +/// Implement this trait to interface blk-mq as block devices. +/// +/// To implement a block device driver, implement this trait as described in the +/// [module level documentation]. The kernel will use the implementation of the +/// functions defined in this trait to interface a block device driver. Note: +/// There is no need for an exit_request() implementation, because the `drop` +/// implementation of the [`Request`] type will be invoked by automatically by +/// the C/Rust glue logic. +/// +/// [module level documentation]: kernel::block::mq +#[macros::vtable] +pub trait Operations: Sized { + /// Called by the kernel to queue a request with the driver. If `is_last` is + /// `false`, the driver is allowed to defer committing the request. + fn queue_rq(rq: ARef<Request<Self>>, is_last: bool) -> Result; + + /// Called by the kernel to indicate that queued requests should be submitted. + fn commit_rqs(); + + /// Called by the kernel to poll the device for completed requests. Only + /// used for poll queues. + fn poll() -> bool { + crate::build_error(crate::error::VTABLE_DEFAULT_ERROR) + } +} + +/// A vtable for blk-mq to interact with a block device driver. +/// +/// A `bindings::blk_mq_ops` vtable is constructed from pointers to the `extern +/// "C"` functions of this struct, exposed through the `OperationsVTable::VTABLE`. +/// +/// For general documentation of these methods, see the kernel source +/// documentation related to `struct blk_mq_operations` in +/// [`include/linux/blk-mq.h`]. +/// +/// [`include/linux/blk-mq.h`]: srctree/include/linux/blk-mq.h +pub(crate) struct OperationsVTable<T: Operations>(PhantomData<T>); + +impl<T: Operations> OperationsVTable<T> { + /// This function is called by the C kernel. A pointer to this function is + /// installed in the `blk_mq_ops` vtable for the driver. + /// + /// # Safety + /// + /// - The caller of this function must ensure that the pointee of `bd` is + /// valid for reads for the duration of this function. + /// - This function must be called for an initialized and live `hctx`. That + /// is, `Self::init_hctx_callback` was called and + /// `Self::exit_hctx_callback()` was not yet called. + /// - `(*bd).rq` must point to an initialized and live `bindings:request`. + /// That is, `Self::init_request_callback` was called but + /// `Self::exit_request_callback` was not yet called for the request. + /// - `(*bd).rq` must be owned by the driver. That is, the block layer must + /// promise to not access the request until the driver calls + /// `bindings::blk_mq_end_request` for the request. + unsafe extern "C" fn queue_rq_callback( + _hctx: *mut bindings::blk_mq_hw_ctx, + bd: *const bindings::blk_mq_queue_data, + ) -> bindings::blk_status_t { + // SAFETY: `bd.rq` is valid as required by the safety requirement for + // this function. + let request = unsafe { &*(*bd).rq.cast::<Request<T>>() }; + + // One refcount for the ARef, one for being in flight + request.wrapper_ref().refcount().store(2, Ordering::Relaxed); + + // SAFETY: + // - We own a refcount that we took above. We pass that to `ARef`. + // - By the safety requirements of this function, `request` is a valid + // `struct request` and the private data is properly initialized. + // - `rq` will be alive until `blk_mq_end_request` is called and is + // reference counted by `ARef` until then. + let rq = unsafe { Request::aref_from_raw((*bd).rq) }; + + // SAFETY: We have exclusive access and we just set the refcount above. + unsafe { Request::start_unchecked(&rq) }; + + let ret = T::queue_rq( + rq, + // SAFETY: `bd` is valid as required by the safety requirement for + // this function. + unsafe { (*bd).last }, + ); + + if let Err(e) = ret { + e.to_blk_status() + } else { + bindings::BLK_STS_OK as _ + } + } + + /// This function is called by the C kernel. A pointer to this function is + /// installed in the `blk_mq_ops` vtable for the driver. + /// + /// # Safety + /// + /// This function may only be called by blk-mq C infrastructure. + unsafe extern "C" fn commit_rqs_callback(_hctx: *mut bindings::blk_mq_hw_ctx) { + T::commit_rqs() + } + + /// This function is called by the C kernel. It is not currently + /// implemented, and there is no way to exercise this code path. + /// + /// # Safety + /// + /// This function may only be called by blk-mq C infrastructure. + unsafe extern "C" fn complete_callback(_rq: *mut bindings::request) {} + + /// This function is called by the C kernel. A pointer to this function is + /// installed in the `blk_mq_ops` vtable for the driver. + /// + /// # Safety + /// + /// This function may only be called by blk-mq C infrastructure. + unsafe extern "C" fn poll_callback( + _hctx: *mut bindings::blk_mq_hw_ctx, + _iob: *mut bindings::io_comp_batch, + ) -> core::ffi::c_int { + T::poll().into() + } + + /// This function is called by the C kernel. A pointer to this function is + /// installed in the `blk_mq_ops` vtable for the driver. + /// + /// # Safety + /// + /// This function may only be called by blk-mq C infrastructure. This + /// function may only be called once before `exit_hctx_callback` is called + /// for the same context. + unsafe extern "C" fn init_hctx_callback( + _hctx: *mut bindings::blk_mq_hw_ctx, + _tagset_data: *mut core::ffi::c_void, + _hctx_idx: core::ffi::c_uint, + ) -> core::ffi::c_int { + from_result(|| Ok(0)) + } + + /// This function is called by the C kernel. A pointer to this function is + /// installed in the `blk_mq_ops` vtable for the driver. + /// + /// # Safety + /// + /// This function may only be called by blk-mq C infrastructure. + unsafe extern "C" fn exit_hctx_callback( + _hctx: *mut bindings::blk_mq_hw_ctx, + _hctx_idx: core::ffi::c_uint, + ) { + } + + /// This function is called by the C kernel. A pointer to this function is + /// installed in the `blk_mq_ops` vtable for the driver. + /// + /// # Safety + /// + /// - This function may only be called by blk-mq C infrastructure. + /// - `_set` must point to an initialized `TagSet<T>`. + /// - `rq` must point to an initialized `bindings::request`. + /// - The allocation pointed to by `rq` must be at the size of `Request` + /// plus the size of `RequestDataWrapper`. + unsafe extern "C" fn init_request_callback( + _set: *mut bindings::blk_mq_tag_set, + rq: *mut bindings::request, + _hctx_idx: core::ffi::c_uint, + _numa_node: core::ffi::c_uint, + ) -> core::ffi::c_int { + from_result(|| { + // SAFETY: By the safety requirements of this function, `rq` points + // to a valid allocation. + let pdu = unsafe { Request::wrapper_ptr(rq.cast::<Request<T>>()) }; + + // SAFETY: The refcount field is allocated but not initialized, so + // it is valid for writes. + unsafe { RequestDataWrapper::refcount_ptr(pdu.as_ptr()).write(AtomicU64::new(0)) }; + + Ok(0) + }) + } + + /// This function is called by the C kernel. A pointer to this function is + /// installed in the `blk_mq_ops` vtable for the driver. + /// + /// # Safety + /// + /// - This function may only be called by blk-mq C infrastructure. + /// - `_set` must point to an initialized `TagSet<T>`. + /// - `rq` must point to an initialized and valid `Request`. + unsafe extern "C" fn exit_request_callback( + _set: *mut bindings::blk_mq_tag_set, + rq: *mut bindings::request, + _hctx_idx: core::ffi::c_uint, + ) { + // SAFETY: The tagset invariants guarantee that all requests are allocated with extra memory + // for the request data. + let pdu = unsafe { bindings::blk_mq_rq_to_pdu(rq) }.cast::<RequestDataWrapper>(); + + // SAFETY: `pdu` is valid for read and write and is properly initialised. + unsafe { core::ptr::drop_in_place(pdu) }; + } + + const VTABLE: bindings::blk_mq_ops = bindings::blk_mq_ops { + queue_rq: Some(Self::queue_rq_callback), + queue_rqs: None, + commit_rqs: Some(Self::commit_rqs_callback), + get_budget: None, + put_budget: None, + set_rq_budget_token: None, + get_rq_budget_token: None, + timeout: None, + poll: if T::HAS_POLL { + Some(Self::poll_callback) + } else { + None + }, + complete: Some(Self::complete_callback), + init_hctx: Some(Self::init_hctx_callback), + exit_hctx: Some(Self::exit_hctx_callback), + init_request: Some(Self::init_request_callback), + exit_request: Some(Self::exit_request_callback), + cleanup_rq: None, + busy: None, + map_queues: None, + #[cfg(CONFIG_BLK_DEBUG_FS)] + show_rq: None, + }; + + pub(crate) const fn build() -> &'static bindings::blk_mq_ops { + &Self::VTABLE + } +} diff --git a/rust/kernel/block/mq/raw_writer.rs b/rust/kernel/block/mq/raw_writer.rs new file mode 100644 index 000000000000..9222465d670b --- /dev/null +++ b/rust/kernel/block/mq/raw_writer.rs @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: GPL-2.0 + +use core::fmt::{self, Write}; + +use crate::error::Result; +use crate::prelude::EINVAL; + +/// A mutable reference to a byte buffer where a string can be written into. +/// +/// # Invariants +/// +/// `buffer` is always null terminated. +pub(crate) struct RawWriter<'a> { + buffer: &'a mut [u8], + pos: usize, +} + +impl<'a> RawWriter<'a> { + /// Create a new `RawWriter` instance. + fn new(buffer: &'a mut [u8]) -> Result<RawWriter<'a>> { + *(buffer.last_mut().ok_or(EINVAL)?) = 0; + + // INVARIANT: We null terminated the buffer above. + Ok(Self { buffer, pos: 0 }) + } + + pub(crate) fn from_array<const N: usize>( + a: &'a mut [core::ffi::c_char; N], + ) -> Result<RawWriter<'a>> { + Self::new( + // SAFETY: the buffer of `a` is valid for read and write as `u8` for + // at least `N` bytes. + unsafe { core::slice::from_raw_parts_mut(a.as_mut_ptr().cast::<u8>(), N) }, + ) + } +} + +impl Write for RawWriter<'_> { + fn write_str(&mut self, s: &str) -> fmt::Result { + let bytes = s.as_bytes(); + let len = bytes.len(); + + // We do not want to overwrite our null terminator + if self.pos + len > self.buffer.len() - 1 { + return Err(fmt::Error); + } + + // INVARIANT: We are not overwriting the last byte + self.buffer[self.pos..self.pos + len].copy_from_slice(bytes); + + self.pos += len; + + Ok(()) + } +} diff --git a/rust/kernel/block/mq/request.rs b/rust/kernel/block/mq/request.rs new file mode 100644 index 000000000000..a0e22827f3f4 --- /dev/null +++ b/rust/kernel/block/mq/request.rs @@ -0,0 +1,253 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! This module provides a wrapper for the C `struct request` type. +//! +//! C header: [`include/linux/blk-mq.h`](srctree/include/linux/blk-mq.h) + +use crate::{ + bindings, + block::mq::Operations, + error::Result, + types::{ARef, AlwaysRefCounted, Opaque}, +}; +use core::{ + marker::PhantomData, + ptr::{addr_of_mut, NonNull}, + sync::atomic::{AtomicU64, Ordering}, +}; + +/// A wrapper around a blk-mq `struct request`. This represents an IO request. +/// +/// # Implementation details +/// +/// There are four states for a request that the Rust bindings care about: +/// +/// A) Request is owned by block layer (refcount 0) +/// B) Request is owned by driver but with zero `ARef`s in existence +/// (refcount 1) +/// C) Request is owned by driver with exactly one `ARef` in existence +/// (refcount 2) +/// D) Request is owned by driver with more than one `ARef` in existence +/// (refcount > 2) +/// +/// +/// We need to track A and B to ensure we fail tag to request conversions for +/// requests that are not owned by the driver. +/// +/// We need to track C and D to ensure that it is safe to end the request and hand +/// back ownership to the block layer. +/// +/// The states are tracked through the private `refcount` field of +/// `RequestDataWrapper`. This structure lives in the private data area of the C +/// `struct request`. +/// +/// # Invariants +/// +/// * `self.0` is a valid `struct request` created by the C portion of the kernel. +/// * The private data area associated with this request must be an initialized +/// and valid `RequestDataWrapper<T>`. +/// * `self` is reference counted by atomic modification of +/// self.wrapper_ref().refcount(). +/// +#[repr(transparent)] +pub struct Request<T: Operations>(Opaque<bindings::request>, PhantomData<T>); + +impl<T: Operations> Request<T> { + /// Create an `ARef<Request>` from a `struct request` pointer. + /// + /// # Safety + /// + /// * The caller must own a refcount on `ptr` that is transferred to the + /// returned `ARef`. + /// * The type invariants for `Request` must hold for the pointee of `ptr`. + pub(crate) unsafe fn aref_from_raw(ptr: *mut bindings::request) -> ARef<Self> { + // INVARIANT: By the safety requirements of this function, invariants are upheld. + // SAFETY: By the safety requirement of this function, we own a + // reference count that we can pass to `ARef`. + unsafe { ARef::from_raw(NonNull::new_unchecked(ptr as *const Self as *mut Self)) } + } + + /// Notify the block layer that a request is going to be processed now. + /// + /// The block layer uses this hook to do proper initializations such as + /// starting the timeout timer. It is a requirement that block device + /// drivers call this function when starting to process a request. + /// + /// # Safety + /// + /// The caller must have exclusive ownership of `self`, that is + /// `self.wrapper_ref().refcount() == 2`. + pub(crate) unsafe fn start_unchecked(this: &ARef<Self>) { + // SAFETY: By type invariant, `self.0` is a valid `struct request` and + // we have exclusive access. + unsafe { bindings::blk_mq_start_request(this.0.get()) }; + } + + /// Try to take exclusive ownership of `this` by dropping the refcount to 0. + /// This fails if `this` is not the only `ARef` pointing to the underlying + /// `Request`. + /// + /// If the operation is successful, `Ok` is returned with a pointer to the + /// C `struct request`. If the operation fails, `this` is returned in the + /// `Err` variant. + fn try_set_end(this: ARef<Self>) -> Result<*mut bindings::request, ARef<Self>> { + // We can race with `TagSet::tag_to_rq` + if let Err(_old) = this.wrapper_ref().refcount().compare_exchange( + 2, + 0, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + return Err(this); + } + + let request_ptr = this.0.get(); + core::mem::forget(this); + + Ok(request_ptr) + } + + /// Notify the block layer that the request has been completed without errors. + /// + /// This function will return `Err` if `this` is not the only `ARef` + /// referencing the request. + pub fn end_ok(this: ARef<Self>) -> Result<(), ARef<Self>> { + let request_ptr = Self::try_set_end(this)?; + + // SAFETY: By type invariant, `this.0` was a valid `struct request`. The + // success of the call to `try_set_end` guarantees that there are no + // `ARef`s pointing to this request. Therefore it is safe to hand it + // back to the block layer. + unsafe { bindings::blk_mq_end_request(request_ptr, bindings::BLK_STS_OK as _) }; + + Ok(()) + } + + /// Return a pointer to the `RequestDataWrapper` stored in the private area + /// of the request structure. + /// + /// # Safety + /// + /// - `this` must point to a valid allocation of size at least size of + /// `Self` plus size of `RequestDataWrapper`. + pub(crate) unsafe fn wrapper_ptr(this: *mut Self) -> NonNull<RequestDataWrapper> { + let request_ptr = this.cast::<bindings::request>(); + // SAFETY: By safety requirements for this function, `this` is a + // valid allocation. + let wrapper_ptr = + unsafe { bindings::blk_mq_rq_to_pdu(request_ptr).cast::<RequestDataWrapper>() }; + // SAFETY: By C API contract, wrapper_ptr points to a valid allocation + // and is not null. + unsafe { NonNull::new_unchecked(wrapper_ptr) } + } + + /// Return a reference to the `RequestDataWrapper` stored in the private + /// area of the request structure. + pub(crate) fn wrapper_ref(&self) -> &RequestDataWrapper { + // SAFETY: By type invariant, `self.0` is a valid allocation. Further, + // the private data associated with this request is initialized and + // valid. The existence of `&self` guarantees that the private data is + // valid as a shared reference. + unsafe { Self::wrapper_ptr(self as *const Self as *mut Self).as_ref() } + } +} + +/// A wrapper around data stored in the private area of the C `struct request`. +pub(crate) struct RequestDataWrapper { + /// The Rust request refcount has the following states: + /// + /// - 0: The request is owned by C block layer. + /// - 1: The request is owned by Rust abstractions but there are no ARef references to it. + /// - 2+: There are `ARef` references to the request. + refcount: AtomicU64, +} + +impl RequestDataWrapper { + /// Return a reference to the refcount of the request that is embedding + /// `self`. + pub(crate) fn refcount(&self) -> &AtomicU64 { + &self.refcount + } + + /// Return a pointer to the refcount of the request that is embedding the + /// pointee of `this`. + /// + /// # Safety + /// + /// - `this` must point to a live allocation of at least the size of `Self`. + pub(crate) unsafe fn refcount_ptr(this: *mut Self) -> *mut AtomicU64 { + // SAFETY: Because of the safety requirements of this function, the + // field projection is safe. + unsafe { addr_of_mut!((*this).refcount) } + } +} + +// SAFETY: Exclusive access is thread-safe for `Request`. `Request` has no `&mut +// self` methods and `&self` methods that mutate `self` are internally +// synchronized. +unsafe impl<T: Operations> Send for Request<T> {} + +// SAFETY: Shared access is thread-safe for `Request`. `&self` methods that +// mutate `self` are internally synchronized` +unsafe impl<T: Operations> Sync for Request<T> {} + +/// Store the result of `op(target.load())` in target, returning new value of +/// target. +fn atomic_relaxed_op_return(target: &AtomicU64, op: impl Fn(u64) -> u64) -> u64 { + let old = target.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |x| Some(op(x))); + + // SAFETY: Because the operation passed to `fetch_update` above always + // return `Some`, `old` will always be `Ok`. + let old = unsafe { old.unwrap_unchecked() }; + + op(old) +} + +/// Store the result of `op(target.load)` in `target` if `target.load() != +/// pred`, returning true if the target was updated. +fn atomic_relaxed_op_unless(target: &AtomicU64, op: impl Fn(u64) -> u64, pred: u64) -> bool { + target + .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |x| { + if x == pred { + None + } else { + Some(op(x)) + } + }) + .is_ok() +} + +// SAFETY: All instances of `Request<T>` are reference counted. This +// implementation of `AlwaysRefCounted` ensure that increments to the ref count +// keeps the object alive in memory at least until a matching reference count +// decrement is executed. +unsafe impl<T: Operations> AlwaysRefCounted for Request<T> { + fn inc_ref(&self) { + let refcount = &self.wrapper_ref().refcount(); + + #[cfg_attr(not(CONFIG_DEBUG_MISC), allow(unused_variables))] + let updated = atomic_relaxed_op_unless(refcount, |x| x + 1, 0); + + #[cfg(CONFIG_DEBUG_MISC)] + if !updated { + panic!("Request refcount zero on clone") + } + } + + unsafe fn dec_ref(obj: core::ptr::NonNull<Self>) { + // SAFETY: The type invariants of `ARef` guarantee that `obj` is valid + // for read. + let wrapper_ptr = unsafe { Self::wrapper_ptr(obj.as_ptr()).as_ptr() }; + // SAFETY: The type invariant of `Request` guarantees that the private + // data area is initialized and valid. + let refcount = unsafe { &*RequestDataWrapper::refcount_ptr(wrapper_ptr) }; + + #[cfg_attr(not(CONFIG_DEBUG_MISC), allow(unused_variables))] + let new_refcount = atomic_relaxed_op_return(refcount, |x| x - 1); + + #[cfg(CONFIG_DEBUG_MISC)] + if new_refcount == 0 { + panic!("Request reached refcount zero in Rust abstractions"); + } + } +} diff --git a/rust/kernel/block/mq/tag_set.rs b/rust/kernel/block/mq/tag_set.rs new file mode 100644 index 000000000000..f9a1ca655a35 --- /dev/null +++ b/rust/kernel/block/mq/tag_set.rs @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! This module provides the `TagSet` struct to wrap the C `struct blk_mq_tag_set`. +//! +//! C header: [`include/linux/blk-mq.h`](srctree/include/linux/blk-mq.h) + +use core::pin::Pin; + +use crate::{ + bindings, + block::mq::{operations::OperationsVTable, request::RequestDataWrapper, Operations}, + error, + prelude::PinInit, + try_pin_init, + types::Opaque, +}; +use core::{convert::TryInto, marker::PhantomData}; +use macros::{pin_data, pinned_drop}; + +/// A wrapper for the C `struct blk_mq_tag_set`. +/// +/// `struct blk_mq_tag_set` contains a `struct list_head` and so must be pinned. +/// +/// # Invariants +/// +/// - `inner` is initialized and valid. +#[pin_data(PinnedDrop)] +#[repr(transparent)] +pub struct TagSet<T: Operations> { + #[pin] + inner: Opaque<bindings::blk_mq_tag_set>, + _p: PhantomData<T>, +} + +impl<T: Operations> TagSet<T> { + /// Try to create a new tag set + pub fn new( + nr_hw_queues: u32, + num_tags: u32, + num_maps: u32, + ) -> impl PinInit<Self, error::Error> { + // SAFETY: `blk_mq_tag_set` only contains integers and pointers, which + // all are allowed to be 0. + let tag_set: bindings::blk_mq_tag_set = unsafe { core::mem::zeroed() }; + let tag_set = core::mem::size_of::<RequestDataWrapper>() + .try_into() + .map(|cmd_size| { + bindings::blk_mq_tag_set { + ops: OperationsVTable::<T>::build(), + nr_hw_queues, + timeout: 0, // 0 means default which is 30Hz in C + numa_node: bindings::NUMA_NO_NODE, + queue_depth: num_tags, + cmd_size, + flags: bindings::BLK_MQ_F_SHOULD_MERGE, + driver_data: core::ptr::null_mut::<core::ffi::c_void>(), + nr_maps: num_maps, + ..tag_set + } + }); + + try_pin_init!(TagSet { + inner <- PinInit::<_, error::Error>::pin_chain(Opaque::new(tag_set?), |tag_set| { + // SAFETY: we do not move out of `tag_set`. + let tag_set = unsafe { Pin::get_unchecked_mut(tag_set) }; + // SAFETY: `tag_set` is a reference to an initialized `blk_mq_tag_set`. + error::to_result( unsafe { bindings::blk_mq_alloc_tag_set(tag_set.get())}) + }), + _p: PhantomData, + }) + } + + /// Return the pointer to the wrapped `struct blk_mq_tag_set` + pub(crate) fn raw_tag_set(&self) -> *mut bindings::blk_mq_tag_set { + self.inner.get() + } +} + +#[pinned_drop] +impl<T: Operations> PinnedDrop for TagSet<T> { + fn drop(self: Pin<&mut Self>) { + // SAFETY: By type invariant `inner` is valid and has been properly + // initialized during construction. + unsafe { bindings::blk_mq_free_tag_set(self.inner.get()) }; + } +} diff --git a/rust/kernel/device.rs b/rust/kernel/device.rs new file mode 100644 index 000000000000..c8199ee079ef --- /dev/null +++ b/rust/kernel/device.rs @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Generic devices that are part of the kernel's driver model. +//! +//! C header: [`include/linux/device.h`](srctree/include/linux/device.h) + +use crate::{ + bindings, + types::{ARef, Opaque}, +}; +use core::ptr; + +/// A reference-counted device. +/// +/// This structure represents the Rust abstraction for a C `struct device`. This implementation +/// abstracts the usage of an already existing C `struct device` within Rust code that we get +/// passed from the C side. +/// +/// An instance of this abstraction can be obtained temporarily or permanent. +/// +/// A temporary one is bound to the lifetime of the C `struct device` pointer used for creation. +/// A permanent instance is always reference-counted and hence not restricted by any lifetime +/// boundaries. +/// +/// For subsystems it is recommended to create a permanent instance to wrap into a subsystem +/// specific device structure (e.g. `pci::Device`). This is useful for passing it to drivers in +/// `T::probe()`, such that a driver can store the `ARef<Device>` (equivalent to storing a +/// `struct device` pointer in a C driver) for arbitrary purposes, e.g. allocating DMA coherent +/// memory. +/// +/// # Invariants +/// +/// A `Device` instance represents a valid `struct device` created by the C portion of the kernel. +/// +/// Instances of this type are always reference-counted, that is, a call to `get_device` ensures +/// that the allocation remains valid at least until the matching call to `put_device`. +/// +/// `bindings::device::release` is valid to be called from any thread, hence `ARef<Device>` can be +/// dropped from any thread. +#[repr(transparent)] +pub struct Device(Opaque<bindings::device>); + +impl Device { + /// Creates a new reference-counted abstraction instance of an existing `struct device` pointer. + /// + /// # Safety + /// + /// Callers must ensure that `ptr` is valid, non-null, and has a non-zero reference count, + /// i.e. it must be ensured that the reference count of the C `struct device` `ptr` points to + /// can't drop to zero, for the duration of this function call. + /// + /// It must also be ensured that `bindings::device::release` can be called from any thread. + /// While not officially documented, this should be the case for any `struct device`. + pub unsafe fn get_device(ptr: *mut bindings::device) -> ARef<Self> { + // SAFETY: By the safety requirements ptr is valid + unsafe { Self::as_ref(ptr) }.into() + } + + /// Obtain the raw `struct device *`. + pub(crate) fn as_raw(&self) -> *mut bindings::device { + self.0.get() + } + + /// Convert a raw C `struct device` pointer to a `&'a Device`. + /// + /// # Safety + /// + /// Callers must ensure that `ptr` is valid, non-null, and has a non-zero reference count, + /// i.e. it must be ensured that the reference count of the C `struct device` `ptr` points to + /// can't drop to zero, for the duration of this function call and the entire duration when the + /// returned reference exists. + pub unsafe fn as_ref<'a>(ptr: *mut bindings::device) -> &'a Self { + // SAFETY: Guaranteed by the safety requirements of the function. + unsafe { &*ptr.cast() } + } +} + +// SAFETY: Instances of `Device` are always reference-counted. +unsafe impl crate::types::AlwaysRefCounted for Device { + fn inc_ref(&self) { + // SAFETY: The existence of a shared reference guarantees that the refcount is non-zero. + unsafe { bindings::get_device(self.as_raw()) }; + } + + unsafe fn dec_ref(obj: ptr::NonNull<Self>) { + // SAFETY: The safety requirements guarantee that the refcount is non-zero. + unsafe { bindings::put_device(obj.cast().as_ptr()) } + } +} + +// SAFETY: As by the type invariant `Device` can be sent to any thread. +unsafe impl Send for Device {} + +// SAFETY: `Device` can be shared among threads because all immutable methods are protected by the +// synchronization in `struct device`. +unsafe impl Sync for Device {} diff --git a/rust/kernel/error.rs b/rust/kernel/error.rs index 55280ae9fe40..6f1587a2524e 100644 --- a/rust/kernel/error.rs +++ b/rust/kernel/error.rs @@ -126,11 +126,20 @@ impl Error { self.0 } + #[cfg(CONFIG_BLOCK)] + pub(crate) fn to_blk_status(self) -> bindings::blk_status_t { + // SAFETY: `self.0` is a valid error due to its invariant. + unsafe { bindings::errno_to_blk_status(self.0) } + } + /// Returns the error encoded as a pointer. #[allow(dead_code)] pub(crate) fn to_ptr<T>(self) -> *mut T { + #[cfg_attr(target_pointer_width = "32", allow(clippy::useless_conversion))] // SAFETY: `self.0` is a valid error due to its invariant. - unsafe { bindings::ERR_PTR(self.0.into()) as *mut _ } + unsafe { + bindings::ERR_PTR(self.0.into()) as *mut _ + } } /// Returns a string representing the error, if one exists. diff --git a/rust/kernel/firmware.rs b/rust/kernel/firmware.rs new file mode 100644 index 000000000000..13a374a5cdb7 --- /dev/null +++ b/rust/kernel/firmware.rs @@ -0,0 +1,117 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Firmware abstraction +//! +//! C header: [`include/linux/firmware.h`](srctree/include/linux/firmware.h) + +use crate::{bindings, device::Device, error::Error, error::Result, str::CStr}; +use core::ptr::NonNull; + +/// # Invariants +/// +/// One of the following: `bindings::request_firmware`, `bindings::firmware_request_nowarn`, +/// `bindings::firmware_request_platform`, `bindings::request_firmware_direct`. +struct FwFunc( + unsafe extern "C" fn(*mut *const bindings::firmware, *const i8, *mut bindings::device) -> i32, +); + +impl FwFunc { + fn request() -> Self { + Self(bindings::request_firmware) + } + + fn request_nowarn() -> Self { + Self(bindings::firmware_request_nowarn) + } +} + +/// Abstraction around a C `struct firmware`. +/// +/// This is a simple abstraction around the C firmware API. Just like with the C API, firmware can +/// be requested. Once requested the abstraction provides direct access to the firmware buffer as +/// `&[u8]`. The firmware is released once [`Firmware`] is dropped. +/// +/// # Invariants +/// +/// The pointer is valid, and has ownership over the instance of `struct firmware`. +/// +/// The `Firmware`'s backing buffer is not modified. +/// +/// # Examples +/// +/// ```no_run +/// # use kernel::{c_str, device::Device, firmware::Firmware}; +/// +/// # fn no_run() -> Result<(), Error> { +/// # // SAFETY: *NOT* safe, just for the example to get an `ARef<Device>` instance +/// # let dev = unsafe { Device::get_device(core::ptr::null_mut()) }; +/// +/// let fw = Firmware::request(c_str!("path/to/firmware.bin"), &dev)?; +/// let blob = fw.data(); +/// +/// # Ok(()) +/// # } +/// ``` +pub struct Firmware(NonNull<bindings::firmware>); + +impl Firmware { + fn request_internal(name: &CStr, dev: &Device, func: FwFunc) -> Result<Self> { + let mut fw: *mut bindings::firmware = core::ptr::null_mut(); + let pfw: *mut *mut bindings::firmware = &mut fw; + + // SAFETY: `pfw` is a valid pointer to a NULL initialized `bindings::firmware` pointer. + // `name` and `dev` are valid as by their type invariants. + let ret = unsafe { func.0(pfw as _, name.as_char_ptr(), dev.as_raw()) }; + if ret != 0 { + return Err(Error::from_errno(ret)); + } + + // SAFETY: `func` not bailing out with a non-zero error code, guarantees that `fw` is a + // valid pointer to `bindings::firmware`. + Ok(Firmware(unsafe { NonNull::new_unchecked(fw) })) + } + + /// Send a firmware request and wait for it. See also `bindings::request_firmware`. + pub fn request(name: &CStr, dev: &Device) -> Result<Self> { + Self::request_internal(name, dev, FwFunc::request()) + } + + /// Send a request for an optional firmware module. See also + /// `bindings::firmware_request_nowarn`. + pub fn request_nowarn(name: &CStr, dev: &Device) -> Result<Self> { + Self::request_internal(name, dev, FwFunc::request_nowarn()) + } + + fn as_raw(&self) -> *mut bindings::firmware { + self.0.as_ptr() + } + + /// Returns the size of the requested firmware in bytes. + pub fn size(&self) -> usize { + // SAFETY: `self.as_raw()` is valid by the type invariant. + unsafe { (*self.as_raw()).size } + } + + /// Returns the requested firmware as `&[u8]`. + pub fn data(&self) -> &[u8] { + // SAFETY: `self.as_raw()` is valid by the type invariant. Additionally, + // `bindings::firmware` guarantees, if successfully requested, that + // `bindings::firmware::data` has a size of `bindings::firmware::size` bytes. + unsafe { core::slice::from_raw_parts((*self.as_raw()).data, self.size()) } + } +} + +impl Drop for Firmware { + fn drop(&mut self) { + // SAFETY: `self.as_raw()` is valid by the type invariant. + unsafe { bindings::release_firmware(self.as_raw()) }; + } +} + +// SAFETY: `Firmware` only holds a pointer to a C `struct firmware`, which is safe to be used from +// any thread. +unsafe impl Send for Firmware {} + +// SAFETY: `Firmware` only holds a pointer to a C `struct firmware`, references to which are safe to +// be used from any thread. +unsafe impl Sync for Firmware {} diff --git a/rust/kernel/init.rs b/rust/kernel/init.rs index 68605b633e73..a17ac8762d8f 100644 --- a/rust/kernel/init.rs +++ b/rust/kernel/init.rs @@ -213,6 +213,7 @@ use crate::{ alloc::{box_ext::BoxExt, AllocError, Flags}, error::{self, Error}, + sync::Arc, sync::UniqueArc, types::{Opaque, ScopeGuard}, }; @@ -742,6 +743,74 @@ macro_rules! try_init { }; } +/// Asserts that a field on a struct using `#[pin_data]` is marked with `#[pin]` ie. that it is +/// structurally pinned. +/// +/// # Example +/// +/// This will succeed: +/// ``` +/// use kernel::assert_pinned; +/// #[pin_data] +/// struct MyStruct { +/// #[pin] +/// some_field: u64, +/// } +/// +/// assert_pinned!(MyStruct, some_field, u64); +/// ``` +/// +/// This will fail: +// TODO: replace with `compile_fail` when supported. +/// ```ignore +/// use kernel::assert_pinned; +/// #[pin_data] +/// struct MyStruct { +/// some_field: u64, +/// } +/// +/// assert_pinned!(MyStruct, some_field, u64); +/// ``` +/// +/// Some uses of the macro may trigger the `can't use generic parameters from outer item` error. To +/// work around this, you may pass the `inline` parameter to the macro. The `inline` parameter can +/// only be used when the macro is invoked from a function body. +/// ``` +/// use kernel::assert_pinned; +/// #[pin_data] +/// struct Foo<T> { +/// #[pin] +/// elem: T, +/// } +/// +/// impl<T> Foo<T> { +/// fn project(self: Pin<&mut Self>) -> Pin<&mut T> { +/// assert_pinned!(Foo<T>, elem, T, inline); +/// +/// // SAFETY: The field is structurally pinned. +/// unsafe { self.map_unchecked_mut(|me| &mut me.elem) } +/// } +/// } +/// ``` +#[macro_export] +macro_rules! assert_pinned { + ($ty:ty, $field:ident, $field_ty:ty, inline) => { + let _ = move |ptr: *mut $field_ty| { + // SAFETY: This code is unreachable. + let data = unsafe { <$ty as $crate::init::__internal::HasPinData>::__pin_data() }; + let init = $crate::init::__internal::AlwaysFail::<$field_ty>::new(); + // SAFETY: This code is unreachable. + unsafe { data.$field(ptr, init) }.ok(); + }; + }; + + ($ty:ty, $field:ident, $field_ty:ty) => { + const _: () = { + $crate::assert_pinned!($ty, $field, $field_ty, inline); + }; + }; +} + /// A pin-initializer for the type `T`. /// /// To use this initializer, you will need a suitable memory location that can hold a `T`. This can @@ -843,11 +912,8 @@ where let val = unsafe { &mut *slot }; // SAFETY: `slot` is considered pinned. let val = unsafe { Pin::new_unchecked(val) }; - (self.1)(val).map_err(|e| { - // SAFETY: `slot` was initialized above. - unsafe { core::ptr::drop_in_place(slot) }; - e - }) + // SAFETY: `slot` was initialized above. + (self.1)(val).inspect_err(|_| unsafe { core::ptr::drop_in_place(slot) }) } } @@ -941,11 +1007,9 @@ where // SAFETY: All requirements fulfilled since this function is `__init`. unsafe { self.0.__pinned_init(slot)? }; // SAFETY: The above call initialized `slot` and we still have unique access. - (self.1)(unsafe { &mut *slot }).map_err(|e| { + (self.1)(unsafe { &mut *slot }).inspect_err(|_| // SAFETY: `slot` was initialized above. - unsafe { core::ptr::drop_in_place(slot) }; - e - }) + unsafe { core::ptr::drop_in_place(slot) }) } } @@ -1112,11 +1176,17 @@ unsafe impl<T, E> PinInit<T, E> for T { /// Smart pointer that can initialize memory in-place. pub trait InPlaceInit<T>: Sized { + /// Pinned version of `Self`. + /// + /// If a type already implicitly pins its pointee, `Pin<Self>` is unnecessary. In this case use + /// `Self`, otherwise just use `Pin<Self>`. + type PinnedSelf; + /// Use the given pin-initializer to pin-initialize a `T` inside of a new smart pointer of this /// type. /// /// If `T: !Unpin` it will not be able to move afterwards. - fn try_pin_init<E>(init: impl PinInit<T, E>, flags: Flags) -> Result<Pin<Self>, E> + fn try_pin_init<E>(init: impl PinInit<T, E>, flags: Flags) -> Result<Self::PinnedSelf, E> where E: From<AllocError>; @@ -1124,7 +1194,7 @@ pub trait InPlaceInit<T>: Sized { /// type. /// /// If `T: !Unpin` it will not be able to move afterwards. - fn pin_init<E>(init: impl PinInit<T, E>, flags: Flags) -> error::Result<Pin<Self>> + fn pin_init<E>(init: impl PinInit<T, E>, flags: Flags) -> error::Result<Self::PinnedSelf> where Error: From<E>, { @@ -1153,19 +1223,35 @@ pub trait InPlaceInit<T>: Sized { } } +impl<T> InPlaceInit<T> for Arc<T> { + type PinnedSelf = Self; + + #[inline] + fn try_pin_init<E>(init: impl PinInit<T, E>, flags: Flags) -> Result<Self::PinnedSelf, E> + where + E: From<AllocError>, + { + UniqueArc::try_pin_init(init, flags).map(|u| u.into()) + } + + #[inline] + fn try_init<E>(init: impl Init<T, E>, flags: Flags) -> Result<Self, E> + where + E: From<AllocError>, + { + UniqueArc::try_init(init, flags).map(|u| u.into()) + } +} + impl<T> InPlaceInit<T> for Box<T> { + type PinnedSelf = Pin<Self>; + #[inline] - fn try_pin_init<E>(init: impl PinInit<T, E>, flags: Flags) -> Result<Pin<Self>, E> + fn try_pin_init<E>(init: impl PinInit<T, E>, flags: Flags) -> Result<Self::PinnedSelf, E> where E: From<AllocError>, { - let mut this = <Box<_> as BoxExt<_>>::new_uninit(flags)?; - let slot = this.as_mut_ptr(); - // SAFETY: When init errors/panics, slot will get deallocated but not dropped, - // slot is valid and will not be moved, because we pin it later. - unsafe { init.__pinned_init(slot)? }; - // SAFETY: All fields have been initialized. - Ok(unsafe { this.assume_init() }.into()) + <Box<_> as BoxExt<_>>::new_uninit(flags)?.write_pin_init(init) } #[inline] @@ -1173,29 +1259,19 @@ impl<T> InPlaceInit<T> for Box<T> { where E: From<AllocError>, { - let mut this = <Box<_> as BoxExt<_>>::new_uninit(flags)?; - let slot = this.as_mut_ptr(); - // SAFETY: When init errors/panics, slot will get deallocated but not dropped, - // slot is valid. - unsafe { init.__init(slot)? }; - // SAFETY: All fields have been initialized. - Ok(unsafe { this.assume_init() }) + <Box<_> as BoxExt<_>>::new_uninit(flags)?.write_init(init) } } impl<T> InPlaceInit<T> for UniqueArc<T> { + type PinnedSelf = Pin<Self>; + #[inline] - fn try_pin_init<E>(init: impl PinInit<T, E>, flags: Flags) -> Result<Pin<Self>, E> + fn try_pin_init<E>(init: impl PinInit<T, E>, flags: Flags) -> Result<Self::PinnedSelf, E> where E: From<AllocError>, { - let mut this = UniqueArc::new_uninit(flags)?; - let slot = this.as_mut_ptr(); - // SAFETY: When init errors/panics, slot will get deallocated but not dropped, - // slot is valid and will not be moved, because we pin it later. - unsafe { init.__pinned_init(slot)? }; - // SAFETY: All fields have been initialized. - Ok(unsafe { this.assume_init() }.into()) + UniqueArc::new_uninit(flags)?.write_pin_init(init) } #[inline] @@ -1203,13 +1279,67 @@ impl<T> InPlaceInit<T> for UniqueArc<T> { where E: From<AllocError>, { - let mut this = UniqueArc::new_uninit(flags)?; - let slot = this.as_mut_ptr(); + UniqueArc::new_uninit(flags)?.write_init(init) + } +} + +/// Smart pointer containing uninitialized memory and that can write a value. +pub trait InPlaceWrite<T> { + /// The type `Self` turns into when the contents are initialized. + type Initialized; + + /// Use the given initializer to write a value into `self`. + /// + /// Does not drop the current value and considers it as uninitialized memory. + fn write_init<E>(self, init: impl Init<T, E>) -> Result<Self::Initialized, E>; + + /// Use the given pin-initializer to write a value into `self`. + /// + /// Does not drop the current value and considers it as uninitialized memory. + fn write_pin_init<E>(self, init: impl PinInit<T, E>) -> Result<Pin<Self::Initialized>, E>; +} + +impl<T> InPlaceWrite<T> for Box<MaybeUninit<T>> { + type Initialized = Box<T>; + + fn write_init<E>(mut self, init: impl Init<T, E>) -> Result<Self::Initialized, E> { + let slot = self.as_mut_ptr(); // SAFETY: When init errors/panics, slot will get deallocated but not dropped, // slot is valid. unsafe { init.__init(slot)? }; // SAFETY: All fields have been initialized. - Ok(unsafe { this.assume_init() }) + Ok(unsafe { self.assume_init() }) + } + + fn write_pin_init<E>(mut self, init: impl PinInit<T, E>) -> Result<Pin<Self::Initialized>, E> { + let slot = self.as_mut_ptr(); + // SAFETY: When init errors/panics, slot will get deallocated but not dropped, + // slot is valid and will not be moved, because we pin it later. + unsafe { init.__pinned_init(slot)? }; + // SAFETY: All fields have been initialized. + Ok(unsafe { self.assume_init() }.into()) + } +} + +impl<T> InPlaceWrite<T> for UniqueArc<MaybeUninit<T>> { + type Initialized = UniqueArc<T>; + + fn write_init<E>(mut self, init: impl Init<T, E>) -> Result<Self::Initialized, E> { + let slot = self.as_mut_ptr(); + // SAFETY: When init errors/panics, slot will get deallocated but not dropped, + // slot is valid. + unsafe { init.__init(slot)? }; + // SAFETY: All fields have been initialized. + Ok(unsafe { self.assume_init() }) + } + + fn write_pin_init<E>(mut self, init: impl PinInit<T, E>) -> Result<Pin<Self::Initialized>, E> { + let slot = self.as_mut_ptr(); + // SAFETY: When init errors/panics, slot will get deallocated but not dropped, + // slot is valid and will not be moved, because we pin it later. + unsafe { init.__pinned_init(slot)? }; + // SAFETY: All fields have been initialized. + Ok(unsafe { self.assume_init() }.into()) } } diff --git a/rust/kernel/init/__internal.rs b/rust/kernel/init/__internal.rs index db3372619ecd..13cefd37512f 100644 --- a/rust/kernel/init/__internal.rs +++ b/rust/kernel/init/__internal.rs @@ -228,3 +228,32 @@ impl OnlyCallFromDrop { Self(()) } } + +/// Initializer that always fails. +/// +/// Used by [`assert_pinned!`]. +/// +/// [`assert_pinned!`]: crate::assert_pinned +pub struct AlwaysFail<T: ?Sized> { + _t: PhantomData<T>, +} + +impl<T: ?Sized> AlwaysFail<T> { + /// Creates a new initializer that always fails. + pub fn new() -> Self { + Self { _t: PhantomData } + } +} + +impl<T: ?Sized> Default for AlwaysFail<T> { + fn default() -> Self { + Self::new() + } +} + +// SAFETY: `__pinned_init` always fails, which is always okay. +unsafe impl<T: ?Sized> PinInit<T, ()> for AlwaysFail<T> { + unsafe fn __pinned_init(self, _slot: *mut T) -> Result<(), ()> { + Err(()) + } +} diff --git a/rust/kernel/init/macros.rs b/rust/kernel/init/macros.rs index 02ecedc4ae7a..9a0c4650ef67 100644 --- a/rust/kernel/init/macros.rs +++ b/rust/kernel/init/macros.rs @@ -145,7 +145,7 @@ //! } //! } //! // Implement the internal `PinData` trait that marks the pin-data struct as a pin-data -//! // struct. This is important to ensure that no user can implement a rouge `__pin_data` +//! // struct. This is important to ensure that no user can implement a rogue `__pin_data` //! // function without using `unsafe`. //! unsafe impl<T> ::kernel::init::__internal::PinData for __ThePinData<T> { //! type Datee = Bar<T>; @@ -156,7 +156,7 @@ //! // case no such fields exist, hence this is almost empty. The two phantomdata fields exist //! // for two reasons: //! // - `__phantom`: every generic must be used, since we cannot really know which generics -//! // are used, we declere all and then use everything here once. +//! // are used, we declare all and then use everything here once. //! // - `__phantom_pin`: uses the `'__pin` lifetime and ensures that this struct is invariant //! // over it. The lifetime is needed to work around the limitation that trait bounds must //! // not be trivial, e.g. the user has a `#[pin] PhantomPinned` field -- this is diff --git a/rust/kernel/kunit.rs b/rust/kernel/kunit.rs index 0ba77276ae7e..824da0e9738a 100644 --- a/rust/kernel/kunit.rs +++ b/rust/kernel/kunit.rs @@ -18,7 +18,7 @@ pub fn err(args: fmt::Arguments<'_>) { #[cfg(CONFIG_PRINTK)] unsafe { bindings::_printk( - b"\x013%pA\0".as_ptr() as _, + c"\x013%pA".as_ptr() as _, &args as *const _ as *const c_void, ); } @@ -34,7 +34,7 @@ pub fn info(args: fmt::Arguments<'_>) { #[cfg(CONFIG_PRINTK)] unsafe { bindings::_printk( - b"\x016%pA\0".as_ptr() as _, + c"\x016%pA".as_ptr() as _, &args as *const _ as *const c_void, ); } diff --git a/rust/kernel/lib.rs b/rust/kernel/lib.rs index fbd91a48ff8b..b5f4b3ce6b48 100644 --- a/rust/kernel/lib.rs +++ b/rust/kernel/lib.rs @@ -27,16 +27,25 @@ compile_error!("Missing kernel configuration for conditional compilation"); extern crate self as kernel; pub mod alloc; +#[cfg(CONFIG_BLOCK)] +pub mod block; mod build_assert; +pub mod device; pub mod error; +#[cfg(CONFIG_RUST_FW_LOADER_ABSTRACTIONS)] +pub mod firmware; pub mod init; pub mod ioctl; #[cfg(CONFIG_KUNIT)] pub mod kunit; +pub mod list; #[cfg(CONFIG_NET)] pub mod net; +pub mod page; pub mod prelude; pub mod print; +pub mod rbtree; +pub mod sizes; mod static_assert; #[doc(hidden)] pub mod std_vendor; @@ -45,6 +54,7 @@ pub mod sync; pub mod task; pub mod time; pub mod types; +pub mod uaccess; pub mod workqueue; #[doc(hidden)] diff --git a/rust/kernel/list.rs b/rust/kernel/list.rs new file mode 100644 index 000000000000..5b4aec29eb67 --- /dev/null +++ b/rust/kernel/list.rs @@ -0,0 +1,686 @@ +// SPDX-License-Identifier: GPL-2.0 + +// Copyright (C) 2024 Google LLC. + +//! A linked list implementation. + +use crate::init::PinInit; +use crate::sync::ArcBorrow; +use crate::types::Opaque; +use core::iter::{DoubleEndedIterator, FusedIterator}; +use core::marker::PhantomData; +use core::ptr; + +mod impl_list_item_mod; +pub use self::impl_list_item_mod::{ + impl_has_list_links, impl_has_list_links_self_ptr, impl_list_item, HasListLinks, HasSelfPtr, +}; + +mod arc; +pub use self::arc::{impl_list_arc_safe, AtomicTracker, ListArc, ListArcSafe, TryNewListArc}; + +mod arc_field; +pub use self::arc_field::{define_list_arc_field_getter, ListArcField}; + +/// A linked list. +/// +/// All elements in this linked list will be [`ListArc`] references to the value. Since a value can +/// only have one `ListArc` (for each pair of prev/next pointers), this ensures that the same +/// prev/next pointers are not used for several linked lists. +/// +/// # Invariants +/// +/// * If the list is empty, then `first` is null. Otherwise, `first` points at the `ListLinks` +/// field of the first element in the list. +/// * All prev/next pointers in `ListLinks` fields of items in the list are valid and form a cycle. +/// * For every item in the list, the list owns the associated [`ListArc`] reference and has +/// exclusive access to the `ListLinks` field. +pub struct List<T: ?Sized + ListItem<ID>, const ID: u64 = 0> { + first: *mut ListLinksFields, + _ty: PhantomData<ListArc<T, ID>>, +} + +// SAFETY: This is a container of `ListArc<T, ID>`, and access to the container allows the same +// type of access to the `ListArc<T, ID>` elements. +unsafe impl<T, const ID: u64> Send for List<T, ID> +where + ListArc<T, ID>: Send, + T: ?Sized + ListItem<ID>, +{ +} +// SAFETY: This is a container of `ListArc<T, ID>`, and access to the container allows the same +// type of access to the `ListArc<T, ID>` elements. +unsafe impl<T, const ID: u64> Sync for List<T, ID> +where + ListArc<T, ID>: Sync, + T: ?Sized + ListItem<ID>, +{ +} + +/// Implemented by types where a [`ListArc<Self>`] can be inserted into a [`List`]. +/// +/// # Safety +/// +/// Implementers must ensure that they provide the guarantees documented on methods provided by +/// this trait. +/// +/// [`ListArc<Self>`]: ListArc +pub unsafe trait ListItem<const ID: u64 = 0>: ListArcSafe<ID> { + /// Views the [`ListLinks`] for this value. + /// + /// # Guarantees + /// + /// If there is a previous call to `prepare_to_insert` and there is no call to `post_remove` + /// since the most recent such call, then this returns the same pointer as the one returned by + /// the most recent call to `prepare_to_insert`. + /// + /// Otherwise, the returned pointer points at a read-only [`ListLinks`] with two null pointers. + /// + /// # Safety + /// + /// The provided pointer must point at a valid value. (It need not be in an `Arc`.) + unsafe fn view_links(me: *const Self) -> *mut ListLinks<ID>; + + /// View the full value given its [`ListLinks`] field. + /// + /// Can only be used when the value is in a list. + /// + /// # Guarantees + /// + /// * Returns the same pointer as the one passed to the most recent call to `prepare_to_insert`. + /// * The returned pointer is valid until the next call to `post_remove`. + /// + /// # Safety + /// + /// * The provided pointer must originate from the most recent call to `prepare_to_insert`, or + /// from a call to `view_links` that happened after the most recent call to + /// `prepare_to_insert`. + /// * Since the most recent call to `prepare_to_insert`, the `post_remove` method must not have + /// been called. + unsafe fn view_value(me: *mut ListLinks<ID>) -> *const Self; + + /// This is called when an item is inserted into a [`List`]. + /// + /// # Guarantees + /// + /// The caller is granted exclusive access to the returned [`ListLinks`] until `post_remove` is + /// called. + /// + /// # Safety + /// + /// * The provided pointer must point at a valid value in an [`Arc`]. + /// * Calls to `prepare_to_insert` and `post_remove` on the same value must alternate. + /// * The caller must own the [`ListArc`] for this value. + /// * The caller must not give up ownership of the [`ListArc`] unless `post_remove` has been + /// called after this call to `prepare_to_insert`. + /// + /// [`Arc`]: crate::sync::Arc + unsafe fn prepare_to_insert(me: *const Self) -> *mut ListLinks<ID>; + + /// This undoes a previous call to `prepare_to_insert`. + /// + /// # Guarantees + /// + /// The returned pointer is the pointer that was originally passed to `prepare_to_insert`. + /// + /// # Safety + /// + /// The provided pointer must be the pointer returned by the most recent call to + /// `prepare_to_insert`. + unsafe fn post_remove(me: *mut ListLinks<ID>) -> *const Self; +} + +#[repr(C)] +#[derive(Copy, Clone)] +struct ListLinksFields { + next: *mut ListLinksFields, + prev: *mut ListLinksFields, +} + +/// The prev/next pointers for an item in a linked list. +/// +/// # Invariants +/// +/// The fields are null if and only if this item is not in a list. +#[repr(transparent)] +pub struct ListLinks<const ID: u64 = 0> { + // This type is `!Unpin` for aliasing reasons as the pointers are part of an intrusive linked + // list. + inner: Opaque<ListLinksFields>, +} + +// SAFETY: The only way to access/modify the pointers inside of `ListLinks<ID>` is via holding the +// associated `ListArc<T, ID>`. Since that type correctly implements `Send`, it is impossible to +// move this an instance of this type to a different thread if the pointees are `!Send`. +unsafe impl<const ID: u64> Send for ListLinks<ID> {} +// SAFETY: The type is opaque so immutable references to a ListLinks are useless. Therefore, it's +// okay to have immutable access to a ListLinks from several threads at once. +unsafe impl<const ID: u64> Sync for ListLinks<ID> {} + +impl<const ID: u64> ListLinks<ID> { + /// Creates a new initializer for this type. + pub fn new() -> impl PinInit<Self> { + // INVARIANT: Pin-init initializers can't be used on an existing `Arc`, so this value will + // not be constructed in an `Arc` that already has a `ListArc`. + ListLinks { + inner: Opaque::new(ListLinksFields { + prev: ptr::null_mut(), + next: ptr::null_mut(), + }), + } + } + + /// # Safety + /// + /// `me` must be dereferenceable. + #[inline] + unsafe fn fields(me: *mut Self) -> *mut ListLinksFields { + // SAFETY: The caller promises that the pointer is valid. + unsafe { Opaque::raw_get(ptr::addr_of!((*me).inner)) } + } + + /// # Safety + /// + /// `me` must be dereferenceable. + #[inline] + unsafe fn from_fields(me: *mut ListLinksFields) -> *mut Self { + me.cast() + } +} + +/// Similar to [`ListLinks`], but also contains a pointer to the full value. +/// +/// This type can be used instead of [`ListLinks`] to support lists with trait objects. +#[repr(C)] +pub struct ListLinksSelfPtr<T: ?Sized, const ID: u64 = 0> { + /// The `ListLinks` field inside this value. + /// + /// This is public so that it can be used with `impl_has_list_links!`. + pub inner: ListLinks<ID>, + // UnsafeCell is not enough here because we use `Opaque::uninit` as a dummy value, and + // `ptr::null()` doesn't work for `T: ?Sized`. + self_ptr: Opaque<*const T>, +} + +// SAFETY: The fields of a ListLinksSelfPtr can be moved across thread boundaries. +unsafe impl<T: ?Sized + Send, const ID: u64> Send for ListLinksSelfPtr<T, ID> {} +// SAFETY: The type is opaque so immutable references to a ListLinksSelfPtr are useless. Therefore, +// it's okay to have immutable access to a ListLinks from several threads at once. +// +// Note that `inner` being a public field does not prevent this type from being opaque, since +// `inner` is a opaque type. +unsafe impl<T: ?Sized + Sync, const ID: u64> Sync for ListLinksSelfPtr<T, ID> {} + +impl<T: ?Sized, const ID: u64> ListLinksSelfPtr<T, ID> { + /// The offset from the [`ListLinks`] to the self pointer field. + pub const LIST_LINKS_SELF_PTR_OFFSET: usize = core::mem::offset_of!(Self, self_ptr); + + /// Creates a new initializer for this type. + pub fn new() -> impl PinInit<Self> { + // INVARIANT: Pin-init initializers can't be used on an existing `Arc`, so this value will + // not be constructed in an `Arc` that already has a `ListArc`. + Self { + inner: ListLinks { + inner: Opaque::new(ListLinksFields { + prev: ptr::null_mut(), + next: ptr::null_mut(), + }), + }, + self_ptr: Opaque::uninit(), + } + } +} + +impl<T: ?Sized + ListItem<ID>, const ID: u64> List<T, ID> { + /// Creates a new empty list. + pub const fn new() -> Self { + Self { + first: ptr::null_mut(), + _ty: PhantomData, + } + } + + /// Returns whether this list is empty. + pub fn is_empty(&self) -> bool { + self.first.is_null() + } + + /// Add the provided item to the back of the list. + pub fn push_back(&mut self, item: ListArc<T, ID>) { + let raw_item = ListArc::into_raw(item); + // SAFETY: + // * We just got `raw_item` from a `ListArc`, so it's in an `Arc`. + // * Since we have ownership of the `ListArc`, `post_remove` must have been called after + // the most recent call to `prepare_to_insert`, if any. + // * We own the `ListArc`. + // * Removing items from this list is always done using `remove_internal_inner`, which + // calls `post_remove` before giving up ownership. + let list_links = unsafe { T::prepare_to_insert(raw_item) }; + // SAFETY: We have not yet called `post_remove`, so `list_links` is still valid. + let item = unsafe { ListLinks::fields(list_links) }; + + if self.first.is_null() { + self.first = item; + // SAFETY: The caller just gave us ownership of these fields. + // INVARIANT: A linked list with one item should be cyclic. + unsafe { + (*item).next = item; + (*item).prev = item; + } + } else { + let next = self.first; + // SAFETY: By the type invariant, this pointer is valid or null. We just checked that + // it's not null, so it must be valid. + let prev = unsafe { (*next).prev }; + // SAFETY: Pointers in a linked list are never dangling, and the caller just gave us + // ownership of the fields on `item`. + // INVARIANT: This correctly inserts `item` between `prev` and `next`. + unsafe { + (*item).next = next; + (*item).prev = prev; + (*prev).next = item; + (*next).prev = item; + } + } + } + + /// Add the provided item to the front of the list. + pub fn push_front(&mut self, item: ListArc<T, ID>) { + let raw_item = ListArc::into_raw(item); + // SAFETY: + // * We just got `raw_item` from a `ListArc`, so it's in an `Arc`. + // * If this requirement is violated, then the previous caller of `prepare_to_insert` + // violated the safety requirement that they can't give up ownership of the `ListArc` + // until they call `post_remove`. + // * We own the `ListArc`. + // * Removing items] from this list is always done using `remove_internal_inner`, which + // calls `post_remove` before giving up ownership. + let list_links = unsafe { T::prepare_to_insert(raw_item) }; + // SAFETY: We have not yet called `post_remove`, so `list_links` is still valid. + let item = unsafe { ListLinks::fields(list_links) }; + + if self.first.is_null() { + // SAFETY: The caller just gave us ownership of these fields. + // INVARIANT: A linked list with one item should be cyclic. + unsafe { + (*item).next = item; + (*item).prev = item; + } + } else { + let next = self.first; + // SAFETY: We just checked that `next` is non-null. + let prev = unsafe { (*next).prev }; + // SAFETY: Pointers in a linked list are never dangling, and the caller just gave us + // ownership of the fields on `item`. + // INVARIANT: This correctly inserts `item` between `prev` and `next`. + unsafe { + (*item).next = next; + (*item).prev = prev; + (*prev).next = item; + (*next).prev = item; + } + } + self.first = item; + } + + /// Removes the last item from this list. + pub fn pop_back(&mut self) -> Option<ListArc<T, ID>> { + if self.first.is_null() { + return None; + } + + // SAFETY: We just checked that the list is not empty. + let last = unsafe { (*self.first).prev }; + // SAFETY: The last item of this list is in this list. + Some(unsafe { self.remove_internal(last) }) + } + + /// Removes the first item from this list. + pub fn pop_front(&mut self) -> Option<ListArc<T, ID>> { + if self.first.is_null() { + return None; + } + + // SAFETY: The first item of this list is in this list. + Some(unsafe { self.remove_internal(self.first) }) + } + + /// Removes the provided item from this list and returns it. + /// + /// This returns `None` if the item is not in the list. (Note that by the safety requirements, + /// this means that the item is not in any list.) + /// + /// # Safety + /// + /// `item` must not be in a different linked list (with the same id). + pub unsafe fn remove(&mut self, item: &T) -> Option<ListArc<T, ID>> { + let mut item = unsafe { ListLinks::fields(T::view_links(item)) }; + // SAFETY: The user provided a reference, and reference are never dangling. + // + // As for why this is not a data race, there are two cases: + // + // * If `item` is not in any list, then these fields are read-only and null. + // * If `item` is in this list, then we have exclusive access to these fields since we + // have a mutable reference to the list. + // + // In either case, there's no race. + let ListLinksFields { next, prev } = unsafe { *item }; + + debug_assert_eq!(next.is_null(), prev.is_null()); + if !next.is_null() { + // This is really a no-op, but this ensures that `item` is a raw pointer that was + // obtained without going through a pointer->reference->pointer conversion roundtrip. + // This ensures that the list is valid under the more restrictive strict provenance + // ruleset. + // + // SAFETY: We just checked that `next` is not null, and it's not dangling by the + // list invariants. + unsafe { + debug_assert_eq!(item, (*next).prev); + item = (*next).prev; + } + + // SAFETY: We just checked that `item` is in a list, so the caller guarantees that it + // is in this list. The pointers are in the right order. + Some(unsafe { self.remove_internal_inner(item, next, prev) }) + } else { + None + } + } + + /// Removes the provided item from the list. + /// + /// # Safety + /// + /// `item` must point at an item in this list. + unsafe fn remove_internal(&mut self, item: *mut ListLinksFields) -> ListArc<T, ID> { + // SAFETY: The caller promises that this pointer is not dangling, and there's no data race + // since we have a mutable reference to the list containing `item`. + let ListLinksFields { next, prev } = unsafe { *item }; + // SAFETY: The pointers are ok and in the right order. + unsafe { self.remove_internal_inner(item, next, prev) } + } + + /// Removes the provided item from the list. + /// + /// # Safety + /// + /// The `item` pointer must point at an item in this list, and we must have `(*item).next == + /// next` and `(*item).prev == prev`. + unsafe fn remove_internal_inner( + &mut self, + item: *mut ListLinksFields, + next: *mut ListLinksFields, + prev: *mut ListLinksFields, + ) -> ListArc<T, ID> { + // SAFETY: We have exclusive access to the pointers of items in the list, and the prev/next + // pointers are always valid for items in a list. + // + // INVARIANT: There are three cases: + // * If the list has at least three items, then after removing the item, `prev` and `next` + // will be next to each other. + // * If the list has two items, then the remaining item will point at itself. + // * If the list has one item, then `next == prev == item`, so these writes have no + // effect. The list remains unchanged and `item` is still in the list for now. + unsafe { + (*next).prev = prev; + (*prev).next = next; + } + // SAFETY: We have exclusive access to items in the list. + // INVARIANT: `item` is being removed, so the pointers should be null. + unsafe { + (*item).prev = ptr::null_mut(); + (*item).next = ptr::null_mut(); + } + // INVARIANT: There are three cases: + // * If `item` was not the first item, then `self.first` should remain unchanged. + // * If `item` was the first item and there is another item, then we just updated + // `prev->next` to `next`, which is the new first item, and setting `item->next` to null + // did not modify `prev->next`. + // * If `item` was the only item in the list, then `prev == item`, and we just set + // `item->next` to null, so this correctly sets `first` to null now that the list is + // empty. + if self.first == item { + // SAFETY: The `prev` pointer is the value that `item->prev` had when it was in this + // list, so it must be valid. There is no race since `prev` is still in the list and we + // still have exclusive access to the list. + self.first = unsafe { (*prev).next }; + } + + // SAFETY: `item` used to be in the list, so it is dereferenceable by the type invariants + // of `List`. + let list_links = unsafe { ListLinks::from_fields(item) }; + // SAFETY: Any pointer in the list originates from a `prepare_to_insert` call. + let raw_item = unsafe { T::post_remove(list_links) }; + // SAFETY: The above call to `post_remove` guarantees that we can recreate the `ListArc`. + unsafe { ListArc::from_raw(raw_item) } + } + + /// Moves all items from `other` into `self`. + /// + /// The items of `other` are added to the back of `self`, so the last item of `other` becomes + /// the last item of `self`. + pub fn push_all_back(&mut self, other: &mut List<T, ID>) { + // First, we insert the elements into `self`. At the end, we make `other` empty. + if self.is_empty() { + // INVARIANT: All of the elements in `other` become elements of `self`. + self.first = other.first; + } else if !other.is_empty() { + let other_first = other.first; + // SAFETY: The other list is not empty, so this pointer is valid. + let other_last = unsafe { (*other_first).prev }; + let self_first = self.first; + // SAFETY: The self list is not empty, so this pointer is valid. + let self_last = unsafe { (*self_first).prev }; + + // SAFETY: We have exclusive access to both lists, so we can update the pointers. + // INVARIANT: This correctly sets the pointers to merge both lists. We do not need to + // update `self.first` because the first element of `self` does not change. + unsafe { + (*self_first).prev = other_last; + (*other_last).next = self_first; + (*self_last).next = other_first; + (*other_first).prev = self_last; + } + } + + // INVARIANT: The other list is now empty, so update its pointer. + other.first = ptr::null_mut(); + } + + /// Returns a cursor to the first element of the list. + /// + /// If the list is empty, this returns `None`. + pub fn cursor_front(&mut self) -> Option<Cursor<'_, T, ID>> { + if self.first.is_null() { + None + } else { + Some(Cursor { + current: self.first, + list: self, + }) + } + } + + /// Creates an iterator over the list. + pub fn iter(&self) -> Iter<'_, T, ID> { + // INVARIANT: If the list is empty, both pointers are null. Otherwise, both pointers point + // at the first element of the same list. + Iter { + current: self.first, + stop: self.first, + _ty: PhantomData, + } + } +} + +impl<T: ?Sized + ListItem<ID>, const ID: u64> Default for List<T, ID> { + fn default() -> Self { + List::new() + } +} + +impl<T: ?Sized + ListItem<ID>, const ID: u64> Drop for List<T, ID> { + fn drop(&mut self) { + while let Some(item) = self.pop_front() { + drop(item); + } + } +} + +/// An iterator over a [`List`]. +/// +/// # Invariants +/// +/// * There must be a [`List`] that is immutably borrowed for the duration of `'a`. +/// * The `current` pointer is null or points at a value in that [`List`]. +/// * The `stop` pointer is equal to the `first` field of that [`List`]. +#[derive(Clone)] +pub struct Iter<'a, T: ?Sized + ListItem<ID>, const ID: u64 = 0> { + current: *mut ListLinksFields, + stop: *mut ListLinksFields, + _ty: PhantomData<&'a ListArc<T, ID>>, +} + +impl<'a, T: ?Sized + ListItem<ID>, const ID: u64> Iterator for Iter<'a, T, ID> { + type Item = ArcBorrow<'a, T>; + + fn next(&mut self) -> Option<ArcBorrow<'a, T>> { + if self.current.is_null() { + return None; + } + + let current = self.current; + + // SAFETY: We just checked that `current` is not null, so it is in a list, and hence not + // dangling. There's no race because the iterator holds an immutable borrow to the list. + let next = unsafe { (*current).next }; + // INVARIANT: If `current` was the last element of the list, then this updates it to null. + // Otherwise, we update it to the next element. + self.current = if next != self.stop { + next + } else { + ptr::null_mut() + }; + + // SAFETY: The `current` pointer points at a value in the list. + let item = unsafe { T::view_value(ListLinks::from_fields(current)) }; + // SAFETY: + // * All values in a list are stored in an `Arc`. + // * The value cannot be removed from the list for the duration of the lifetime annotated + // on the returned `ArcBorrow`, because removing it from the list would require mutable + // access to the list. However, the `ArcBorrow` is annotated with the iterator's + // lifetime, and the list is immutably borrowed for that lifetime. + // * Values in a list never have a `UniqueArc` reference. + Some(unsafe { ArcBorrow::from_raw(item) }) + } +} + +/// A cursor into a [`List`]. +/// +/// # Invariants +/// +/// The `current` pointer points a value in `list`. +pub struct Cursor<'a, T: ?Sized + ListItem<ID>, const ID: u64 = 0> { + current: *mut ListLinksFields, + list: &'a mut List<T, ID>, +} + +impl<'a, T: ?Sized + ListItem<ID>, const ID: u64> Cursor<'a, T, ID> { + /// Access the current element of this cursor. + pub fn current(&self) -> ArcBorrow<'_, T> { + // SAFETY: The `current` pointer points a value in the list. + let me = unsafe { T::view_value(ListLinks::from_fields(self.current)) }; + // SAFETY: + // * All values in a list are stored in an `Arc`. + // * The value cannot be removed from the list for the duration of the lifetime annotated + // on the returned `ArcBorrow`, because removing it from the list would require mutable + // access to the cursor or the list. However, the `ArcBorrow` holds an immutable borrow + // on the cursor, which in turn holds a mutable borrow on the list, so any such + // mutable access requires first releasing the immutable borrow on the cursor. + // * Values in a list never have a `UniqueArc` reference, because the list has a `ListArc` + // reference, and `UniqueArc` references must be unique. + unsafe { ArcBorrow::from_raw(me) } + } + + /// Move the cursor to the next element. + pub fn next(self) -> Option<Cursor<'a, T, ID>> { + // SAFETY: The `current` field is always in a list. + let next = unsafe { (*self.current).next }; + + if next == self.list.first { + None + } else { + // INVARIANT: Since `self.current` is in the `list`, its `next` pointer is also in the + // `list`. + Some(Cursor { + current: next, + list: self.list, + }) + } + } + + /// Move the cursor to the previous element. + pub fn prev(self) -> Option<Cursor<'a, T, ID>> { + // SAFETY: The `current` field is always in a list. + let prev = unsafe { (*self.current).prev }; + + if self.current == self.list.first { + None + } else { + // INVARIANT: Since `self.current` is in the `list`, its `prev` pointer is also in the + // `list`. + Some(Cursor { + current: prev, + list: self.list, + }) + } + } + + /// Remove the current element from the list. + pub fn remove(self) -> ListArc<T, ID> { + // SAFETY: The `current` pointer always points at a member of the list. + unsafe { self.list.remove_internal(self.current) } + } +} + +impl<'a, T: ?Sized + ListItem<ID>, const ID: u64> FusedIterator for Iter<'a, T, ID> {} + +impl<'a, T: ?Sized + ListItem<ID>, const ID: u64> IntoIterator for &'a List<T, ID> { + type IntoIter = Iter<'a, T, ID>; + type Item = ArcBorrow<'a, T>; + + fn into_iter(self) -> Iter<'a, T, ID> { + self.iter() + } +} + +/// An owning iterator into a [`List`]. +pub struct IntoIter<T: ?Sized + ListItem<ID>, const ID: u64 = 0> { + list: List<T, ID>, +} + +impl<T: ?Sized + ListItem<ID>, const ID: u64> Iterator for IntoIter<T, ID> { + type Item = ListArc<T, ID>; + + fn next(&mut self) -> Option<ListArc<T, ID>> { + self.list.pop_front() + } +} + +impl<T: ?Sized + ListItem<ID>, const ID: u64> FusedIterator for IntoIter<T, ID> {} + +impl<T: ?Sized + ListItem<ID>, const ID: u64> DoubleEndedIterator for IntoIter<T, ID> { + fn next_back(&mut self) -> Option<ListArc<T, ID>> { + self.list.pop_back() + } +} + +impl<T: ?Sized + ListItem<ID>, const ID: u64> IntoIterator for List<T, ID> { + type IntoIter = IntoIter<T, ID>; + type Item = ListArc<T, ID>; + + fn into_iter(self) -> IntoIter<T, ID> { + IntoIter { list: self } + } +} diff --git a/rust/kernel/list/arc.rs b/rust/kernel/list/arc.rs new file mode 100644 index 000000000000..d801b9dc6291 --- /dev/null +++ b/rust/kernel/list/arc.rs @@ -0,0 +1,521 @@ +// SPDX-License-Identifier: GPL-2.0 + +// Copyright (C) 2024 Google LLC. + +//! A wrapper around `Arc` for linked lists. + +use crate::alloc::{AllocError, Flags}; +use crate::prelude::*; +use crate::sync::{Arc, ArcBorrow, UniqueArc}; +use core::marker::{PhantomPinned, Unsize}; +use core::ops::Deref; +use core::pin::Pin; +use core::sync::atomic::{AtomicBool, Ordering}; + +/// Declares that this type has some way to ensure that there is exactly one `ListArc` instance for +/// this id. +/// +/// Types that implement this trait should include some kind of logic for keeping track of whether +/// a [`ListArc`] exists or not. We refer to this logic as "the tracking inside `T`". +/// +/// We allow the case where the tracking inside `T` thinks that a [`ListArc`] exists, but actually, +/// there isn't a [`ListArc`]. However, we do not allow the opposite situation where a [`ListArc`] +/// exists, but the tracking thinks it doesn't. This is because the former can at most result in us +/// failing to create a [`ListArc`] when the operation could succeed, whereas the latter can result +/// in the creation of two [`ListArc`] references. Only the latter situation can lead to memory +/// safety issues. +/// +/// A consequence of the above is that you may implement the tracking inside `T` by not actually +/// keeping track of anything. To do this, you always claim that a [`ListArc`] exists, even if +/// there isn't one. This implementation is allowed by the above rule, but it means that +/// [`ListArc`] references can only be created if you have ownership of *all* references to the +/// refcounted object, as you otherwise have no way of knowing whether a [`ListArc`] exists. +pub trait ListArcSafe<const ID: u64 = 0> { + /// Informs the tracking inside this type that it now has a [`ListArc`] reference. + /// + /// This method may be called even if the tracking inside this type thinks that a `ListArc` + /// reference exists. (But only if that's not actually the case.) + /// + /// # Safety + /// + /// Must not be called if a [`ListArc`] already exist for this value. + unsafe fn on_create_list_arc_from_unique(self: Pin<&mut Self>); + + /// Informs the tracking inside this type that there is no [`ListArc`] reference anymore. + /// + /// # Safety + /// + /// Must only be called if there is no [`ListArc`] reference, but the tracking thinks there is. + unsafe fn on_drop_list_arc(&self); +} + +/// Declares that this type is able to safely attempt to create `ListArc`s at any time. +/// +/// # Safety +/// +/// The guarantees of `try_new_list_arc` must be upheld. +pub unsafe trait TryNewListArc<const ID: u64 = 0>: ListArcSafe<ID> { + /// Attempts to convert an `Arc<Self>` into an `ListArc<Self>`. Returns `true` if the + /// conversion was successful. + /// + /// This method should not be called directly. Use [`ListArc::try_from_arc`] instead. + /// + /// # Guarantees + /// + /// If this call returns `true`, then there is no [`ListArc`] pointing to this value. + /// Additionally, this call will have transitioned the tracking inside `Self` from not thinking + /// that a [`ListArc`] exists, to thinking that a [`ListArc`] exists. + fn try_new_list_arc(&self) -> bool; +} + +/// Declares that this type supports [`ListArc`]. +/// +/// This macro supports a few different strategies for implementing the tracking inside the type: +/// +/// * The `untracked` strategy does not actually keep track of whether a [`ListArc`] exists. When +/// using this strategy, the only way to create a [`ListArc`] is using a [`UniqueArc`]. +/// * The `tracked_by` strategy defers the tracking to a field of the struct. The user much specify +/// which field to defer the tracking to. The field must implement [`ListArcSafe`]. If the field +/// implements [`TryNewListArc`], then the type will also implement [`TryNewListArc`]. +/// +/// The `tracked_by` strategy is usually used by deferring to a field of type +/// [`AtomicTracker`]. However, it is also possible to defer the tracking to another struct +/// using also using this macro. +#[macro_export] +macro_rules! impl_list_arc_safe { + (impl$({$($generics:tt)*})? ListArcSafe<$num:tt> for $t:ty { untracked; } $($rest:tt)*) => { + impl$(<$($generics)*>)? $crate::list::ListArcSafe<$num> for $t { + unsafe fn on_create_list_arc_from_unique(self: ::core::pin::Pin<&mut Self>) {} + unsafe fn on_drop_list_arc(&self) {} + } + $crate::list::impl_list_arc_safe! { $($rest)* } + }; + + (impl$({$($generics:tt)*})? ListArcSafe<$num:tt> for $t:ty { + tracked_by $field:ident : $fty:ty; + } $($rest:tt)*) => { + impl$(<$($generics)*>)? $crate::list::ListArcSafe<$num> for $t { + unsafe fn on_create_list_arc_from_unique(self: ::core::pin::Pin<&mut Self>) { + $crate::assert_pinned!($t, $field, $fty, inline); + + // SAFETY: This field is structurally pinned as per the above assertion. + let field = unsafe { + ::core::pin::Pin::map_unchecked_mut(self, |me| &mut me.$field) + }; + // SAFETY: The caller promises that there is no `ListArc`. + unsafe { + <$fty as $crate::list::ListArcSafe<$num>>::on_create_list_arc_from_unique(field) + }; + } + unsafe fn on_drop_list_arc(&self) { + // SAFETY: The caller promises that there is no `ListArc` reference, and also + // promises that the tracking thinks there is a `ListArc` reference. + unsafe { <$fty as $crate::list::ListArcSafe<$num>>::on_drop_list_arc(&self.$field) }; + } + } + unsafe impl$(<$($generics)*>)? $crate::list::TryNewListArc<$num> for $t + where + $fty: TryNewListArc<$num>, + { + fn try_new_list_arc(&self) -> bool { + <$fty as $crate::list::TryNewListArc<$num>>::try_new_list_arc(&self.$field) + } + } + $crate::list::impl_list_arc_safe! { $($rest)* } + }; + + () => {}; +} +pub use impl_list_arc_safe; + +/// A wrapper around [`Arc`] that's guaranteed unique for the given id. +/// +/// The `ListArc` type can be thought of as a special reference to a refcounted object that owns the +/// permission to manipulate the `next`/`prev` pointers stored in the refcounted object. By ensuring +/// that each object has only one `ListArc` reference, the owner of that reference is assured +/// exclusive access to the `next`/`prev` pointers. When a `ListArc` is inserted into a [`List`], +/// the [`List`] takes ownership of the `ListArc` reference. +/// +/// There are various strategies to ensuring that a value has only one `ListArc` reference. The +/// simplest is to convert a [`UniqueArc`] into a `ListArc`. However, the refcounted object could +/// also keep track of whether a `ListArc` exists using a boolean, which could allow for the +/// creation of new `ListArc` references from an [`Arc`] reference. Whatever strategy is used, the +/// relevant tracking is referred to as "the tracking inside `T`", and the [`ListArcSafe`] trait +/// (and its subtraits) are used to update the tracking when a `ListArc` is created or destroyed. +/// +/// Note that we allow the case where the tracking inside `T` thinks that a `ListArc` exists, but +/// actually, there isn't a `ListArc`. However, we do not allow the opposite situation where a +/// `ListArc` exists, but the tracking thinks it doesn't. This is because the former can at most +/// result in us failing to create a `ListArc` when the operation could succeed, whereas the latter +/// can result in the creation of two `ListArc` references. +/// +/// While this `ListArc` is unique for the given id, there still might exist normal `Arc` +/// references to the object. +/// +/// # Invariants +/// +/// * Each reference counted object has at most one `ListArc` for each value of `ID`. +/// * The tracking inside `T` is aware that a `ListArc` reference exists. +/// +/// [`List`]: crate::list::List +#[repr(transparent)] +pub struct ListArc<T, const ID: u64 = 0> +where + T: ListArcSafe<ID> + ?Sized, +{ + arc: Arc<T>, +} + +impl<T: ListArcSafe<ID>, const ID: u64> ListArc<T, ID> { + /// Constructs a new reference counted instance of `T`. + #[inline] + pub fn new(contents: T, flags: Flags) -> Result<Self, AllocError> { + Ok(Self::from(UniqueArc::new(contents, flags)?)) + } + + /// Use the given initializer to in-place initialize a `T`. + /// + /// If `T: !Unpin` it will not be able to move afterwards. + // We don't implement `InPlaceInit` because `ListArc` is implicitly pinned. This is similar to + // what we do for `Arc`. + #[inline] + pub fn pin_init<E>(init: impl PinInit<T, E>, flags: Flags) -> Result<Self, E> + where + E: From<AllocError>, + { + Ok(Self::from(UniqueArc::try_pin_init(init, flags)?)) + } + + /// Use the given initializer to in-place initialize a `T`. + /// + /// This is equivalent to [`ListArc<T>::pin_init`], since a [`ListArc`] is always pinned. + #[inline] + pub fn init<E>(init: impl Init<T, E>, flags: Flags) -> Result<Self, E> + where + E: From<AllocError>, + { + Ok(Self::from(UniqueArc::try_init(init, flags)?)) + } +} + +impl<T, const ID: u64> From<UniqueArc<T>> for ListArc<T, ID> +where + T: ListArcSafe<ID> + ?Sized, +{ + /// Convert a [`UniqueArc`] into a [`ListArc`]. + #[inline] + fn from(unique: UniqueArc<T>) -> Self { + Self::from(Pin::from(unique)) + } +} + +impl<T, const ID: u64> From<Pin<UniqueArc<T>>> for ListArc<T, ID> +where + T: ListArcSafe<ID> + ?Sized, +{ + /// Convert a pinned [`UniqueArc`] into a [`ListArc`]. + #[inline] + fn from(mut unique: Pin<UniqueArc<T>>) -> Self { + // SAFETY: We have a `UniqueArc`, so there is no `ListArc`. + unsafe { T::on_create_list_arc_from_unique(unique.as_mut()) }; + let arc = Arc::from(unique); + // SAFETY: We just called `on_create_list_arc_from_unique` on an arc without a `ListArc`, + // so we can create a `ListArc`. + unsafe { Self::transmute_from_arc(arc) } + } +} + +impl<T, const ID: u64> ListArc<T, ID> +where + T: ListArcSafe<ID> + ?Sized, +{ + /// Creates two `ListArc`s from a [`UniqueArc`]. + /// + /// The two ids must be different. + #[inline] + pub fn pair_from_unique<const ID2: u64>(unique: UniqueArc<T>) -> (Self, ListArc<T, ID2>) + where + T: ListArcSafe<ID2>, + { + Self::pair_from_pin_unique(Pin::from(unique)) + } + + /// Creates two `ListArc`s from a pinned [`UniqueArc`]. + /// + /// The two ids must be different. + #[inline] + pub fn pair_from_pin_unique<const ID2: u64>( + mut unique: Pin<UniqueArc<T>>, + ) -> (Self, ListArc<T, ID2>) + where + T: ListArcSafe<ID2>, + { + build_assert!(ID != ID2); + + // SAFETY: We have a `UniqueArc`, so there is no `ListArc`. + unsafe { <T as ListArcSafe<ID>>::on_create_list_arc_from_unique(unique.as_mut()) }; + // SAFETY: We have a `UniqueArc`, so there is no `ListArc`. + unsafe { <T as ListArcSafe<ID2>>::on_create_list_arc_from_unique(unique.as_mut()) }; + + let arc1 = Arc::from(unique); + let arc2 = Arc::clone(&arc1); + + // SAFETY: We just called `on_create_list_arc_from_unique` on an arc without a `ListArc` + // for both IDs (which are different), so we can create two `ListArc`s. + unsafe { + ( + Self::transmute_from_arc(arc1), + ListArc::transmute_from_arc(arc2), + ) + } + } + + /// Try to create a new `ListArc`. + /// + /// This fails if this value already has a `ListArc`. + pub fn try_from_arc(arc: Arc<T>) -> Result<Self, Arc<T>> + where + T: TryNewListArc<ID>, + { + if arc.try_new_list_arc() { + // SAFETY: The `try_new_list_arc` method returned true, so we made the tracking think + // that a `ListArc` exists. This lets us create a `ListArc`. + Ok(unsafe { Self::transmute_from_arc(arc) }) + } else { + Err(arc) + } + } + + /// Try to create a new `ListArc`. + /// + /// This fails if this value already has a `ListArc`. + pub fn try_from_arc_borrow(arc: ArcBorrow<'_, T>) -> Option<Self> + where + T: TryNewListArc<ID>, + { + if arc.try_new_list_arc() { + // SAFETY: The `try_new_list_arc` method returned true, so we made the tracking think + // that a `ListArc` exists. This lets us create a `ListArc`. + Some(unsafe { Self::transmute_from_arc(Arc::from(arc)) }) + } else { + None + } + } + + /// Try to create a new `ListArc`. + /// + /// If it's not possible to create a new `ListArc`, then the `Arc` is dropped. This will never + /// run the destructor of the value. + pub fn try_from_arc_or_drop(arc: Arc<T>) -> Option<Self> + where + T: TryNewListArc<ID>, + { + match Self::try_from_arc(arc) { + Ok(list_arc) => Some(list_arc), + Err(arc) => Arc::into_unique_or_drop(arc).map(Self::from), + } + } + + /// Transmutes an [`Arc`] into a `ListArc` without updating the tracking inside `T`. + /// + /// # Safety + /// + /// * The value must not already have a `ListArc` reference. + /// * The tracking inside `T` must think that there is a `ListArc` reference. + #[inline] + unsafe fn transmute_from_arc(arc: Arc<T>) -> Self { + // INVARIANT: By the safety requirements, the invariants on `ListArc` are satisfied. + Self { arc } + } + + /// Transmutes a `ListArc` into an [`Arc`] without updating the tracking inside `T`. + /// + /// After this call, the tracking inside `T` will still think that there is a `ListArc` + /// reference. + #[inline] + fn transmute_to_arc(self) -> Arc<T> { + // Use a transmute to skip destructor. + // + // SAFETY: ListArc is repr(transparent). + unsafe { core::mem::transmute(self) } + } + + /// Convert ownership of this `ListArc` into a raw pointer. + /// + /// The returned pointer is indistinguishable from pointers returned by [`Arc::into_raw`]. The + /// tracking inside `T` will still think that a `ListArc` exists after this call. + #[inline] + pub fn into_raw(self) -> *const T { + Arc::into_raw(Self::transmute_to_arc(self)) + } + + /// Take ownership of the `ListArc` from a raw pointer. + /// + /// # Safety + /// + /// * `ptr` must satisfy the safety requirements of [`Arc::from_raw`]. + /// * The value must not already have a `ListArc` reference. + /// * The tracking inside `T` must think that there is a `ListArc` reference. + #[inline] + pub unsafe fn from_raw(ptr: *const T) -> Self { + // SAFETY: The pointer satisfies the safety requirements for `Arc::from_raw`. + let arc = unsafe { Arc::from_raw(ptr) }; + // SAFETY: The value doesn't already have a `ListArc` reference, but the tracking thinks it + // does. + unsafe { Self::transmute_from_arc(arc) } + } + + /// Converts the `ListArc` into an [`Arc`]. + #[inline] + pub fn into_arc(self) -> Arc<T> { + let arc = Self::transmute_to_arc(self); + // SAFETY: There is no longer a `ListArc`, but the tracking thinks there is. + unsafe { T::on_drop_list_arc(&arc) }; + arc + } + + /// Clone a `ListArc` into an [`Arc`]. + #[inline] + pub fn clone_arc(&self) -> Arc<T> { + self.arc.clone() + } + + /// Returns a reference to an [`Arc`] from the given [`ListArc`]. + /// + /// This is useful when the argument of a function call is an [`&Arc`] (e.g., in a method + /// receiver), but we have a [`ListArc`] instead. + /// + /// [`&Arc`]: Arc + #[inline] + pub fn as_arc(&self) -> &Arc<T> { + &self.arc + } + + /// Returns an [`ArcBorrow`] from the given [`ListArc`]. + /// + /// This is useful when the argument of a function call is an [`ArcBorrow`] (e.g., in a method + /// receiver), but we have an [`Arc`] instead. Getting an [`ArcBorrow`] is free when optimised. + #[inline] + pub fn as_arc_borrow(&self) -> ArcBorrow<'_, T> { + self.arc.as_arc_borrow() + } + + /// Compare whether two [`ListArc`] pointers reference the same underlying object. + #[inline] + pub fn ptr_eq(this: &Self, other: &Self) -> bool { + Arc::ptr_eq(&this.arc, &other.arc) + } +} + +impl<T, const ID: u64> Deref for ListArc<T, ID> +where + T: ListArcSafe<ID> + ?Sized, +{ + type Target = T; + + #[inline] + fn deref(&self) -> &Self::Target { + self.arc.deref() + } +} + +impl<T, const ID: u64> Drop for ListArc<T, ID> +where + T: ListArcSafe<ID> + ?Sized, +{ + #[inline] + fn drop(&mut self) { + // SAFETY: There is no longer a `ListArc`, but the tracking thinks there is by the type + // invariants on `Self`. + unsafe { T::on_drop_list_arc(&self.arc) }; + } +} + +impl<T, const ID: u64> AsRef<Arc<T>> for ListArc<T, ID> +where + T: ListArcSafe<ID> + ?Sized, +{ + #[inline] + fn as_ref(&self) -> &Arc<T> { + self.as_arc() + } +} + +// This is to allow [`ListArc`] (and variants) to be used as the type of `self`. +impl<T, const ID: u64> core::ops::Receiver for ListArc<T, ID> where T: ListArcSafe<ID> + ?Sized {} + +// This is to allow coercion from `ListArc<T>` to `ListArc<U>` if `T` can be converted to the +// dynamically-sized type (DST) `U`. +impl<T, U, const ID: u64> core::ops::CoerceUnsized<ListArc<U, ID>> for ListArc<T, ID> +where + T: ListArcSafe<ID> + Unsize<U> + ?Sized, + U: ListArcSafe<ID> + ?Sized, +{ +} + +// This is to allow `ListArc<U>` to be dispatched on when `ListArc<T>` can be coerced into +// `ListArc<U>`. +impl<T, U, const ID: u64> core::ops::DispatchFromDyn<ListArc<U, ID>> for ListArc<T, ID> +where + T: ListArcSafe<ID> + Unsize<U> + ?Sized, + U: ListArcSafe<ID> + ?Sized, +{ +} + +/// A utility for tracking whether a [`ListArc`] exists using an atomic. +/// +/// # Invariant +/// +/// If the boolean is `false`, then there is no [`ListArc`] for this value. +#[repr(transparent)] +pub struct AtomicTracker<const ID: u64 = 0> { + inner: AtomicBool, + // This value needs to be pinned to justify the INVARIANT: comment in `AtomicTracker::new`. + _pin: PhantomPinned, +} + +impl<const ID: u64> AtomicTracker<ID> { + /// Creates a new initializer for this type. + pub fn new() -> impl PinInit<Self> { + // INVARIANT: Pin-init initializers can't be used on an existing `Arc`, so this value will + // not be constructed in an `Arc` that already has a `ListArc`. + Self { + inner: AtomicBool::new(false), + _pin: PhantomPinned, + } + } + + fn project_inner(self: Pin<&mut Self>) -> &mut AtomicBool { + // SAFETY: The `inner` field is not structurally pinned, so we may obtain a mutable + // reference to it even if we only have a pinned reference to `self`. + unsafe { &mut Pin::into_inner_unchecked(self).inner } + } +} + +impl<const ID: u64> ListArcSafe<ID> for AtomicTracker<ID> { + unsafe fn on_create_list_arc_from_unique(self: Pin<&mut Self>) { + // INVARIANT: We just created a ListArc, so the boolean should be true. + *self.project_inner().get_mut() = true; + } + + unsafe fn on_drop_list_arc(&self) { + // INVARIANT: We just dropped a ListArc, so the boolean should be false. + self.inner.store(false, Ordering::Release); + } +} + +// SAFETY: If this method returns `true`, then by the type invariant there is no `ListArc` before +// this call, so it is okay to create a new `ListArc`. +// +// The acquire ordering will synchronize with the release store from the destruction of any +// previous `ListArc`, so if there was a previous `ListArc`, then the destruction of the previous +// `ListArc` happens-before the creation of the new `ListArc`. +unsafe impl<const ID: u64> TryNewListArc<ID> for AtomicTracker<ID> { + fn try_new_list_arc(&self) -> bool { + // INVARIANT: If this method returns true, then the boolean used to be false, and is no + // longer false, so it is okay for the caller to create a new [`ListArc`]. + self.inner + .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed) + .is_ok() + } +} diff --git a/rust/kernel/list/arc_field.rs b/rust/kernel/list/arc_field.rs new file mode 100644 index 000000000000..2330f673427a --- /dev/null +++ b/rust/kernel/list/arc_field.rs @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: GPL-2.0 + +// Copyright (C) 2024 Google LLC. + +//! A field that is exclusively owned by a [`ListArc`]. +//! +//! This can be used to have reference counted struct where one of the reference counted pointers +//! has exclusive access to a field of the struct. +//! +//! [`ListArc`]: crate::list::ListArc + +use core::cell::UnsafeCell; + +/// A field owned by a specific [`ListArc`]. +/// +/// [`ListArc`]: crate::list::ListArc +pub struct ListArcField<T, const ID: u64 = 0> { + value: UnsafeCell<T>, +} + +// SAFETY: If the inner type is thread-safe, then it's also okay for `ListArc` to be thread-safe. +unsafe impl<T: Send + Sync, const ID: u64> Send for ListArcField<T, ID> {} +// SAFETY: If the inner type is thread-safe, then it's also okay for `ListArc` to be thread-safe. +unsafe impl<T: Send + Sync, const ID: u64> Sync for ListArcField<T, ID> {} + +impl<T, const ID: u64> ListArcField<T, ID> { + /// Creates a new `ListArcField`. + pub fn new(value: T) -> Self { + Self { + value: UnsafeCell::new(value), + } + } + + /// Access the value when we have exclusive access to the `ListArcField`. + /// + /// This allows access to the field using an `UniqueArc` instead of a `ListArc`. + pub fn get_mut(&mut self) -> &mut T { + self.value.get_mut() + } + + /// Unsafely assert that you have shared access to the `ListArc` for this field. + /// + /// # Safety + /// + /// The caller must have shared access to the `ListArc<ID>` containing the struct with this + /// field for the duration of the returned reference. + pub unsafe fn assert_ref(&self) -> &T { + // SAFETY: The caller has shared access to the `ListArc`, so they also have shared access + // to this field. + unsafe { &*self.value.get() } + } + + /// Unsafely assert that you have mutable access to the `ListArc` for this field. + /// + /// # Safety + /// + /// The caller must have mutable access to the `ListArc<ID>` containing the struct with this + /// field for the duration of the returned reference. + #[allow(clippy::mut_from_ref)] + pub unsafe fn assert_mut(&self) -> &mut T { + // SAFETY: The caller has exclusive access to the `ListArc`, so they also have exclusive + // access to this field. + unsafe { &mut *self.value.get() } + } +} + +/// Defines getters for a [`ListArcField`]. +#[macro_export] +macro_rules! define_list_arc_field_getter { + ($pub:vis fn $name:ident(&self $(<$id:tt>)?) -> &$typ:ty { $field:ident } + $($rest:tt)* + ) => { + $pub fn $name<'a>(self: &'a $crate::list::ListArc<Self $(, $id)?>) -> &'a $typ { + let field = &(&**self).$field; + // SAFETY: We have a shared reference to the `ListArc`. + unsafe { $crate::list::ListArcField::<$typ $(, $id)?>::assert_ref(field) } + } + + $crate::list::define_list_arc_field_getter!($($rest)*); + }; + + ($pub:vis fn $name:ident(&mut self $(<$id:tt>)?) -> &mut $typ:ty { $field:ident } + $($rest:tt)* + ) => { + $pub fn $name<'a>(self: &'a mut $crate::list::ListArc<Self $(, $id)?>) -> &'a mut $typ { + let field = &(&**self).$field; + // SAFETY: We have a mutable reference to the `ListArc`. + unsafe { $crate::list::ListArcField::<$typ $(, $id)?>::assert_mut(field) } + } + + $crate::list::define_list_arc_field_getter!($($rest)*); + }; + + () => {}; +} +pub use define_list_arc_field_getter; diff --git a/rust/kernel/list/impl_list_item_mod.rs b/rust/kernel/list/impl_list_item_mod.rs new file mode 100644 index 000000000000..a0438537cee1 --- /dev/null +++ b/rust/kernel/list/impl_list_item_mod.rs @@ -0,0 +1,274 @@ +// SPDX-License-Identifier: GPL-2.0 + +// Copyright (C) 2024 Google LLC. + +//! Helpers for implementing list traits safely. + +use crate::list::ListLinks; + +/// Declares that this type has a `ListLinks<ID>` field at a fixed offset. +/// +/// This trait is only used to help implement `ListItem` safely. If `ListItem` is implemented +/// manually, then this trait is not needed. Use the [`impl_has_list_links!`] macro to implement +/// this trait. +/// +/// # Safety +/// +/// All values of this type must have a `ListLinks<ID>` field at the given offset. +/// +/// The behavior of `raw_get_list_links` must not be changed. +pub unsafe trait HasListLinks<const ID: u64 = 0> { + /// The offset of the `ListLinks` field. + const OFFSET: usize; + + /// Returns a pointer to the [`ListLinks<T, ID>`] field. + /// + /// # Safety + /// + /// The provided pointer must point at a valid struct of type `Self`. + /// + /// [`ListLinks<T, ID>`]: ListLinks + // We don't really need this method, but it's necessary for the implementation of + // `impl_has_list_links!` to be correct. + #[inline] + unsafe fn raw_get_list_links(ptr: *mut Self) -> *mut ListLinks<ID> { + // SAFETY: The caller promises that the pointer is valid. The implementer promises that the + // `OFFSET` constant is correct. + unsafe { (ptr as *mut u8).add(Self::OFFSET) as *mut ListLinks<ID> } + } +} + +/// Implements the [`HasListLinks`] trait for the given type. +#[macro_export] +macro_rules! impl_has_list_links { + ($(impl$(<$($implarg:ident),*>)? + HasListLinks$(<$id:tt>)? + for $self:ident $(<$($selfarg:ty),*>)? + { self$(.$field:ident)* } + )*) => {$( + // SAFETY: The implementation of `raw_get_list_links` only compiles if the field has the + // right type. + // + // The behavior of `raw_get_list_links` is not changed since the `addr_of_mut!` macro is + // equivalent to the pointer offset operation in the trait definition. + unsafe impl$(<$($implarg),*>)? $crate::list::HasListLinks$(<$id>)? for + $self $(<$($selfarg),*>)? + { + const OFFSET: usize = ::core::mem::offset_of!(Self, $($field).*) as usize; + + #[inline] + unsafe fn raw_get_list_links(ptr: *mut Self) -> *mut $crate::list::ListLinks$(<$id>)? { + // SAFETY: The caller promises that the pointer is not dangling. We know that this + // expression doesn't follow any pointers, as the `offset_of!` invocation above + // would otherwise not compile. + unsafe { ::core::ptr::addr_of_mut!((*ptr)$(.$field)*) } + } + } + )*}; +} +pub use impl_has_list_links; + +/// Declares that the `ListLinks<ID>` field in this struct is inside a `ListLinksSelfPtr<T, ID>`. +/// +/// # Safety +/// +/// The `ListLinks<ID>` field of this struct at the offset `HasListLinks<ID>::OFFSET` must be +/// inside a `ListLinksSelfPtr<T, ID>`. +pub unsafe trait HasSelfPtr<T: ?Sized, const ID: u64 = 0> +where + Self: HasListLinks<ID>, +{ +} + +/// Implements the [`HasListLinks`] and [`HasSelfPtr`] traits for the given type. +#[macro_export] +macro_rules! impl_has_list_links_self_ptr { + ($(impl$({$($implarg:tt)*})? + HasSelfPtr<$item_type:ty $(, $id:tt)?> + for $self:ident $(<$($selfarg:ty),*>)? + { self.$field:ident } + )*) => {$( + // SAFETY: The implementation of `raw_get_list_links` only compiles if the field has the + // right type. + unsafe impl$(<$($implarg)*>)? $crate::list::HasSelfPtr<$item_type $(, $id)?> for + $self $(<$($selfarg),*>)? + {} + + unsafe impl$(<$($implarg)*>)? $crate::list::HasListLinks$(<$id>)? for + $self $(<$($selfarg),*>)? + { + const OFFSET: usize = ::core::mem::offset_of!(Self, $field) as usize; + + #[inline] + unsafe fn raw_get_list_links(ptr: *mut Self) -> *mut $crate::list::ListLinks$(<$id>)? { + // SAFETY: The caller promises that the pointer is not dangling. + let ptr: *mut $crate::list::ListLinksSelfPtr<$item_type $(, $id)?> = + unsafe { ::core::ptr::addr_of_mut!((*ptr).$field) }; + ptr.cast() + } + } + )*}; +} +pub use impl_has_list_links_self_ptr; + +/// Implements the [`ListItem`] trait for the given type. +/// +/// Requires that the type implements [`HasListLinks`]. Use the [`impl_has_list_links!`] macro to +/// implement that trait. +/// +/// [`ListItem`]: crate::list::ListItem +#[macro_export] +macro_rules! impl_list_item { + ( + $(impl$({$($generics:tt)*})? ListItem<$num:tt> for $t:ty { + using ListLinks; + })* + ) => {$( + // SAFETY: See GUARANTEES comment on each method. + unsafe impl$(<$($generics)*>)? $crate::list::ListItem<$num> for $t { + // GUARANTEES: + // * This returns the same pointer as `prepare_to_insert` because `prepare_to_insert` + // is implemented in terms of `view_links`. + // * By the type invariants of `ListLinks`, the `ListLinks` has two null pointers when + // this value is not in a list. + unsafe fn view_links(me: *const Self) -> *mut $crate::list::ListLinks<$num> { + // SAFETY: The caller guarantees that `me` points at a valid value of type `Self`. + unsafe { + <Self as $crate::list::HasListLinks<$num>>::raw_get_list_links(me.cast_mut()) + } + } + + // GUARANTEES: + // * `me` originates from the most recent call to `prepare_to_insert`, which just added + // `offset` to the pointer passed to `prepare_to_insert`. This method subtracts + // `offset` from `me` so it returns the pointer originally passed to + // `prepare_to_insert`. + // * The pointer remains valid until the next call to `post_remove` because the caller + // of the most recent call to `prepare_to_insert` promised to retain ownership of the + // `ListArc` containing `Self` until the next call to `post_remove`. The value cannot + // be destroyed while a `ListArc` reference exists. + unsafe fn view_value(me: *mut $crate::list::ListLinks<$num>) -> *const Self { + let offset = <Self as $crate::list::HasListLinks<$num>>::OFFSET; + // SAFETY: `me` originates from the most recent call to `prepare_to_insert`, so it + // points at the field at offset `offset` in a value of type `Self`. Thus, + // subtracting `offset` from `me` is still in-bounds of the allocation. + unsafe { (me as *const u8).sub(offset) as *const Self } + } + + // GUARANTEES: + // This implementation of `ListItem` will not give out exclusive access to the same + // `ListLinks` several times because calls to `prepare_to_insert` and `post_remove` + // must alternate and exclusive access is given up when `post_remove` is called. + // + // Other invocations of `impl_list_item!` also cannot give out exclusive access to the + // same `ListLinks` because you can only implement `ListItem` once for each value of + // `ID`, and the `ListLinks` fields only work with the specified `ID`. + unsafe fn prepare_to_insert(me: *const Self) -> *mut $crate::list::ListLinks<$num> { + // SAFETY: The caller promises that `me` points at a valid value. + unsafe { <Self as $crate::list::ListItem<$num>>::view_links(me) } + } + + // GUARANTEES: + // * `me` originates from the most recent call to `prepare_to_insert`, which just added + // `offset` to the pointer passed to `prepare_to_insert`. This method subtracts + // `offset` from `me` so it returns the pointer originally passed to + // `prepare_to_insert`. + unsafe fn post_remove(me: *mut $crate::list::ListLinks<$num>) -> *const Self { + let offset = <Self as $crate::list::HasListLinks<$num>>::OFFSET; + // SAFETY: `me` originates from the most recent call to `prepare_to_insert`, so it + // points at the field at offset `offset` in a value of type `Self`. Thus, + // subtracting `offset` from `me` is still in-bounds of the allocation. + unsafe { (me as *const u8).sub(offset) as *const Self } + } + } + )*}; + + ( + $(impl$({$($generics:tt)*})? ListItem<$num:tt> for $t:ty { + using ListLinksSelfPtr; + })* + ) => {$( + // SAFETY: See GUARANTEES comment on each method. + unsafe impl$(<$($generics)*>)? $crate::list::ListItem<$num> for $t { + // GUARANTEES: + // This implementation of `ListItem` will not give out exclusive access to the same + // `ListLinks` several times because calls to `prepare_to_insert` and `post_remove` + // must alternate and exclusive access is given up when `post_remove` is called. + // + // Other invocations of `impl_list_item!` also cannot give out exclusive access to the + // same `ListLinks` because you can only implement `ListItem` once for each value of + // `ID`, and the `ListLinks` fields only work with the specified `ID`. + unsafe fn prepare_to_insert(me: *const Self) -> *mut $crate::list::ListLinks<$num> { + // SAFETY: The caller promises that `me` points at a valid value of type `Self`. + let links_field = unsafe { <Self as $crate::list::ListItem<$num>>::view_links(me) }; + + let spoff = $crate::list::ListLinksSelfPtr::<Self, $num>::LIST_LINKS_SELF_PTR_OFFSET; + // Goes via the offset as the field is private. + // + // SAFETY: The constant is equal to `offset_of!(ListLinksSelfPtr, self_ptr)`, so + // the pointer stays in bounds of the allocation. + let self_ptr = unsafe { (links_field as *const u8).add(spoff) } + as *const $crate::types::Opaque<*const Self>; + let cell_inner = $crate::types::Opaque::raw_get(self_ptr); + + // SAFETY: This value is not accessed in any other places than `prepare_to_insert`, + // `post_remove`, or `view_value`. By the safety requirements of those methods, + // none of these three methods may be called in parallel with this call to + // `prepare_to_insert`, so this write will not race with any other access to the + // value. + unsafe { ::core::ptr::write(cell_inner, me) }; + + links_field + } + + // GUARANTEES: + // * This returns the same pointer as `prepare_to_insert` because `prepare_to_insert` + // returns the return value of `view_links`. + // * By the type invariants of `ListLinks`, the `ListLinks` has two null pointers when + // this value is not in a list. + unsafe fn view_links(me: *const Self) -> *mut $crate::list::ListLinks<$num> { + // SAFETY: The caller promises that `me` points at a valid value of type `Self`. + unsafe { <Self as HasListLinks<$num>>::raw_get_list_links(me.cast_mut()) } + } + + // This function is also used as the implementation of `post_remove`, so the caller + // may choose to satisfy the safety requirements of `post_remove` instead of the safety + // requirements for `view_value`. + // + // GUARANTEES: (always) + // * This returns the same pointer as the one passed to the most recent call to + // `prepare_to_insert` since that call wrote that pointer to this location. The value + // is only modified in `prepare_to_insert`, so it has not been modified since the + // most recent call. + // + // GUARANTEES: (only when using the `view_value` safety requirements) + // * The pointer remains valid until the next call to `post_remove` because the caller + // of the most recent call to `prepare_to_insert` promised to retain ownership of the + // `ListArc` containing `Self` until the next call to `post_remove`. The value cannot + // be destroyed while a `ListArc` reference exists. + unsafe fn view_value(links_field: *mut $crate::list::ListLinks<$num>) -> *const Self { + let spoff = $crate::list::ListLinksSelfPtr::<Self, $num>::LIST_LINKS_SELF_PTR_OFFSET; + // SAFETY: The constant is equal to `offset_of!(ListLinksSelfPtr, self_ptr)`, so + // the pointer stays in bounds of the allocation. + let self_ptr = unsafe { (links_field as *const u8).add(spoff) } + as *const ::core::cell::UnsafeCell<*const Self>; + let cell_inner = ::core::cell::UnsafeCell::raw_get(self_ptr); + // SAFETY: This is not a data race, because the only function that writes to this + // value is `prepare_to_insert`, but by the safety requirements the + // `prepare_to_insert` method may not be called in parallel with `view_value` or + // `post_remove`. + unsafe { ::core::ptr::read(cell_inner) } + } + + // GUARANTEES: + // The first guarantee of `view_value` is exactly what `post_remove` guarantees. + unsafe fn post_remove(me: *mut $crate::list::ListLinks<$num>) -> *const Self { + // SAFETY: This specific implementation of `view_value` allows the caller to + // promise the safety requirements of `post_remove` instead of the safety + // requirements for `view_value`. + unsafe { <Self as $crate::list::ListItem<$num>>::view_value(me) } + } + } + )*}; +} +pub use impl_list_item; diff --git a/rust/kernel/net/phy.rs b/rust/kernel/net/phy.rs index fd40b703d224..910ce867480a 100644 --- a/rust/kernel/net/phy.rs +++ b/rust/kernel/net/phy.rs @@ -7,8 +7,9 @@ //! C headers: [`include/linux/phy.h`](srctree/include/linux/phy.h). use crate::{error::*, prelude::*, types::Opaque}; +use core::{marker::PhantomData, ptr::addr_of_mut}; -use core::marker::PhantomData; +pub mod reg; /// PHY state machine states. /// @@ -58,8 +59,9 @@ pub enum DuplexMode { /// /// # Invariants /// -/// Referencing a `phy_device` using this struct asserts that you are in -/// a context where all methods defined on this struct are safe to call. +/// - Referencing a `phy_device` using this struct asserts that you are in +/// a context where all methods defined on this struct are safe to call. +/// - This struct always has a valid `self.0.mdio.dev`. /// /// [`struct phy_device`]: srctree/include/linux/phy.h // During the calls to most functions in [`Driver`], the C side (`PHYLIB`) holds a lock that is @@ -76,9 +78,11 @@ impl Device { /// /// # Safety /// - /// For the duration of 'a, the pointer must point at a valid `phy_device`, - /// and the caller must be in a context where all methods defined on this struct - /// are safe to call. + /// For the duration of `'a`, + /// - the pointer must point at a valid `phy_device`, and the caller + /// must be in a context where all methods defined on this struct + /// are safe to call. + /// - `(*ptr).mdio.dev` must be a valid. unsafe fn from_raw<'a>(ptr: *mut bindings::phy_device) -> &'a mut Self { // CAST: `Self` is a `repr(transparent)` wrapper around `bindings::phy_device`. let ptr = ptr.cast::<Self>(); @@ -175,32 +179,15 @@ impl Device { unsafe { (*phydev).duplex = v }; } - /// Reads a given C22 PHY register. + /// Reads a PHY register. // This function reads a hardware register and updates the stats so takes `&mut self`. - pub fn read(&mut self, regnum: u16) -> Result<u16> { - let phydev = self.0.get(); - // SAFETY: `phydev` is pointing to a valid object by the type invariant of `Self`. - // So it's just an FFI call, open code of `phy_read()` with a valid `phy_device` pointer - // `phydev`. - let ret = unsafe { - bindings::mdiobus_read((*phydev).mdio.bus, (*phydev).mdio.addr, regnum.into()) - }; - if ret < 0 { - Err(Error::from_errno(ret)) - } else { - Ok(ret as u16) - } + pub fn read<R: reg::Register>(&mut self, reg: R) -> Result<u16> { + reg.read(self) } - /// Writes a given C22 PHY register. - pub fn write(&mut self, regnum: u16, val: u16) -> Result { - let phydev = self.0.get(); - // SAFETY: `phydev` is pointing to a valid object by the type invariant of `Self`. - // So it's just an FFI call, open code of `phy_write()` with a valid `phy_device` pointer - // `phydev`. - to_result(unsafe { - bindings::mdiobus_write((*phydev).mdio.bus, (*phydev).mdio.addr, regnum.into(), val) - }) + /// Writes a PHY register. + pub fn write<R: reg::Register>(&mut self, reg: R, val: u16) -> Result { + reg.write(self, val) } /// Reads a paged register. @@ -265,16 +252,8 @@ impl Device { } /// Checks the link status and updates current link state. - pub fn genphy_read_status(&mut self) -> Result<u16> { - let phydev = self.0.get(); - // SAFETY: `phydev` is pointing to a valid object by the type invariant of `Self`. - // So it's just an FFI call. - let ret = unsafe { bindings::genphy_read_status(phydev) }; - if ret < 0 { - Err(Error::from_errno(ret)) - } else { - Ok(ret as u16) - } + pub fn genphy_read_status<R: reg::Register>(&mut self) -> Result<u16> { + R::read_status(self) } /// Updates the link status. @@ -302,6 +281,14 @@ impl Device { } } +impl AsRef<kernel::device::Device> for Device { + fn as_ref(&self) -> &kernel::device::Device { + let phydev = self.0.get(); + // SAFETY: The struct invariant ensures that `mdio.dev` is valid. + unsafe { kernel::device::Device::as_ref(addr_of_mut!((*phydev).mdio.dev)) } + } +} + /// Defines certain other features this PHY supports (like interrupts). /// /// These flag values are used in [`Driver::FLAGS`]. @@ -341,6 +328,21 @@ impl<T: Driver> Adapter<T> { /// # Safety /// /// `phydev` must be passed by the corresponding callback in `phy_driver`. + unsafe extern "C" fn probe_callback(phydev: *mut bindings::phy_device) -> core::ffi::c_int { + from_result(|| { + // SAFETY: This callback is called only in contexts + // where we can exclusively access `phy_device` because + // it's not published yet, so the accessors on `Device` are okay + // to call. + let dev = unsafe { Device::from_raw(phydev) }; + T::probe(dev)?; + Ok(0) + }) + } + + /// # Safety + /// + /// `phydev` must be passed by the corresponding callback in `phy_driver`. unsafe extern "C" fn get_features_callback( phydev: *mut bindings::phy_device, ) -> core::ffi::c_int { @@ -491,7 +493,7 @@ impl<T: Driver> Adapter<T> { pub struct DriverVTable(Opaque<bindings::phy_driver>); // SAFETY: `DriverVTable` doesn't expose any &self method to access internal data, so it's safe to -// share `&DriverVTable` across execution context boundries. +// share `&DriverVTable` across execution context boundaries. unsafe impl Sync for DriverVTable {} /// Creates a [`DriverVTable`] instance from [`Driver`]. @@ -511,6 +513,11 @@ pub const fn create_phy_driver<T: Driver>() -> DriverVTable { } else { None }, + probe: if T::HAS_PROBE { + Some(Adapter::<T>::probe_callback) + } else { + None + }, get_features: if T::HAS_GET_FEATURES { Some(Adapter::<T>::get_features_callback) } else { @@ -583,6 +590,11 @@ pub trait Driver { kernel::build_error(VTABLE_DEFAULT_ERROR) } + /// Sets up device-specific structures during discovery. + fn probe(_dev: &mut Device) -> Result { + kernel::build_error(VTABLE_DEFAULT_ERROR) + } + /// Probes the hardware to determine what abilities it has. fn get_features(_dev: &mut Device) -> Result { kernel::build_error(VTABLE_DEFAULT_ERROR) diff --git a/rust/kernel/net/phy/reg.rs b/rust/kernel/net/phy/reg.rs new file mode 100644 index 000000000000..a7db0064cb7d --- /dev/null +++ b/rust/kernel/net/phy/reg.rs @@ -0,0 +1,224 @@ +// SPDX-License-Identifier: GPL-2.0 + +// Copyright (C) 2024 FUJITA Tomonori <fujita.tomonori@gmail.com> + +//! PHY register interfaces. +//! +//! This module provides support for accessing PHY registers in the +//! Ethernet management interface clauses 22 and 45 register namespaces, as +//! defined in IEEE 802.3. + +use super::Device; +use crate::build_assert; +use crate::error::*; +use crate::uapi; + +mod private { + /// Marker that a trait cannot be implemented outside of this crate + pub trait Sealed {} +} + +/// Accesses PHY registers. +/// +/// This trait is used to implement the unified interface to access +/// C22 and C45 PHY registers. +/// +/// # Examples +/// +/// ```ignore +/// fn link_change_notify(dev: &mut Device) { +/// // read C22 BMCR register +/// dev.read(C22::BMCR); +/// // read C45 PMA/PMD control 1 register +/// dev.read(C45::new(Mmd::PMAPMD, 0)); +/// +/// // Checks the link status as reported by registers in the C22 namespace +/// // and updates current link state. +/// dev.genphy_read_status::<phy::C22>(); +/// // Checks the link status as reported by registers in the C45 namespace +/// // and updates current link state. +/// dev.genphy_read_status::<phy::C45>(); +/// } +/// ``` +pub trait Register: private::Sealed { + /// Reads a PHY register. + fn read(&self, dev: &mut Device) -> Result<u16>; + + /// Writes a PHY register. + fn write(&self, dev: &mut Device, val: u16) -> Result; + + /// Checks the link status and updates current link state. + fn read_status(dev: &mut Device) -> Result<u16>; +} + +/// A single MDIO clause 22 register address (5 bits). +#[derive(Copy, Clone, Debug)] +pub struct C22(u8); + +impl C22 { + /// Basic mode control. + pub const BMCR: Self = C22(0x00); + /// Basic mode status. + pub const BMSR: Self = C22(0x01); + /// PHY identifier 1. + pub const PHYSID1: Self = C22(0x02); + /// PHY identifier 2. + pub const PHYSID2: Self = C22(0x03); + /// Auto-negotiation advertisement. + pub const ADVERTISE: Self = C22(0x04); + /// Auto-negotiation link partner base page ability. + pub const LPA: Self = C22(0x05); + /// Auto-negotiation expansion. + pub const EXPANSION: Self = C22(0x06); + /// Auto-negotiation next page transmit. + pub const NEXT_PAGE_TRANSMIT: Self = C22(0x07); + /// Auto-negotiation link partner received next page. + pub const LP_RECEIVED_NEXT_PAGE: Self = C22(0x08); + /// Master-slave control. + pub const MASTER_SLAVE_CONTROL: Self = C22(0x09); + /// Master-slave status. + pub const MASTER_SLAVE_STATUS: Self = C22(0x0a); + /// PSE Control. + pub const PSE_CONTROL: Self = C22(0x0b); + /// PSE Status. + pub const PSE_STATUS: Self = C22(0x0c); + /// MMD Register control. + pub const MMD_CONTROL: Self = C22(0x0d); + /// MMD Register address data. + pub const MMD_DATA: Self = C22(0x0e); + /// Extended status. + pub const EXTENDED_STATUS: Self = C22(0x0f); + + /// Creates a new instance of `C22` with a vendor specific register. + pub const fn vendor_specific<const N: u8>() -> Self { + build_assert!( + N > 0x0f && N < 0x20, + "Vendor-specific register address must be between 16 and 31" + ); + C22(N) + } +} + +impl private::Sealed for C22 {} + +impl Register for C22 { + fn read(&self, dev: &mut Device) -> Result<u16> { + let phydev = dev.0.get(); + // SAFETY: `phydev` is pointing to a valid object by the type invariant of `Device`. + // So it's just an FFI call, open code of `phy_read()` with a valid `phy_device` pointer + // `phydev`. + let ret = unsafe { + bindings::mdiobus_read((*phydev).mdio.bus, (*phydev).mdio.addr, self.0.into()) + }; + to_result(ret)?; + Ok(ret as u16) + } + + fn write(&self, dev: &mut Device, val: u16) -> Result { + let phydev = dev.0.get(); + // SAFETY: `phydev` is pointing to a valid object by the type invariant of `Device`. + // So it's just an FFI call, open code of `phy_write()` with a valid `phy_device` pointer + // `phydev`. + to_result(unsafe { + bindings::mdiobus_write((*phydev).mdio.bus, (*phydev).mdio.addr, self.0.into(), val) + }) + } + + fn read_status(dev: &mut Device) -> Result<u16> { + let phydev = dev.0.get(); + // SAFETY: `phydev` is pointing to a valid object by the type invariant of `Self`. + // So it's just an FFI call. + let ret = unsafe { bindings::genphy_read_status(phydev) }; + to_result(ret)?; + Ok(ret as u16) + } +} + +/// A single MDIO clause 45 register device and address. +#[derive(Copy, Clone, Debug)] +pub struct Mmd(u8); + +impl Mmd { + /// Physical Medium Attachment/Dependent. + pub const PMAPMD: Self = Mmd(uapi::MDIO_MMD_PMAPMD as u8); + /// WAN interface sublayer. + pub const WIS: Self = Mmd(uapi::MDIO_MMD_WIS as u8); + /// Physical coding sublayer. + pub const PCS: Self = Mmd(uapi::MDIO_MMD_PCS as u8); + /// PHY Extender sublayer. + pub const PHYXS: Self = Mmd(uapi::MDIO_MMD_PHYXS as u8); + /// DTE Extender sublayer. + pub const DTEXS: Self = Mmd(uapi::MDIO_MMD_DTEXS as u8); + /// Transmission convergence. + pub const TC: Self = Mmd(uapi::MDIO_MMD_TC as u8); + /// Auto negotiation. + pub const AN: Self = Mmd(uapi::MDIO_MMD_AN as u8); + /// Separated PMA (1). + pub const SEPARATED_PMA1: Self = Mmd(8); + /// Separated PMA (2). + pub const SEPARATED_PMA2: Self = Mmd(9); + /// Separated PMA (3). + pub const SEPARATED_PMA3: Self = Mmd(10); + /// Separated PMA (4). + pub const SEPARATED_PMA4: Self = Mmd(11); + /// OFDM PMA/PMD. + pub const OFDM_PMAPMD: Self = Mmd(12); + /// Power unit. + pub const POWER_UNIT: Self = Mmd(13); + /// Clause 22 extension. + pub const C22_EXT: Self = Mmd(uapi::MDIO_MMD_C22EXT as u8); + /// Vendor specific 1. + pub const VEND1: Self = Mmd(uapi::MDIO_MMD_VEND1 as u8); + /// Vendor specific 2. + pub const VEND2: Self = Mmd(uapi::MDIO_MMD_VEND2 as u8); +} + +/// A single MDIO clause 45 register device and address. +/// +/// Clause 45 uses a 5-bit device address to access a specific MMD within +/// a port, then a 16-bit register address to access a location within +/// that device. `C45` represents this by storing a [`Mmd`] and +/// a register number. +pub struct C45 { + devad: Mmd, + regnum: u16, +} + +impl C45 { + /// Creates a new instance of `C45`. + pub fn new(devad: Mmd, regnum: u16) -> Self { + Self { devad, regnum } + } +} + +impl private::Sealed for C45 {} + +impl Register for C45 { + fn read(&self, dev: &mut Device) -> Result<u16> { + let phydev = dev.0.get(); + // SAFETY: `phydev` is pointing to a valid object by the type invariant of `Device`. + // So it's just an FFI call. + let ret = + unsafe { bindings::phy_read_mmd(phydev, self.devad.0.into(), self.regnum.into()) }; + to_result(ret)?; + Ok(ret as u16) + } + + fn write(&self, dev: &mut Device, val: u16) -> Result { + let phydev = dev.0.get(); + // SAFETY: `phydev` is pointing to a valid object by the type invariant of `Device`. + // So it's just an FFI call. + to_result(unsafe { + bindings::phy_write_mmd(phydev, self.devad.0.into(), self.regnum.into(), val) + }) + } + + fn read_status(dev: &mut Device) -> Result<u16> { + let phydev = dev.0.get(); + // SAFETY: `phydev` is pointing to a valid object by the type invariant of `Self`. + // So it's just an FFI call. + let ret = unsafe { bindings::genphy_c45_read_status(phydev) }; + to_result(ret)?; + Ok(ret as u16) + } +} diff --git a/rust/kernel/page.rs b/rust/kernel/page.rs new file mode 100644 index 000000000000..208a006d587c --- /dev/null +++ b/rust/kernel/page.rs @@ -0,0 +1,250 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Kernel page allocation and management. + +use crate::{ + alloc::{AllocError, Flags}, + bindings, + error::code::*, + error::Result, + uaccess::UserSliceReader, +}; +use core::ptr::{self, NonNull}; + +/// A bitwise shift for the page size. +pub const PAGE_SHIFT: usize = bindings::PAGE_SHIFT as usize; + +/// The number of bytes in a page. +pub const PAGE_SIZE: usize = bindings::PAGE_SIZE; + +/// A bitmask that gives the page containing a given address. +pub const PAGE_MASK: usize = !(PAGE_SIZE - 1); + +/// A pointer to a page that owns the page allocation. +/// +/// # Invariants +/// +/// The pointer is valid, and has ownership over the page. +pub struct Page { + page: NonNull<bindings::page>, +} + +// SAFETY: Pages have no logic that relies on them staying on a given thread, so moving them across +// threads is safe. +unsafe impl Send for Page {} + +// SAFETY: Pages have no logic that relies on them not being accessed concurrently, so accessing +// them concurrently is safe. +unsafe impl Sync for Page {} + +impl Page { + /// Allocates a new page. + /// + /// # Examples + /// + /// Allocate memory for a page. + /// + /// ``` + /// use kernel::page::Page; + /// + /// # fn dox() -> Result<(), kernel::alloc::AllocError> { + /// let page = Page::alloc_page(GFP_KERNEL)?; + /// # Ok(()) } + /// ``` + /// + /// Allocate memory for a page and zero its contents. + /// + /// ``` + /// use kernel::page::Page; + /// + /// # fn dox() -> Result<(), kernel::alloc::AllocError> { + /// let page = Page::alloc_page(GFP_KERNEL | __GFP_ZERO)?; + /// # Ok(()) } + /// ``` + pub fn alloc_page(flags: Flags) -> Result<Self, AllocError> { + // SAFETY: Depending on the value of `gfp_flags`, this call may sleep. Other than that, it + // is always safe to call this method. + let page = unsafe { bindings::alloc_pages(flags.as_raw(), 0) }; + let page = NonNull::new(page).ok_or(AllocError)?; + // INVARIANT: We just successfully allocated a page, so we now have ownership of the newly + // allocated page. We transfer that ownership to the new `Page` object. + Ok(Self { page }) + } + + /// Returns a raw pointer to the page. + pub fn as_ptr(&self) -> *mut bindings::page { + self.page.as_ptr() + } + + /// Runs a piece of code with this page mapped to an address. + /// + /// The page is unmapped when this call returns. + /// + /// # Using the raw pointer + /// + /// It is up to the caller to use the provided raw pointer correctly. The pointer is valid for + /// `PAGE_SIZE` bytes and for the duration in which the closure is called. The pointer might + /// only be mapped on the current thread, and when that is the case, dereferencing it on other + /// threads is UB. Other than that, the usual rules for dereferencing a raw pointer apply: don't + /// cause data races, the memory may be uninitialized, and so on. + /// + /// If multiple threads map the same page at the same time, then they may reference with + /// different addresses. However, even if the addresses are different, the underlying memory is + /// still the same for these purposes (e.g., it's still a data race if they both write to the + /// same underlying byte at the same time). + fn with_page_mapped<T>(&self, f: impl FnOnce(*mut u8) -> T) -> T { + // SAFETY: `page` is valid due to the type invariants on `Page`. + let mapped_addr = unsafe { bindings::kmap_local_page(self.as_ptr()) }; + + let res = f(mapped_addr.cast()); + + // This unmaps the page mapped above. + // + // SAFETY: Since this API takes the user code as a closure, it can only be used in a manner + // where the pages are unmapped in reverse order. This is as required by `kunmap_local`. + // + // In other words, if this call to `kunmap_local` happens when a different page should be + // unmapped first, then there must necessarily be a call to `kmap_local_page` other than the + // call just above in `with_page_mapped` that made that possible. In this case, it is the + // unsafe block that wraps that other call that is incorrect. + unsafe { bindings::kunmap_local(mapped_addr) }; + + res + } + + /// Runs a piece of code with a raw pointer to a slice of this page, with bounds checking. + /// + /// If `f` is called, then it will be called with a pointer that points at `off` bytes into the + /// page, and the pointer will be valid for at least `len` bytes. The pointer is only valid on + /// this task, as this method uses a local mapping. + /// + /// If `off` and `len` refers to a region outside of this page, then this method returns + /// [`EINVAL`] and does not call `f`. + /// + /// # Using the raw pointer + /// + /// It is up to the caller to use the provided raw pointer correctly. The pointer is valid for + /// `len` bytes and for the duration in which the closure is called. The pointer might only be + /// mapped on the current thread, and when that is the case, dereferencing it on other threads + /// is UB. Other than that, the usual rules for dereferencing a raw pointer apply: don't cause + /// data races, the memory may be uninitialized, and so on. + /// + /// If multiple threads map the same page at the same time, then they may reference with + /// different addresses. However, even if the addresses are different, the underlying memory is + /// still the same for these purposes (e.g., it's still a data race if they both write to the + /// same underlying byte at the same time). + fn with_pointer_into_page<T>( + &self, + off: usize, + len: usize, + f: impl FnOnce(*mut u8) -> Result<T>, + ) -> Result<T> { + let bounds_ok = off <= PAGE_SIZE && len <= PAGE_SIZE && (off + len) <= PAGE_SIZE; + + if bounds_ok { + self.with_page_mapped(move |page_addr| { + // SAFETY: The `off` integer is at most `PAGE_SIZE`, so this pointer offset will + // result in a pointer that is in bounds or one off the end of the page. + f(unsafe { page_addr.add(off) }) + }) + } else { + Err(EINVAL) + } + } + + /// Maps the page and reads from it into the given buffer. + /// + /// This method will perform bounds checks on the page offset. If `offset .. offset+len` goes + /// outside of the page, then this call returns [`EINVAL`]. + /// + /// # Safety + /// + /// * Callers must ensure that `dst` is valid for writing `len` bytes. + /// * Callers must ensure that this call does not race with a write to the same page that + /// overlaps with this read. + pub unsafe fn read_raw(&self, dst: *mut u8, offset: usize, len: usize) -> Result { + self.with_pointer_into_page(offset, len, move |src| { + // SAFETY: If `with_pointer_into_page` calls into this closure, then + // it has performed a bounds check and guarantees that `src` is + // valid for `len` bytes. + // + // There caller guarantees that there is no data race. + unsafe { ptr::copy_nonoverlapping(src, dst, len) }; + Ok(()) + }) + } + + /// Maps the page and writes into it from the given buffer. + /// + /// This method will perform bounds checks on the page offset. If `offset .. offset+len` goes + /// outside of the page, then this call returns [`EINVAL`]. + /// + /// # Safety + /// + /// * Callers must ensure that `src` is valid for reading `len` bytes. + /// * Callers must ensure that this call does not race with a read or write to the same page + /// that overlaps with this write. + pub unsafe fn write_raw(&self, src: *const u8, offset: usize, len: usize) -> Result { + self.with_pointer_into_page(offset, len, move |dst| { + // SAFETY: If `with_pointer_into_page` calls into this closure, then it has performed a + // bounds check and guarantees that `dst` is valid for `len` bytes. + // + // There caller guarantees that there is no data race. + unsafe { ptr::copy_nonoverlapping(src, dst, len) }; + Ok(()) + }) + } + + /// Maps the page and zeroes the given slice. + /// + /// This method will perform bounds checks on the page offset. If `offset .. offset+len` goes + /// outside of the page, then this call returns [`EINVAL`]. + /// + /// # Safety + /// + /// Callers must ensure that this call does not race with a read or write to the same page that + /// overlaps with this write. + pub unsafe fn fill_zero_raw(&self, offset: usize, len: usize) -> Result { + self.with_pointer_into_page(offset, len, move |dst| { + // SAFETY: If `with_pointer_into_page` calls into this closure, then it has performed a + // bounds check and guarantees that `dst` is valid for `len` bytes. + // + // There caller guarantees that there is no data race. + unsafe { ptr::write_bytes(dst, 0u8, len) }; + Ok(()) + }) + } + + /// Copies data from userspace into this page. + /// + /// This method will perform bounds checks on the page offset. If `offset .. offset+len` goes + /// outside of the page, then this call returns [`EINVAL`]. + /// + /// Like the other `UserSliceReader` methods, data races are allowed on the userspace address. + /// However, they are not allowed on the page you are copying into. + /// + /// # Safety + /// + /// Callers must ensure that this call does not race with a read or write to the same page that + /// overlaps with this write. + pub unsafe fn copy_from_user_slice_raw( + &self, + reader: &mut UserSliceReader, + offset: usize, + len: usize, + ) -> Result { + self.with_pointer_into_page(offset, len, move |dst| { + // SAFETY: If `with_pointer_into_page` calls into this closure, then it has performed a + // bounds check and guarantees that `dst` is valid for `len` bytes. Furthermore, we have + // exclusive access to the slice since the caller guarantees that there are no races. + reader.read_raw(unsafe { core::slice::from_raw_parts_mut(dst.cast(), len) }) + }) + } +} + +impl Drop for Page { + fn drop(&mut self) { + // SAFETY: By the type invariants, we have ownership of the page and can free it. + unsafe { bindings::__free_pages(self.page.as_ptr(), 0) }; + } +} diff --git a/rust/kernel/prelude.rs b/rust/kernel/prelude.rs index b37a0b3180fb..4571daec0961 100644 --- a/rust/kernel/prelude.rs +++ b/rust/kernel/prelude.rs @@ -37,6 +37,6 @@ pub use super::error::{code::*, Error, Result}; pub use super::{str::CStr, ThisModule}; -pub use super::init::{InPlaceInit, Init, PinInit}; +pub use super::init::{InPlaceInit, InPlaceWrite, Init, PinInit}; pub use super::current; diff --git a/rust/kernel/print.rs b/rust/kernel/print.rs index a78aa3514a0a..508b0221256c 100644 --- a/rust/kernel/print.rs +++ b/rust/kernel/print.rs @@ -4,7 +4,7 @@ //! //! C header: [`include/linux/printk.h`](srctree/include/linux/printk.h) //! -//! Reference: <https://www.kernel.org/doc/html/latest/core-api/printk-basics.html> +//! Reference: <https://docs.kernel.org/core-api/printk-basics.html> use core::{ ffi::{c_char, c_void}, @@ -197,7 +197,7 @@ macro_rules! print_macro ( /// Mimics the interface of [`std::print!`]. See [`core::fmt`] and /// `alloc::format!` for information about the formatting syntax. /// -/// [`pr_emerg`]: https://www.kernel.org/doc/html/latest/core-api/printk-basics.html#c.pr_emerg +/// [`pr_emerg`]: https://docs.kernel.org/core-api/printk-basics.html#c.pr_emerg /// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html /// /// # Examples @@ -221,7 +221,7 @@ macro_rules! pr_emerg ( /// Mimics the interface of [`std::print!`]. See [`core::fmt`] and /// `alloc::format!` for information about the formatting syntax. /// -/// [`pr_alert`]: https://www.kernel.org/doc/html/latest/core-api/printk-basics.html#c.pr_alert +/// [`pr_alert`]: https://docs.kernel.org/core-api/printk-basics.html#c.pr_alert /// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html /// /// # Examples @@ -245,7 +245,7 @@ macro_rules! pr_alert ( /// Mimics the interface of [`std::print!`]. See [`core::fmt`] and /// `alloc::format!` for information about the formatting syntax. /// -/// [`pr_crit`]: https://www.kernel.org/doc/html/latest/core-api/printk-basics.html#c.pr_crit +/// [`pr_crit`]: https://docs.kernel.org/core-api/printk-basics.html#c.pr_crit /// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html /// /// # Examples @@ -269,7 +269,7 @@ macro_rules! pr_crit ( /// Mimics the interface of [`std::print!`]. See [`core::fmt`] and /// `alloc::format!` for information about the formatting syntax. /// -/// [`pr_err`]: https://www.kernel.org/doc/html/latest/core-api/printk-basics.html#c.pr_err +/// [`pr_err`]: https://docs.kernel.org/core-api/printk-basics.html#c.pr_err /// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html /// /// # Examples @@ -293,7 +293,7 @@ macro_rules! pr_err ( /// Mimics the interface of [`std::print!`]. See [`core::fmt`] and /// `alloc::format!` for information about the formatting syntax. /// -/// [`pr_warn`]: https://www.kernel.org/doc/html/latest/core-api/printk-basics.html#c.pr_warn +/// [`pr_warn`]: https://docs.kernel.org/core-api/printk-basics.html#c.pr_warn /// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html /// /// # Examples @@ -317,7 +317,7 @@ macro_rules! pr_warn ( /// Mimics the interface of [`std::print!`]. See [`core::fmt`] and /// `alloc::format!` for information about the formatting syntax. /// -/// [`pr_notice`]: https://www.kernel.org/doc/html/latest/core-api/printk-basics.html#c.pr_notice +/// [`pr_notice`]: https://docs.kernel.org/core-api/printk-basics.html#c.pr_notice /// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html /// /// # Examples @@ -341,7 +341,7 @@ macro_rules! pr_notice ( /// Mimics the interface of [`std::print!`]. See [`core::fmt`] and /// `alloc::format!` for information about the formatting syntax. /// -/// [`pr_info`]: https://www.kernel.org/doc/html/latest/core-api/printk-basics.html#c.pr_info +/// [`pr_info`]: https://docs.kernel.org/core-api/printk-basics.html#c.pr_info /// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html /// /// # Examples @@ -367,7 +367,7 @@ macro_rules! pr_info ( /// Mimics the interface of [`std::print!`]. See [`core::fmt`] and /// `alloc::format!` for information about the formatting syntax. /// -/// [`pr_debug`]: https://www.kernel.org/doc/html/latest/core-api/printk-basics.html#c.pr_debug +/// [`pr_debug`]: https://docs.kernel.org/core-api/printk-basics.html#c.pr_debug /// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html /// /// # Examples @@ -395,7 +395,7 @@ macro_rules! pr_debug ( /// `alloc::format!` for information about the formatting syntax. /// /// [`pr_info!`]: crate::pr_info! -/// [`pr_cont`]: https://www.kernel.org/doc/html/latest/core-api/printk-basics.html#c.pr_cont +/// [`pr_cont`]: https://docs.kernel.org/core-api/printk-basics.html#c.pr_cont /// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html /// /// # Examples diff --git a/rust/kernel/rbtree.rs b/rust/kernel/rbtree.rs new file mode 100644 index 000000000000..25eb36fd1cdc --- /dev/null +++ b/rust/kernel/rbtree.rs @@ -0,0 +1,1278 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Red-black trees. +//! +//! C header: [`include/linux/rbtree.h`](srctree/include/linux/rbtree.h) +//! +//! Reference: <https://docs.kernel.org/core-api/rbtree.html> + +use crate::{alloc::Flags, bindings, container_of, error::Result, prelude::*}; +use alloc::boxed::Box; +use core::{ + cmp::{Ord, Ordering}, + marker::PhantomData, + mem::MaybeUninit, + ptr::{addr_of_mut, from_mut, NonNull}, +}; + +/// A red-black tree with owned nodes. +/// +/// It is backed by the kernel C red-black trees. +/// +/// # Examples +/// +/// In the example below we do several operations on a tree. We note that insertions may fail if +/// the system is out of memory. +/// +/// ``` +/// use kernel::{alloc::flags, rbtree::{RBTree, RBTreeNode, RBTreeNodeReservation}}; +/// +/// // Create a new tree. +/// let mut tree = RBTree::new(); +/// +/// // Insert three elements. +/// tree.try_create_and_insert(20, 200, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(10, 100, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(30, 300, flags::GFP_KERNEL)?; +/// +/// // Check the nodes we just inserted. +/// { +/// assert_eq!(tree.get(&10).unwrap(), &100); +/// assert_eq!(tree.get(&20).unwrap(), &200); +/// assert_eq!(tree.get(&30).unwrap(), &300); +/// } +/// +/// // Iterate over the nodes we just inserted. +/// { +/// let mut iter = tree.iter(); +/// assert_eq!(iter.next().unwrap(), (&10, &100)); +/// assert_eq!(iter.next().unwrap(), (&20, &200)); +/// assert_eq!(iter.next().unwrap(), (&30, &300)); +/// assert!(iter.next().is_none()); +/// } +/// +/// // Print all elements. +/// for (key, value) in &tree { +/// pr_info!("{} = {}\n", key, value); +/// } +/// +/// // Replace one of the elements. +/// tree.try_create_and_insert(10, 1000, flags::GFP_KERNEL)?; +/// +/// // Check that the tree reflects the replacement. +/// { +/// let mut iter = tree.iter(); +/// assert_eq!(iter.next().unwrap(), (&10, &1000)); +/// assert_eq!(iter.next().unwrap(), (&20, &200)); +/// assert_eq!(iter.next().unwrap(), (&30, &300)); +/// assert!(iter.next().is_none()); +/// } +/// +/// // Change the value of one of the elements. +/// *tree.get_mut(&30).unwrap() = 3000; +/// +/// // Check that the tree reflects the update. +/// { +/// let mut iter = tree.iter(); +/// assert_eq!(iter.next().unwrap(), (&10, &1000)); +/// assert_eq!(iter.next().unwrap(), (&20, &200)); +/// assert_eq!(iter.next().unwrap(), (&30, &3000)); +/// assert!(iter.next().is_none()); +/// } +/// +/// // Remove an element. +/// tree.remove(&10); +/// +/// // Check that the tree reflects the removal. +/// { +/// let mut iter = tree.iter(); +/// assert_eq!(iter.next().unwrap(), (&20, &200)); +/// assert_eq!(iter.next().unwrap(), (&30, &3000)); +/// assert!(iter.next().is_none()); +/// } +/// +/// # Ok::<(), Error>(()) +/// ``` +/// +/// In the example below, we first allocate a node, acquire a spinlock, then insert the node into +/// the tree. This is useful when the insertion context does not allow sleeping, for example, when +/// holding a spinlock. +/// +/// ``` +/// use kernel::{alloc::flags, rbtree::{RBTree, RBTreeNode}, sync::SpinLock}; +/// +/// fn insert_test(tree: &SpinLock<RBTree<u32, u32>>) -> Result { +/// // Pre-allocate node. This may fail (as it allocates memory). +/// let node = RBTreeNode::new(10, 100, flags::GFP_KERNEL)?; +/// +/// // Insert node while holding the lock. It is guaranteed to succeed with no allocation +/// // attempts. +/// let mut guard = tree.lock(); +/// guard.insert(node); +/// Ok(()) +/// } +/// ``` +/// +/// In the example below, we reuse an existing node allocation from an element we removed. +/// +/// ``` +/// use kernel::{alloc::flags, rbtree::{RBTree, RBTreeNodeReservation}}; +/// +/// // Create a new tree. +/// let mut tree = RBTree::new(); +/// +/// // Insert three elements. +/// tree.try_create_and_insert(20, 200, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(10, 100, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(30, 300, flags::GFP_KERNEL)?; +/// +/// // Check the nodes we just inserted. +/// { +/// let mut iter = tree.iter(); +/// assert_eq!(iter.next().unwrap(), (&10, &100)); +/// assert_eq!(iter.next().unwrap(), (&20, &200)); +/// assert_eq!(iter.next().unwrap(), (&30, &300)); +/// assert!(iter.next().is_none()); +/// } +/// +/// // Remove a node, getting back ownership of it. +/// let existing = tree.remove(&30).unwrap(); +/// +/// // Check that the tree reflects the removal. +/// { +/// let mut iter = tree.iter(); +/// assert_eq!(iter.next().unwrap(), (&10, &100)); +/// assert_eq!(iter.next().unwrap(), (&20, &200)); +/// assert!(iter.next().is_none()); +/// } +/// +/// // Create a preallocated reservation that we can re-use later. +/// let reservation = RBTreeNodeReservation::new(flags::GFP_KERNEL)?; +/// +/// // Insert a new node into the tree, reusing the previous allocation. This is guaranteed to +/// // succeed (no memory allocations). +/// tree.insert(reservation.into_node(15, 150)); +/// +/// // Check that the tree reflect the new insertion. +/// { +/// let mut iter = tree.iter(); +/// assert_eq!(iter.next().unwrap(), (&10, &100)); +/// assert_eq!(iter.next().unwrap(), (&15, &150)); +/// assert_eq!(iter.next().unwrap(), (&20, &200)); +/// assert!(iter.next().is_none()); +/// } +/// +/// # Ok::<(), Error>(()) +/// ``` +/// +/// # Invariants +/// +/// Non-null parent/children pointers stored in instances of the `rb_node` C struct are always +/// valid, and pointing to a field of our internal representation of a node. +pub struct RBTree<K, V> { + root: bindings::rb_root, + _p: PhantomData<Node<K, V>>, +} + +// SAFETY: An [`RBTree`] allows the same kinds of access to its values that a struct allows to its +// fields, so we use the same Send condition as would be used for a struct with K and V fields. +unsafe impl<K: Send, V: Send> Send for RBTree<K, V> {} + +// SAFETY: An [`RBTree`] allows the same kinds of access to its values that a struct allows to its +// fields, so we use the same Sync condition as would be used for a struct with K and V fields. +unsafe impl<K: Sync, V: Sync> Sync for RBTree<K, V> {} + +impl<K, V> RBTree<K, V> { + /// Creates a new and empty tree. + pub fn new() -> Self { + Self { + // INVARIANT: There are no nodes in the tree, so the invariant holds vacuously. + root: bindings::rb_root::default(), + _p: PhantomData, + } + } + + /// Returns an iterator over the tree nodes, sorted by key. + pub fn iter(&self) -> Iter<'_, K, V> { + Iter { + _tree: PhantomData, + // INVARIANT: + // - `self.root` is a valid pointer to a tree root. + // - `bindings::rb_first` produces a valid pointer to a node given `root` is valid. + iter_raw: IterRaw { + // SAFETY: by the invariants, all pointers are valid. + next: unsafe { bindings::rb_first(&self.root) }, + _phantom: PhantomData, + }, + } + } + + /// Returns a mutable iterator over the tree nodes, sorted by key. + pub fn iter_mut(&mut self) -> IterMut<'_, K, V> { + IterMut { + _tree: PhantomData, + // INVARIANT: + // - `self.root` is a valid pointer to a tree root. + // - `bindings::rb_first` produces a valid pointer to a node given `root` is valid. + iter_raw: IterRaw { + // SAFETY: by the invariants, all pointers are valid. + next: unsafe { bindings::rb_first(from_mut(&mut self.root)) }, + _phantom: PhantomData, + }, + } + } + + /// Returns an iterator over the keys of the nodes in the tree, in sorted order. + pub fn keys(&self) -> impl Iterator<Item = &'_ K> { + self.iter().map(|(k, _)| k) + } + + /// Returns an iterator over the values of the nodes in the tree, sorted by key. + pub fn values(&self) -> impl Iterator<Item = &'_ V> { + self.iter().map(|(_, v)| v) + } + + /// Returns a mutable iterator over the values of the nodes in the tree, sorted by key. + pub fn values_mut(&mut self) -> impl Iterator<Item = &'_ mut V> { + self.iter_mut().map(|(_, v)| v) + } + + /// Returns a cursor over the tree nodes, starting with the smallest key. + pub fn cursor_front(&mut self) -> Option<Cursor<'_, K, V>> { + let root = addr_of_mut!(self.root); + // SAFETY: `self.root` is always a valid root node + let current = unsafe { bindings::rb_first(root) }; + NonNull::new(current).map(|current| { + // INVARIANT: + // - `current` is a valid node in the [`RBTree`] pointed to by `self`. + Cursor { + current, + tree: self, + } + }) + } + + /// Returns a cursor over the tree nodes, starting with the largest key. + pub fn cursor_back(&mut self) -> Option<Cursor<'_, K, V>> { + let root = addr_of_mut!(self.root); + // SAFETY: `self.root` is always a valid root node + let current = unsafe { bindings::rb_last(root) }; + NonNull::new(current).map(|current| { + // INVARIANT: + // - `current` is a valid node in the [`RBTree`] pointed to by `self`. + Cursor { + current, + tree: self, + } + }) + } +} + +impl<K, V> RBTree<K, V> +where + K: Ord, +{ + /// Tries to insert a new value into the tree. + /// + /// It overwrites a node if one already exists with the same key and returns it (containing the + /// key/value pair). Returns [`None`] if a node with the same key didn't already exist. + /// + /// Returns an error if it cannot allocate memory for the new node. + pub fn try_create_and_insert( + &mut self, + key: K, + value: V, + flags: Flags, + ) -> Result<Option<RBTreeNode<K, V>>> { + Ok(self.insert(RBTreeNode::new(key, value, flags)?)) + } + + /// Inserts a new node into the tree. + /// + /// It overwrites a node if one already exists with the same key and returns it (containing the + /// key/value pair). Returns [`None`] if a node with the same key didn't already exist. + /// + /// This function always succeeds. + pub fn insert(&mut self, node: RBTreeNode<K, V>) -> Option<RBTreeNode<K, V>> { + match self.raw_entry(&node.node.key) { + RawEntry::Occupied(entry) => Some(entry.replace(node)), + RawEntry::Vacant(entry) => { + entry.insert(node); + None + } + } + } + + fn raw_entry(&mut self, key: &K) -> RawEntry<'_, K, V> { + let raw_self: *mut RBTree<K, V> = self; + // The returned `RawEntry` is used to call either `rb_link_node` or `rb_replace_node`. + // The parameters of `bindings::rb_link_node` are as follows: + // - `node`: A pointer to an uninitialized node being inserted. + // - `parent`: A pointer to an existing node in the tree. One of its child pointers must be + // null, and `node` will become a child of `parent` by replacing that child pointer + // with a pointer to `node`. + // - `rb_link`: A pointer to either the left-child or right-child field of `parent`. This + // specifies which child of `parent` should hold `node` after this call. The + // value of `*rb_link` must be null before the call to `rb_link_node`. If the + // red/black tree is empty, then it’s also possible for `parent` to be null. In + // this case, `rb_link` is a pointer to the `root` field of the red/black tree. + // + // We will traverse the tree looking for a node that has a null pointer as its child, + // representing an empty subtree where we can insert our new node. We need to make sure + // that we preserve the ordering of the nodes in the tree. In each iteration of the loop + // we store `parent` and `child_field_of_parent`, and the new `node` will go somewhere + // in the subtree of `parent` that `child_field_of_parent` points at. Once + // we find an empty subtree, we can insert the new node using `rb_link_node`. + let mut parent = core::ptr::null_mut(); + let mut child_field_of_parent: &mut *mut bindings::rb_node = + // SAFETY: `raw_self` is a valid pointer to the `RBTree` (created from `self` above). + unsafe { &mut (*raw_self).root.rb_node }; + while !(*child_field_of_parent).is_null() { + let curr = *child_field_of_parent; + // SAFETY: All links fields we create are in a `Node<K, V>`. + let node = unsafe { container_of!(curr, Node<K, V>, links) }; + + // SAFETY: `node` is a non-null node so it is valid by the type invariants. + match key.cmp(unsafe { &(*node).key }) { + // SAFETY: `curr` is a non-null node so it is valid by the type invariants. + Ordering::Less => child_field_of_parent = unsafe { &mut (*curr).rb_left }, + // SAFETY: `curr` is a non-null node so it is valid by the type invariants. + Ordering::Greater => child_field_of_parent = unsafe { &mut (*curr).rb_right }, + Ordering::Equal => { + return RawEntry::Occupied(OccupiedEntry { + rbtree: self, + node_links: curr, + }) + } + } + parent = curr; + } + + RawEntry::Vacant(RawVacantEntry { + rbtree: raw_self, + parent, + child_field_of_parent, + _phantom: PhantomData, + }) + } + + /// Gets the given key's corresponding entry in the map for in-place manipulation. + pub fn entry(&mut self, key: K) -> Entry<'_, K, V> { + match self.raw_entry(&key) { + RawEntry::Occupied(entry) => Entry::Occupied(entry), + RawEntry::Vacant(entry) => Entry::Vacant(VacantEntry { raw: entry, key }), + } + } + + /// Used for accessing the given node, if it exists. + pub fn find_mut(&mut self, key: &K) -> Option<OccupiedEntry<'_, K, V>> { + match self.raw_entry(key) { + RawEntry::Occupied(entry) => Some(entry), + RawEntry::Vacant(_entry) => None, + } + } + + /// Returns a reference to the value corresponding to the key. + pub fn get(&self, key: &K) -> Option<&V> { + let mut node = self.root.rb_node; + while !node.is_null() { + // SAFETY: By the type invariant of `Self`, all non-null `rb_node` pointers stored in `self` + // point to the links field of `Node<K, V>` objects. + let this = unsafe { container_of!(node, Node<K, V>, links) }; + // SAFETY: `this` is a non-null node so it is valid by the type invariants. + node = match key.cmp(unsafe { &(*this).key }) { + // SAFETY: `node` is a non-null node so it is valid by the type invariants. + Ordering::Less => unsafe { (*node).rb_left }, + // SAFETY: `node` is a non-null node so it is valid by the type invariants. + Ordering::Greater => unsafe { (*node).rb_right }, + // SAFETY: `node` is a non-null node so it is valid by the type invariants. + Ordering::Equal => return Some(unsafe { &(*this).value }), + } + } + None + } + + /// Returns a mutable reference to the value corresponding to the key. + pub fn get_mut(&mut self, key: &K) -> Option<&mut V> { + self.find_mut(key).map(|node| node.into_mut()) + } + + /// Removes the node with the given key from the tree. + /// + /// It returns the node that was removed if one exists, or [`None`] otherwise. + pub fn remove_node(&mut self, key: &K) -> Option<RBTreeNode<K, V>> { + self.find_mut(key).map(OccupiedEntry::remove_node) + } + + /// Removes the node with the given key from the tree. + /// + /// It returns the value that was removed if one exists, or [`None`] otherwise. + pub fn remove(&mut self, key: &K) -> Option<V> { + self.find_mut(key).map(OccupiedEntry::remove) + } + + /// Returns a cursor over the tree nodes based on the given key. + /// + /// If the given key exists, the cursor starts there. + /// Otherwise it starts with the first larger key in sort order. + /// If there is no larger key, it returns [`None`]. + pub fn cursor_lower_bound(&mut self, key: &K) -> Option<Cursor<'_, K, V>> + where + K: Ord, + { + let mut node = self.root.rb_node; + let mut best_match: Option<NonNull<Node<K, V>>> = None; + while !node.is_null() { + // SAFETY: By the type invariant of `Self`, all non-null `rb_node` pointers stored in `self` + // point to the links field of `Node<K, V>` objects. + let this = unsafe { container_of!(node, Node<K, V>, links) }.cast_mut(); + // SAFETY: `this` is a non-null node so it is valid by the type invariants. + let this_key = unsafe { &(*this).key }; + // SAFETY: `node` is a non-null node so it is valid by the type invariants. + let left_child = unsafe { (*node).rb_left }; + // SAFETY: `node` is a non-null node so it is valid by the type invariants. + let right_child = unsafe { (*node).rb_right }; + match key.cmp(this_key) { + Ordering::Equal => { + best_match = NonNull::new(this); + break; + } + Ordering::Greater => { + node = right_child; + } + Ordering::Less => { + let is_better_match = match best_match { + None => true, + Some(best) => { + // SAFETY: `best` is a non-null node so it is valid by the type invariants. + let best_key = unsafe { &(*best.as_ptr()).key }; + best_key > this_key + } + }; + if is_better_match { + best_match = NonNull::new(this); + } + node = left_child; + } + }; + } + + let best = best_match?; + + // SAFETY: `best` is a non-null node so it is valid by the type invariants. + let links = unsafe { addr_of_mut!((*best.as_ptr()).links) }; + + NonNull::new(links).map(|current| { + // INVARIANT: + // - `current` is a valid node in the [`RBTree`] pointed to by `self`. + Cursor { + current, + tree: self, + } + }) + } +} + +impl<K, V> Default for RBTree<K, V> { + fn default() -> Self { + Self::new() + } +} + +impl<K, V> Drop for RBTree<K, V> { + fn drop(&mut self) { + // SAFETY: `root` is valid as it's embedded in `self` and we have a valid `self`. + let mut next = unsafe { bindings::rb_first_postorder(&self.root) }; + + // INVARIANT: The loop invariant is that all tree nodes from `next` in postorder are valid. + while !next.is_null() { + // SAFETY: All links fields we create are in a `Node<K, V>`. + let this = unsafe { container_of!(next, Node<K, V>, links) }; + + // Find out what the next node is before disposing of the current one. + // SAFETY: `next` and all nodes in postorder are still valid. + next = unsafe { bindings::rb_next_postorder(next) }; + + // INVARIANT: This is the destructor, so we break the type invariant during clean-up, + // but it is not observable. The loop invariant is still maintained. + + // SAFETY: `this` is valid per the loop invariant. + unsafe { drop(Box::from_raw(this.cast_mut())) }; + } + } +} + +/// A bidirectional cursor over the tree nodes, sorted by key. +/// +/// # Examples +/// +/// In the following example, we obtain a cursor to the first element in the tree. +/// The cursor allows us to iterate bidirectionally over key/value pairs in the tree. +/// +/// ``` +/// use kernel::{alloc::flags, rbtree::RBTree}; +/// +/// // Create a new tree. +/// let mut tree = RBTree::new(); +/// +/// // Insert three elements. +/// tree.try_create_and_insert(10, 100, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(20, 200, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(30, 300, flags::GFP_KERNEL)?; +/// +/// // Get a cursor to the first element. +/// let mut cursor = tree.cursor_front().unwrap(); +/// let mut current = cursor.current(); +/// assert_eq!(current, (&10, &100)); +/// +/// // Move the cursor, updating it to the 2nd element. +/// cursor = cursor.move_next().unwrap(); +/// current = cursor.current(); +/// assert_eq!(current, (&20, &200)); +/// +/// // Peek at the next element without impacting the cursor. +/// let next = cursor.peek_next().unwrap(); +/// assert_eq!(next, (&30, &300)); +/// current = cursor.current(); +/// assert_eq!(current, (&20, &200)); +/// +/// // Moving past the last element causes the cursor to return [`None`]. +/// cursor = cursor.move_next().unwrap(); +/// current = cursor.current(); +/// assert_eq!(current, (&30, &300)); +/// let cursor = cursor.move_next(); +/// assert!(cursor.is_none()); +/// +/// # Ok::<(), Error>(()) +/// ``` +/// +/// A cursor can also be obtained at the last element in the tree. +/// +/// ``` +/// use kernel::{alloc::flags, rbtree::RBTree}; +/// +/// // Create a new tree. +/// let mut tree = RBTree::new(); +/// +/// // Insert three elements. +/// tree.try_create_and_insert(10, 100, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(20, 200, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(30, 300, flags::GFP_KERNEL)?; +/// +/// let mut cursor = tree.cursor_back().unwrap(); +/// let current = cursor.current(); +/// assert_eq!(current, (&30, &300)); +/// +/// # Ok::<(), Error>(()) +/// ``` +/// +/// Obtaining a cursor returns [`None`] if the tree is empty. +/// +/// ``` +/// use kernel::rbtree::RBTree; +/// +/// let mut tree: RBTree<u16, u16> = RBTree::new(); +/// assert!(tree.cursor_front().is_none()); +/// +/// # Ok::<(), Error>(()) +/// ``` +/// +/// [`RBTree::cursor_lower_bound`] can be used to start at an arbitrary node in the tree. +/// +/// ``` +/// use kernel::{alloc::flags, rbtree::RBTree}; +/// +/// // Create a new tree. +/// let mut tree = RBTree::new(); +/// +/// // Insert five elements. +/// tree.try_create_and_insert(10, 100, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(20, 200, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(30, 300, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(40, 400, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(50, 500, flags::GFP_KERNEL)?; +/// +/// // If the provided key exists, a cursor to that key is returned. +/// let cursor = tree.cursor_lower_bound(&20).unwrap(); +/// let current = cursor.current(); +/// assert_eq!(current, (&20, &200)); +/// +/// // If the provided key doesn't exist, a cursor to the first larger element in sort order is returned. +/// let cursor = tree.cursor_lower_bound(&25).unwrap(); +/// let current = cursor.current(); +/// assert_eq!(current, (&30, &300)); +/// +/// // If there is no larger key, [`None`] is returned. +/// let cursor = tree.cursor_lower_bound(&55); +/// assert!(cursor.is_none()); +/// +/// # Ok::<(), Error>(()) +/// ``` +/// +/// The cursor allows mutation of values in the tree. +/// +/// ``` +/// use kernel::{alloc::flags, rbtree::RBTree}; +/// +/// // Create a new tree. +/// let mut tree = RBTree::new(); +/// +/// // Insert three elements. +/// tree.try_create_and_insert(10, 100, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(20, 200, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(30, 300, flags::GFP_KERNEL)?; +/// +/// // Retrieve a cursor. +/// let mut cursor = tree.cursor_front().unwrap(); +/// +/// // Get a mutable reference to the current value. +/// let (k, v) = cursor.current_mut(); +/// *v = 1000; +/// +/// // The updated value is reflected in the tree. +/// let updated = tree.get(&10).unwrap(); +/// assert_eq!(updated, &1000); +/// +/// # Ok::<(), Error>(()) +/// ``` +/// +/// It also allows node removal. The following examples demonstrate the behavior of removing the current node. +/// +/// ``` +/// use kernel::{alloc::flags, rbtree::RBTree}; +/// +/// // Create a new tree. +/// let mut tree = RBTree::new(); +/// +/// // Insert three elements. +/// tree.try_create_and_insert(10, 100, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(20, 200, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(30, 300, flags::GFP_KERNEL)?; +/// +/// // Remove the first element. +/// let mut cursor = tree.cursor_front().unwrap(); +/// let mut current = cursor.current(); +/// assert_eq!(current, (&10, &100)); +/// cursor = cursor.remove_current().0.unwrap(); +/// +/// // If a node exists after the current element, it is returned. +/// current = cursor.current(); +/// assert_eq!(current, (&20, &200)); +/// +/// // Get a cursor to the last element, and remove it. +/// cursor = tree.cursor_back().unwrap(); +/// current = cursor.current(); +/// assert_eq!(current, (&30, &300)); +/// +/// // Since there is no next node, the previous node is returned. +/// cursor = cursor.remove_current().0.unwrap(); +/// current = cursor.current(); +/// assert_eq!(current, (&20, &200)); +/// +/// // Removing the last element in the tree returns [`None`]. +/// assert!(cursor.remove_current().0.is_none()); +/// +/// # Ok::<(), Error>(()) +/// ``` +/// +/// Nodes adjacent to the current node can also be removed. +/// +/// ``` +/// use kernel::{alloc::flags, rbtree::RBTree}; +/// +/// // Create a new tree. +/// let mut tree = RBTree::new(); +/// +/// // Insert three elements. +/// tree.try_create_and_insert(10, 100, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(20, 200, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(30, 300, flags::GFP_KERNEL)?; +/// +/// // Get a cursor to the first element. +/// let mut cursor = tree.cursor_front().unwrap(); +/// let mut current = cursor.current(); +/// assert_eq!(current, (&10, &100)); +/// +/// // Calling `remove_prev` from the first element returns [`None`]. +/// assert!(cursor.remove_prev().is_none()); +/// +/// // Get a cursor to the last element. +/// cursor = tree.cursor_back().unwrap(); +/// current = cursor.current(); +/// assert_eq!(current, (&30, &300)); +/// +/// // Calling `remove_prev` removes and returns the middle element. +/// assert_eq!(cursor.remove_prev().unwrap().to_key_value(), (20, 200)); +/// +/// // Calling `remove_next` from the last element returns [`None`]. +/// assert!(cursor.remove_next().is_none()); +/// +/// // Move to the first element +/// cursor = cursor.move_prev().unwrap(); +/// current = cursor.current(); +/// assert_eq!(current, (&10, &100)); +/// +/// // Calling `remove_next` removes and returns the last element. +/// assert_eq!(cursor.remove_next().unwrap().to_key_value(), (30, 300)); +/// +/// # Ok::<(), Error>(()) +/// +/// ``` +/// +/// # Invariants +/// - `current` points to a node that is in the same [`RBTree`] as `tree`. +pub struct Cursor<'a, K, V> { + tree: &'a mut RBTree<K, V>, + current: NonNull<bindings::rb_node>, +} + +// SAFETY: The [`Cursor`] has exclusive access to both `K` and `V`, so it is sufficient to require them to be `Send`. +// The cursor only gives out immutable references to the keys, but since it has excusive access to those same +// keys, `Send` is sufficient. `Sync` would be okay, but it is more restrictive to the user. +unsafe impl<'a, K: Send, V: Send> Send for Cursor<'a, K, V> {} + +// SAFETY: The [`Cursor`] gives out immutable references to K and mutable references to V, +// so it has the same thread safety requirements as mutable references. +unsafe impl<'a, K: Sync, V: Sync> Sync for Cursor<'a, K, V> {} + +impl<'a, K, V> Cursor<'a, K, V> { + /// The current node + pub fn current(&self) -> (&K, &V) { + // SAFETY: + // - `self.current` is a valid node by the type invariants. + // - We have an immutable reference by the function signature. + unsafe { Self::to_key_value(self.current) } + } + + /// The current node, with a mutable value + pub fn current_mut(&mut self) -> (&K, &mut V) { + // SAFETY: + // - `self.current` is a valid node by the type invariants. + // - We have an mutable reference by the function signature. + unsafe { Self::to_key_value_mut(self.current) } + } + + /// Remove the current node from the tree. + /// + /// Returns a tuple where the first element is a cursor to the next node, if it exists, + /// else the previous node, else [`None`] (if the tree becomes empty). The second element + /// is the removed node. + pub fn remove_current(self) -> (Option<Self>, RBTreeNode<K, V>) { + let prev = self.get_neighbor_raw(Direction::Prev); + let next = self.get_neighbor_raw(Direction::Next); + // SAFETY: By the type invariant of `Self`, all non-null `rb_node` pointers stored in `self` + // point to the links field of `Node<K, V>` objects. + let this = unsafe { container_of!(self.current.as_ptr(), Node<K, V>, links) }.cast_mut(); + // SAFETY: `this` is valid by the type invariants as described above. + let node = unsafe { Box::from_raw(this) }; + let node = RBTreeNode { node }; + // SAFETY: The reference to the tree used to create the cursor outlives the cursor, so + // the tree cannot change. By the tree invariant, all nodes are valid. + unsafe { bindings::rb_erase(&mut (*this).links, addr_of_mut!(self.tree.root)) }; + + let current = match (prev, next) { + (_, Some(next)) => next, + (Some(prev), None) => prev, + (None, None) => { + return (None, node); + } + }; + + ( + // INVARIANT: + // - `current` is a valid node in the [`RBTree`] pointed to by `self.tree`. + Some(Self { + current, + tree: self.tree, + }), + node, + ) + } + + /// Remove the previous node, returning it if it exists. + pub fn remove_prev(&mut self) -> Option<RBTreeNode<K, V>> { + self.remove_neighbor(Direction::Prev) + } + + /// Remove the next node, returning it if it exists. + pub fn remove_next(&mut self) -> Option<RBTreeNode<K, V>> { + self.remove_neighbor(Direction::Next) + } + + fn remove_neighbor(&mut self, direction: Direction) -> Option<RBTreeNode<K, V>> { + if let Some(neighbor) = self.get_neighbor_raw(direction) { + let neighbor = neighbor.as_ptr(); + // SAFETY: The reference to the tree used to create the cursor outlives the cursor, so + // the tree cannot change. By the tree invariant, all nodes are valid. + unsafe { bindings::rb_erase(neighbor, addr_of_mut!(self.tree.root)) }; + // SAFETY: By the type invariant of `Self`, all non-null `rb_node` pointers stored in `self` + // point to the links field of `Node<K, V>` objects. + let this = unsafe { container_of!(neighbor, Node<K, V>, links) }.cast_mut(); + // SAFETY: `this` is valid by the type invariants as described above. + let node = unsafe { Box::from_raw(this) }; + return Some(RBTreeNode { node }); + } + None + } + + /// Move the cursor to the previous node, returning [`None`] if it doesn't exist. + pub fn move_prev(self) -> Option<Self> { + self.mv(Direction::Prev) + } + + /// Move the cursor to the next node, returning [`None`] if it doesn't exist. + pub fn move_next(self) -> Option<Self> { + self.mv(Direction::Next) + } + + fn mv(self, direction: Direction) -> Option<Self> { + // INVARIANT: + // - `neighbor` is a valid node in the [`RBTree`] pointed to by `self.tree`. + self.get_neighbor_raw(direction).map(|neighbor| Self { + tree: self.tree, + current: neighbor, + }) + } + + /// Access the previous node without moving the cursor. + pub fn peek_prev(&self) -> Option<(&K, &V)> { + self.peek(Direction::Prev) + } + + /// Access the previous node without moving the cursor. + pub fn peek_next(&self) -> Option<(&K, &V)> { + self.peek(Direction::Next) + } + + fn peek(&self, direction: Direction) -> Option<(&K, &V)> { + self.get_neighbor_raw(direction).map(|neighbor| { + // SAFETY: + // - `neighbor` is a valid tree node. + // - By the function signature, we have an immutable reference to `self`. + unsafe { Self::to_key_value(neighbor) } + }) + } + + /// Access the previous node mutably without moving the cursor. + pub fn peek_prev_mut(&mut self) -> Option<(&K, &mut V)> { + self.peek_mut(Direction::Prev) + } + + /// Access the next node mutably without moving the cursor. + pub fn peek_next_mut(&mut self) -> Option<(&K, &mut V)> { + self.peek_mut(Direction::Next) + } + + fn peek_mut(&mut self, direction: Direction) -> Option<(&K, &mut V)> { + self.get_neighbor_raw(direction).map(|neighbor| { + // SAFETY: + // - `neighbor` is a valid tree node. + // - By the function signature, we have a mutable reference to `self`. + unsafe { Self::to_key_value_mut(neighbor) } + }) + } + + fn get_neighbor_raw(&self, direction: Direction) -> Option<NonNull<bindings::rb_node>> { + // SAFETY: `self.current` is valid by the type invariants. + let neighbor = unsafe { + match direction { + Direction::Prev => bindings::rb_prev(self.current.as_ptr()), + Direction::Next => bindings::rb_next(self.current.as_ptr()), + } + }; + + NonNull::new(neighbor) + } + + /// SAFETY: + /// - `node` must be a valid pointer to a node in an [`RBTree`]. + /// - The caller has immutable access to `node` for the duration of 'b. + unsafe fn to_key_value<'b>(node: NonNull<bindings::rb_node>) -> (&'b K, &'b V) { + // SAFETY: the caller guarantees that `node` is a valid pointer in an `RBTree`. + let (k, v) = unsafe { Self::to_key_value_raw(node) }; + // SAFETY: the caller guarantees immutable access to `node`. + (k, unsafe { &*v }) + } + + /// SAFETY: + /// - `node` must be a valid pointer to a node in an [`RBTree`]. + /// - The caller has mutable access to `node` for the duration of 'b. + unsafe fn to_key_value_mut<'b>(node: NonNull<bindings::rb_node>) -> (&'b K, &'b mut V) { + // SAFETY: the caller guarantees that `node` is a valid pointer in an `RBTree`. + let (k, v) = unsafe { Self::to_key_value_raw(node) }; + // SAFETY: the caller guarantees mutable access to `node`. + (k, unsafe { &mut *v }) + } + + /// SAFETY: + /// - `node` must be a valid pointer to a node in an [`RBTree`]. + /// - The caller has immutable access to the key for the duration of 'b. + unsafe fn to_key_value_raw<'b>(node: NonNull<bindings::rb_node>) -> (&'b K, *mut V) { + // SAFETY: By the type invariant of `Self`, all non-null `rb_node` pointers stored in `self` + // point to the links field of `Node<K, V>` objects. + let this = unsafe { container_of!(node.as_ptr(), Node<K, V>, links) }.cast_mut(); + // SAFETY: The passed `node` is the current node or a non-null neighbor, + // thus `this` is valid by the type invariants. + let k = unsafe { &(*this).key }; + // SAFETY: The passed `node` is the current node or a non-null neighbor, + // thus `this` is valid by the type invariants. + let v = unsafe { addr_of_mut!((*this).value) }; + (k, v) + } +} + +/// Direction for [`Cursor`] operations. +enum Direction { + /// the node immediately before, in sort order + Prev, + /// the node immediately after, in sort order + Next, +} + +impl<'a, K, V> IntoIterator for &'a RBTree<K, V> { + type Item = (&'a K, &'a V); + type IntoIter = Iter<'a, K, V>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +/// An iterator over the nodes of a [`RBTree`]. +/// +/// Instances are created by calling [`RBTree::iter`]. +pub struct Iter<'a, K, V> { + _tree: PhantomData<&'a RBTree<K, V>>, + iter_raw: IterRaw<K, V>, +} + +// SAFETY: The [`Iter`] gives out immutable references to K and V, so it has the same +// thread safety requirements as immutable references. +unsafe impl<'a, K: Sync, V: Sync> Send for Iter<'a, K, V> {} + +// SAFETY: The [`Iter`] gives out immutable references to K and V, so it has the same +// thread safety requirements as immutable references. +unsafe impl<'a, K: Sync, V: Sync> Sync for Iter<'a, K, V> {} + +impl<'a, K, V> Iterator for Iter<'a, K, V> { + type Item = (&'a K, &'a V); + + fn next(&mut self) -> Option<Self::Item> { + // SAFETY: Due to `self._tree`, `k` and `v` are valid for the lifetime of `'a`. + self.iter_raw.next().map(|(k, v)| unsafe { (&*k, &*v) }) + } +} + +impl<'a, K, V> IntoIterator for &'a mut RBTree<K, V> { + type Item = (&'a K, &'a mut V); + type IntoIter = IterMut<'a, K, V>; + + fn into_iter(self) -> Self::IntoIter { + self.iter_mut() + } +} + +/// A mutable iterator over the nodes of a [`RBTree`]. +/// +/// Instances are created by calling [`RBTree::iter_mut`]. +pub struct IterMut<'a, K, V> { + _tree: PhantomData<&'a mut RBTree<K, V>>, + iter_raw: IterRaw<K, V>, +} + +// SAFETY: The [`IterMut`] has exclusive access to both `K` and `V`, so it is sufficient to require them to be `Send`. +// The iterator only gives out immutable references to the keys, but since the iterator has excusive access to those same +// keys, `Send` is sufficient. `Sync` would be okay, but it is more restrictive to the user. +unsafe impl<'a, K: Send, V: Send> Send for IterMut<'a, K, V> {} + +// SAFETY: The [`IterMut`] gives out immutable references to K and mutable references to V, so it has the same +// thread safety requirements as mutable references. +unsafe impl<'a, K: Sync, V: Sync> Sync for IterMut<'a, K, V> {} + +impl<'a, K, V> Iterator for IterMut<'a, K, V> { + type Item = (&'a K, &'a mut V); + + fn next(&mut self) -> Option<Self::Item> { + self.iter_raw.next().map(|(k, v)| + // SAFETY: Due to `&mut self`, we have exclusive access to `k` and `v`, for the lifetime of `'a`. + unsafe { (&*k, &mut *v) }) + } +} + +/// A raw iterator over the nodes of a [`RBTree`]. +/// +/// # Invariants +/// - `self.next` is a valid pointer. +/// - `self.next` points to a node stored inside of a valid `RBTree`. +struct IterRaw<K, V> { + next: *mut bindings::rb_node, + _phantom: PhantomData<fn() -> (K, V)>, +} + +impl<K, V> Iterator for IterRaw<K, V> { + type Item = (*mut K, *mut V); + + fn next(&mut self) -> Option<Self::Item> { + if self.next.is_null() { + return None; + } + + // SAFETY: By the type invariant of `IterRaw`, `self.next` is a valid node in an `RBTree`, + // and by the type invariant of `RBTree`, all nodes point to the links field of `Node<K, V>` objects. + let cur = unsafe { container_of!(self.next, Node<K, V>, links) }.cast_mut(); + + // SAFETY: `self.next` is a valid tree node by the type invariants. + self.next = unsafe { bindings::rb_next(self.next) }; + + // SAFETY: By the same reasoning above, it is safe to dereference the node. + Some(unsafe { (addr_of_mut!((*cur).key), addr_of_mut!((*cur).value)) }) + } +} + +/// A memory reservation for a red-black tree node. +/// +/// +/// It contains the memory needed to hold a node that can be inserted into a red-black tree. One +/// can be obtained by directly allocating it ([`RBTreeNodeReservation::new`]). +pub struct RBTreeNodeReservation<K, V> { + node: Box<MaybeUninit<Node<K, V>>>, +} + +impl<K, V> RBTreeNodeReservation<K, V> { + /// Allocates memory for a node to be eventually initialised and inserted into the tree via a + /// call to [`RBTree::insert`]. + pub fn new(flags: Flags) -> Result<RBTreeNodeReservation<K, V>> { + Ok(RBTreeNodeReservation { + node: <Box<_> as BoxExt<_>>::new_uninit(flags)?, + }) + } +} + +// SAFETY: This doesn't actually contain K or V, and is just a memory allocation. Those can always +// be moved across threads. +unsafe impl<K, V> Send for RBTreeNodeReservation<K, V> {} + +// SAFETY: This doesn't actually contain K or V, and is just a memory allocation. +unsafe impl<K, V> Sync for RBTreeNodeReservation<K, V> {} + +impl<K, V> RBTreeNodeReservation<K, V> { + /// Initialises a node reservation. + /// + /// It then becomes an [`RBTreeNode`] that can be inserted into a tree. + pub fn into_node(mut self, key: K, value: V) -> RBTreeNode<K, V> { + self.node.write(Node { + key, + value, + links: bindings::rb_node::default(), + }); + // SAFETY: We just wrote to it. + let node = unsafe { self.node.assume_init() }; + RBTreeNode { node } + } +} + +/// A red-black tree node. +/// +/// The node is fully initialised (with key and value) and can be inserted into a tree without any +/// extra allocations or failure paths. +pub struct RBTreeNode<K, V> { + node: Box<Node<K, V>>, +} + +impl<K, V> RBTreeNode<K, V> { + /// Allocates and initialises a node that can be inserted into the tree via + /// [`RBTree::insert`]. + pub fn new(key: K, value: V, flags: Flags) -> Result<RBTreeNode<K, V>> { + Ok(RBTreeNodeReservation::new(flags)?.into_node(key, value)) + } + + /// Get the key and value from inside the node. + pub fn to_key_value(self) -> (K, V) { + (self.node.key, self.node.value) + } +} + +// SAFETY: If K and V can be sent across threads, then it's also okay to send [`RBTreeNode`] across +// threads. +unsafe impl<K: Send, V: Send> Send for RBTreeNode<K, V> {} + +// SAFETY: If K and V can be accessed without synchronization, then it's also okay to access +// [`RBTreeNode`] without synchronization. +unsafe impl<K: Sync, V: Sync> Sync for RBTreeNode<K, V> {} + +impl<K, V> RBTreeNode<K, V> { + /// Drop the key and value, but keep the allocation. + /// + /// It then becomes a reservation that can be re-initialised into a different node (i.e., with + /// a different key and/or value). + /// + /// The existing key and value are dropped in-place as part of this operation, that is, memory + /// may be freed (but only for the key/value; memory for the node itself is kept for reuse). + pub fn into_reservation(self) -> RBTreeNodeReservation<K, V> { + RBTreeNodeReservation { + node: Box::drop_contents(self.node), + } + } +} + +/// A view into a single entry in a map, which may either be vacant or occupied. +/// +/// This enum is constructed from the [`RBTree::entry`]. +/// +/// [`entry`]: fn@RBTree::entry +pub enum Entry<'a, K, V> { + /// This [`RBTree`] does not have a node with this key. + Vacant(VacantEntry<'a, K, V>), + /// This [`RBTree`] already has a node with this key. + Occupied(OccupiedEntry<'a, K, V>), +} + +/// Like [`Entry`], except that it doesn't have ownership of the key. +enum RawEntry<'a, K, V> { + Vacant(RawVacantEntry<'a, K, V>), + Occupied(OccupiedEntry<'a, K, V>), +} + +/// A view into a vacant entry in a [`RBTree`]. It is part of the [`Entry`] enum. +pub struct VacantEntry<'a, K, V> { + key: K, + raw: RawVacantEntry<'a, K, V>, +} + +/// Like [`VacantEntry`], but doesn't hold on to the key. +/// +/// # Invariants +/// - `parent` may be null if the new node becomes the root. +/// - `child_field_of_parent` is a valid pointer to the left-child or right-child of `parent`. If `parent` is +/// null, it is a pointer to the root of the [`RBTree`]. +struct RawVacantEntry<'a, K, V> { + rbtree: *mut RBTree<K, V>, + /// The node that will become the parent of the new node if we insert one. + parent: *mut bindings::rb_node, + /// This points to the left-child or right-child field of `parent`, or `root` if `parent` is + /// null. + child_field_of_parent: *mut *mut bindings::rb_node, + _phantom: PhantomData<&'a mut RBTree<K, V>>, +} + +impl<'a, K, V> RawVacantEntry<'a, K, V> { + /// Inserts the given node into the [`RBTree`] at this entry. + /// + /// The `node` must have a key such that inserting it here does not break the ordering of this + /// [`RBTree`]. + fn insert(self, node: RBTreeNode<K, V>) -> &'a mut V { + let node = Box::into_raw(node.node); + + // SAFETY: `node` is valid at least until we call `Box::from_raw`, which only happens when + // the node is removed or replaced. + let node_links = unsafe { addr_of_mut!((*node).links) }; + + // INVARIANT: We are linking in a new node, which is valid. It remains valid because we + // "forgot" it with `Box::into_raw`. + // SAFETY: The type invariants of `RawVacantEntry` are exactly the safety requirements of `rb_link_node`. + unsafe { bindings::rb_link_node(node_links, self.parent, self.child_field_of_parent) }; + + // SAFETY: All pointers are valid. `node` has just been inserted into the tree. + unsafe { bindings::rb_insert_color(node_links, addr_of_mut!((*self.rbtree).root)) }; + + // SAFETY: The node is valid until we remove it from the tree. + unsafe { &mut (*node).value } + } +} + +impl<'a, K, V> VacantEntry<'a, K, V> { + /// Inserts the given node into the [`RBTree`] at this entry. + pub fn insert(self, value: V, reservation: RBTreeNodeReservation<K, V>) -> &'a mut V { + self.raw.insert(reservation.into_node(self.key, value)) + } +} + +/// A view into an occupied entry in a [`RBTree`]. It is part of the [`Entry`] enum. +/// +/// # Invariants +/// - `node_links` is a valid, non-null pointer to a tree node in `self.rbtree` +pub struct OccupiedEntry<'a, K, V> { + rbtree: &'a mut RBTree<K, V>, + /// The node that this entry corresponds to. + node_links: *mut bindings::rb_node, +} + +impl<'a, K, V> OccupiedEntry<'a, K, V> { + /// Gets a reference to the value in the entry. + pub fn get(&self) -> &V { + // SAFETY: + // - `self.node_links` is a valid pointer to a node in the tree. + // - We have shared access to the underlying tree, and can thus give out a shared reference. + unsafe { &(*container_of!(self.node_links, Node<K, V>, links)).value } + } + + /// Gets a mutable reference to the value in the entry. + pub fn get_mut(&mut self) -> &mut V { + // SAFETY: + // - `self.node_links` is a valid pointer to a node in the tree. + // - We have exclusive access to the underlying tree, and can thus give out a mutable reference. + unsafe { &mut (*(container_of!(self.node_links, Node<K, V>, links).cast_mut())).value } + } + + /// Converts the entry into a mutable reference to its value. + /// + /// If you need multiple references to the `OccupiedEntry`, see [`self#get_mut`]. + pub fn into_mut(self) -> &'a mut V { + // SAFETY: + // - `self.node_links` is a valid pointer to a node in the tree. + // - This consumes the `&'a mut RBTree<K, V>`, therefore it can give out a mutable reference that lives for `'a`. + unsafe { &mut (*(container_of!(self.node_links, Node<K, V>, links).cast_mut())).value } + } + + /// Remove this entry from the [`RBTree`]. + pub fn remove_node(self) -> RBTreeNode<K, V> { + // SAFETY: The node is a node in the tree, so it is valid. + unsafe { bindings::rb_erase(self.node_links, &mut self.rbtree.root) }; + + // INVARIANT: The node is being returned and the caller may free it, however, it was + // removed from the tree. So the invariants still hold. + RBTreeNode { + // SAFETY: The node was a node in the tree, but we removed it, so we can convert it + // back into a box. + node: unsafe { + Box::from_raw(container_of!(self.node_links, Node<K, V>, links).cast_mut()) + }, + } + } + + /// Takes the value of the entry out of the map, and returns it. + pub fn remove(self) -> V { + self.remove_node().node.value + } + + /// Swap the current node for the provided node. + /// + /// The key of both nodes must be equal. + fn replace(self, node: RBTreeNode<K, V>) -> RBTreeNode<K, V> { + let node = Box::into_raw(node.node); + + // SAFETY: `node` is valid at least until we call `Box::from_raw`, which only happens when + // the node is removed or replaced. + let new_node_links = unsafe { addr_of_mut!((*node).links) }; + + // SAFETY: This updates the pointers so that `new_node_links` is in the tree where + // `self.node_links` used to be. + unsafe { + bindings::rb_replace_node(self.node_links, new_node_links, &mut self.rbtree.root) + }; + + // SAFETY: + // - `self.node_ptr` produces a valid pointer to a node in the tree. + // - Now that we removed this entry from the tree, we can convert the node to a box. + let old_node = + unsafe { Box::from_raw(container_of!(self.node_links, Node<K, V>, links).cast_mut()) }; + + RBTreeNode { node: old_node } + } +} + +struct Node<K, V> { + links: bindings::rb_node, + key: K, + value: V, +} diff --git a/rust/kernel/sizes.rs b/rust/kernel/sizes.rs new file mode 100644 index 000000000000..834c343e4170 --- /dev/null +++ b/rust/kernel/sizes.rs @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Commonly used sizes. +//! +//! C headers: [`include/linux/sizes.h`](srctree/include/linux/sizes.h). + +/// 0x00000400 +pub const SZ_1K: usize = bindings::SZ_1K as usize; +/// 0x00000800 +pub const SZ_2K: usize = bindings::SZ_2K as usize; +/// 0x00001000 +pub const SZ_4K: usize = bindings::SZ_4K as usize; +/// 0x00002000 +pub const SZ_8K: usize = bindings::SZ_8K as usize; +/// 0x00004000 +pub const SZ_16K: usize = bindings::SZ_16K as usize; +/// 0x00008000 +pub const SZ_32K: usize = bindings::SZ_32K as usize; +/// 0x00010000 +pub const SZ_64K: usize = bindings::SZ_64K as usize; +/// 0x00020000 +pub const SZ_128K: usize = bindings::SZ_128K as usize; +/// 0x00040000 +pub const SZ_256K: usize = bindings::SZ_256K as usize; +/// 0x00080000 +pub const SZ_512K: usize = bindings::SZ_512K as usize; diff --git a/rust/kernel/std_vendor.rs b/rust/kernel/std_vendor.rs index 39679a960c1a..67bf9d37ddb5 100644 --- a/rust/kernel/std_vendor.rs +++ b/rust/kernel/std_vendor.rs @@ -136,7 +136,7 @@ /// /// [`std::dbg`]: https://doc.rust-lang.org/std/macro.dbg.html /// [`eprintln`]: https://doc.rust-lang.org/std/macro.eprintln.html -/// [`printk`]: https://www.kernel.org/doc/html/latest/core-api/printk-basics.html +/// [`printk`]: https://docs.kernel.org/core-api/printk-basics.html /// [`pr_info`]: crate::pr_info! /// [`pr_debug`]: crate::pr_debug! #[macro_export] diff --git a/rust/kernel/sync/arc.rs b/rust/kernel/sync/arc.rs index 3673496c2363..3021f30fd822 100644 --- a/rust/kernel/sync/arc.rs +++ b/rust/kernel/sync/arc.rs @@ -12,12 +12,13 @@ //! 2. It does not support weak references, which allows it to be half the size. //! 3. It saturates the reference count instead of aborting when it goes over a threshold. //! 4. It does not provide a `get_mut` method, so the ref counted object is pinned. +//! 5. The object in [`Arc`] is pinned implicitly. //! //! [`Arc`]: https://doc.rust-lang.org/std/sync/struct.Arc.html use crate::{ alloc::{box_ext::BoxExt, AllocError, Flags}, - error::{self, Error}, + bindings, init::{self, InPlaceInit, Init, PinInit}, try_init, types::{ForeignOwnable, Opaque}, @@ -209,28 +210,6 @@ impl<T> Arc<T> { // `Arc` object. Ok(unsafe { Self::from_inner(Box::leak(inner).into()) }) } - - /// Use the given initializer to in-place initialize a `T`. - /// - /// If `T: !Unpin` it will not be able to move afterwards. - #[inline] - pub fn pin_init<E>(init: impl PinInit<T, E>, flags: Flags) -> error::Result<Self> - where - Error: From<E>, - { - UniqueArc::pin_init(init, flags).map(|u| u.into()) - } - - /// Use the given initializer to in-place initialize a `T`. - /// - /// This is equivalent to [`Arc<T>::pin_init`], since an [`Arc`] is always pinned. - #[inline] - pub fn init<E>(init: impl Init<T, E>, flags: Flags) -> error::Result<Self> - where - Error: From<E>, - { - UniqueArc::init(init, flags).map(|u| u.into()) - } } impl<T: ?Sized> Arc<T> { diff --git a/rust/kernel/sync/locked_by.rs b/rust/kernel/sync/locked_by.rs index babc731bd5f6..ce2ee8d87865 100644 --- a/rust/kernel/sync/locked_by.rs +++ b/rust/kernel/sync/locked_by.rs @@ -83,8 +83,12 @@ pub struct LockedBy<T: ?Sized, U: ?Sized> { // SAFETY: `LockedBy` can be transferred across thread boundaries iff the data it protects can. unsafe impl<T: ?Sized + Send, U: ?Sized> Send for LockedBy<T, U> {} -// SAFETY: `LockedBy` serialises the interior mutability it provides, so it is `Sync` as long as the -// data it protects is `Send`. +// SAFETY: If `T` is not `Sync`, then parallel shared access to this `LockedBy` allows you to use +// `access_mut` to hand out `&mut T` on one thread at the time. The requirement that `T: Send` is +// sufficient to allow that. +// +// If `T` is `Sync`, then the `access` method also becomes available, which allows you to obtain +// several `&T` from several threads at once. However, this is okay as `T` is `Sync`. unsafe impl<T: ?Sized + Send, U: ?Sized> Sync for LockedBy<T, U> {} impl<T, U> LockedBy<T, U> { @@ -118,7 +122,10 @@ impl<T: ?Sized, U> LockedBy<T, U> { /// /// Panics if `owner` is different from the data protected by the lock used in /// [`new`](LockedBy::new). - pub fn access<'a>(&'a self, owner: &'a U) -> &'a T { + pub fn access<'a>(&'a self, owner: &'a U) -> &'a T + where + T: Sync, + { build_assert!( size_of::<U>() > 0, "`U` cannot be a ZST because `owner` wouldn't be unique" @@ -127,7 +134,10 @@ impl<T: ?Sized, U> LockedBy<T, U> { panic!("mismatched owners"); } - // SAFETY: `owner` is evidence that the owner is locked. + // SAFETY: `owner` is evidence that there are only shared references to the owner for the + // duration of 'a, so it's not possible to use `Self::access_mut` to obtain a mutable + // reference to the inner value that aliases with this shared reference. The type is `Sync` + // so there are no other requirements. unsafe { &*self.data.get() } } diff --git a/rust/kernel/types.rs b/rust/kernel/types.rs index 2e7c9008621f..9e7ca066355c 100644 --- a/rust/kernel/types.rs +++ b/rust/kernel/types.rs @@ -7,8 +7,9 @@ use alloc::boxed::Box; use core::{ cell::UnsafeCell, marker::{PhantomData, PhantomPinned}, - mem::MaybeUninit, + mem::{ManuallyDrop, MaybeUninit}, ops::{Deref, DerefMut}, + pin::Pin, ptr::NonNull, }; @@ -26,7 +27,10 @@ pub trait ForeignOwnable: Sized { /// Converts a Rust-owned object to a foreign-owned one. /// - /// The foreign representation is a pointer to void. + /// The foreign representation is a pointer to void. There are no guarantees for this pointer. + /// For example, it might be invalid, dangling or pointing to uninitialized memory. Using it in + /// any way except for [`ForeignOwnable::from_foreign`], [`ForeignOwnable::borrow`], + /// [`ForeignOwnable::try_from_foreign`] can result in undefined behavior. fn into_foreign(self) -> *const core::ffi::c_void; /// Borrows a foreign-owned object. @@ -89,6 +93,32 @@ impl<T: 'static> ForeignOwnable for Box<T> { } } +impl<T: 'static> ForeignOwnable for Pin<Box<T>> { + type Borrowed<'a> = Pin<&'a T>; + + fn into_foreign(self) -> *const core::ffi::c_void { + // SAFETY: We are still treating the box as pinned. + Box::into_raw(unsafe { Pin::into_inner_unchecked(self) }) as _ + } + + unsafe fn borrow<'a>(ptr: *const core::ffi::c_void) -> Pin<&'a T> { + // SAFETY: The safety requirements for this function ensure that the object is still alive, + // so it is safe to dereference the raw pointer. + // The safety requirements of `from_foreign` also ensure that the object remains alive for + // the lifetime of the returned value. + let r = unsafe { &*ptr.cast() }; + + // SAFETY: This pointer originates from a `Pin<Box<T>>`. + unsafe { Pin::new_unchecked(r) } + } + + unsafe fn from_foreign(ptr: *const core::ffi::c_void) -> Self { + // SAFETY: The safety requirements of this function ensure that `ptr` comes from a previous + // call to `Self::into_foreign`. + unsafe { Pin::new_unchecked(Box::from_raw(ptr as _)) } + } +} + impl ForeignOwnable for () { type Borrowed<'a> = (); @@ -366,6 +396,35 @@ impl<T: AlwaysRefCounted> ARef<T> { _p: PhantomData, } } + + /// Consumes the `ARef`, returning a raw pointer. + /// + /// This function does not change the refcount. After calling this function, the caller is + /// responsible for the refcount previously managed by the `ARef`. + /// + /// # Examples + /// + /// ``` + /// use core::ptr::NonNull; + /// use kernel::types::{ARef, AlwaysRefCounted}; + /// + /// struct Empty {} + /// + /// unsafe impl AlwaysRefCounted for Empty { + /// fn inc_ref(&self) {} + /// unsafe fn dec_ref(_obj: NonNull<Self>) {} + /// } + /// + /// let mut data = Empty {}; + /// let ptr = NonNull::<Empty>::new(&mut data as *mut _).unwrap(); + /// let data_ref: ARef<Empty> = unsafe { ARef::from_raw(ptr) }; + /// let raw_ptr: NonNull<Empty> = ARef::into_raw(data_ref); + /// + /// assert_eq!(ptr, raw_ptr); + /// ``` + pub fn into_raw(me: Self) -> NonNull<T> { + ManuallyDrop::new(me).ptr + } } impl<T: AlwaysRefCounted> Clone for ARef<T> { @@ -409,3 +468,67 @@ pub enum Either<L, R> { /// Constructs an instance of [`Either`] containing a value of type `R`. Right(R), } + +/// Types for which any bit pattern is valid. +/// +/// Not all types are valid for all values. For example, a `bool` must be either zero or one, so +/// reading arbitrary bytes into something that contains a `bool` is not okay. +/// +/// It's okay for the type to have padding, as initializing those bytes has no effect. +/// +/// # Safety +/// +/// All bit-patterns must be valid for this type. This type must not have interior mutability. +pub unsafe trait FromBytes {} + +// SAFETY: All bit patterns are acceptable values of the types below. +unsafe impl FromBytes for u8 {} +unsafe impl FromBytes for u16 {} +unsafe impl FromBytes for u32 {} +unsafe impl FromBytes for u64 {} +unsafe impl FromBytes for usize {} +unsafe impl FromBytes for i8 {} +unsafe impl FromBytes for i16 {} +unsafe impl FromBytes for i32 {} +unsafe impl FromBytes for i64 {} +unsafe impl FromBytes for isize {} +// SAFETY: If all bit patterns are acceptable for individual values in an array, then all bit +// patterns are also acceptable for arrays of that type. +unsafe impl<T: FromBytes> FromBytes for [T] {} +unsafe impl<T: FromBytes, const N: usize> FromBytes for [T; N] {} + +/// Types that can be viewed as an immutable slice of initialized bytes. +/// +/// If a struct implements this trait, then it is okay to copy it byte-for-byte to userspace. This +/// means that it should not have any padding, as padding bytes are uninitialized. Reading +/// uninitialized memory is not just undefined behavior, it may even lead to leaking sensitive +/// information on the stack to userspace. +/// +/// The struct should also not hold kernel pointers, as kernel pointer addresses are also considered +/// sensitive. However, leaking kernel pointers is not considered undefined behavior by Rust, so +/// this is a correctness requirement, but not a safety requirement. +/// +/// # Safety +/// +/// Values of this type may not contain any uninitialized bytes. This type must not have interior +/// mutability. +pub unsafe trait AsBytes {} + +// SAFETY: Instances of the following types have no uninitialized portions. +unsafe impl AsBytes for u8 {} +unsafe impl AsBytes for u16 {} +unsafe impl AsBytes for u32 {} +unsafe impl AsBytes for u64 {} +unsafe impl AsBytes for usize {} +unsafe impl AsBytes for i8 {} +unsafe impl AsBytes for i16 {} +unsafe impl AsBytes for i32 {} +unsafe impl AsBytes for i64 {} +unsafe impl AsBytes for isize {} +unsafe impl AsBytes for bool {} +unsafe impl AsBytes for char {} +unsafe impl AsBytes for str {} +// SAFETY: If individual values in an array have no uninitialized portions, then the array itself +// does not have any uninitialized portions either. +unsafe impl<T: AsBytes> AsBytes for [T] {} +unsafe impl<T: AsBytes, const N: usize> AsBytes for [T; N] {} diff --git a/rust/kernel/uaccess.rs b/rust/kernel/uaccess.rs new file mode 100644 index 000000000000..e9347cff99ab --- /dev/null +++ b/rust/kernel/uaccess.rs @@ -0,0 +1,388 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Slices to user space memory regions. +//! +//! C header: [`include/linux/uaccess.h`](srctree/include/linux/uaccess.h) + +use crate::{ + alloc::Flags, + bindings, + error::Result, + prelude::*, + types::{AsBytes, FromBytes}, +}; +use alloc::vec::Vec; +use core::ffi::{c_ulong, c_void}; +use core::mem::{size_of, MaybeUninit}; + +/// The type used for userspace addresses. +pub type UserPtr = usize; + +/// A pointer to an area in userspace memory, which can be either read-only or read-write. +/// +/// All methods on this struct are safe: attempting to read or write on bad addresses (either out of +/// the bound of the slice or unmapped addresses) will return [`EFAULT`]. Concurrent access, +/// *including data races to/from userspace memory*, is permitted, because fundamentally another +/// userspace thread/process could always be modifying memory at the same time (in the same way that +/// userspace Rust's [`std::io`] permits data races with the contents of files on disk). In the +/// presence of a race, the exact byte values read/written are unspecified but the operation is +/// well-defined. Kernelspace code should validate its copy of data after completing a read, and not +/// expect that multiple reads of the same address will return the same value. +/// +/// These APIs are designed to make it difficult to accidentally write TOCTOU (time-of-check to +/// time-of-use) bugs. Every time a memory location is read, the reader's position is advanced by +/// the read length and the next read will start from there. This helps prevent accidentally reading +/// the same location twice and causing a TOCTOU bug. +/// +/// Creating a [`UserSliceReader`] and/or [`UserSliceWriter`] consumes the `UserSlice`, helping +/// ensure that there aren't multiple readers or writers to the same location. +/// +/// If double-fetching a memory location is necessary for some reason, then that is done by creating +/// multiple readers to the same memory location, e.g. using [`clone_reader`]. +/// +/// # Examples +/// +/// Takes a region of userspace memory from the current process, and modify it by adding one to +/// every byte in the region. +/// +/// ```no_run +/// use alloc::vec::Vec; +/// use core::ffi::c_void; +/// use kernel::error::Result; +/// use kernel::uaccess::{UserPtr, UserSlice}; +/// +/// fn bytes_add_one(uptr: UserPtr, len: usize) -> Result<()> { +/// let (read, mut write) = UserSlice::new(uptr, len).reader_writer(); +/// +/// let mut buf = Vec::new(); +/// read.read_all(&mut buf, GFP_KERNEL)?; +/// +/// for b in &mut buf { +/// *b = b.wrapping_add(1); +/// } +/// +/// write.write_slice(&buf)?; +/// Ok(()) +/// } +/// ``` +/// +/// Example illustrating a TOCTOU (time-of-check to time-of-use) bug. +/// +/// ```no_run +/// use alloc::vec::Vec; +/// use core::ffi::c_void; +/// use kernel::error::{code::EINVAL, Result}; +/// use kernel::uaccess::{UserPtr, UserSlice}; +/// +/// /// Returns whether the data in this region is valid. +/// fn is_valid(uptr: UserPtr, len: usize) -> Result<bool> { +/// let read = UserSlice::new(uptr, len).reader(); +/// +/// let mut buf = Vec::new(); +/// read.read_all(&mut buf, GFP_KERNEL)?; +/// +/// todo!() +/// } +/// +/// /// Returns the bytes behind this user pointer if they are valid. +/// fn get_bytes_if_valid(uptr: UserPtr, len: usize) -> Result<Vec<u8>> { +/// if !is_valid(uptr, len)? { +/// return Err(EINVAL); +/// } +/// +/// let read = UserSlice::new(uptr, len).reader(); +/// +/// let mut buf = Vec::new(); +/// read.read_all(&mut buf, GFP_KERNEL)?; +/// +/// // THIS IS A BUG! The bytes could have changed since we checked them. +/// // +/// // To avoid this kind of bug, don't call `UserSlice::new` multiple +/// // times with the same address. +/// Ok(buf) +/// } +/// ``` +/// +/// [`std::io`]: https://doc.rust-lang.org/std/io/index.html +/// [`clone_reader`]: UserSliceReader::clone_reader +pub struct UserSlice { + ptr: UserPtr, + length: usize, +} + +impl UserSlice { + /// Constructs a user slice from a raw pointer and a length in bytes. + /// + /// Constructing a [`UserSlice`] performs no checks on the provided address and length, it can + /// safely be constructed inside a kernel thread with no current userspace process. Reads and + /// writes wrap the kernel APIs `copy_from_user` and `copy_to_user`, which check the memory map + /// of the current process and enforce that the address range is within the user range (no + /// additional calls to `access_ok` are needed). Validity of the pointer is checked when you + /// attempt to read or write, not in the call to `UserSlice::new`. + /// + /// Callers must be careful to avoid time-of-check-time-of-use (TOCTOU) issues. The simplest way + /// is to create a single instance of [`UserSlice`] per user memory block as it reads each byte + /// at most once. + pub fn new(ptr: UserPtr, length: usize) -> Self { + UserSlice { ptr, length } + } + + /// Reads the entirety of the user slice, appending it to the end of the provided buffer. + /// + /// Fails with [`EFAULT`] if the read happens on a bad address. + pub fn read_all(self, buf: &mut Vec<u8>, flags: Flags) -> Result { + self.reader().read_all(buf, flags) + } + + /// Constructs a [`UserSliceReader`]. + pub fn reader(self) -> UserSliceReader { + UserSliceReader { + ptr: self.ptr, + length: self.length, + } + } + + /// Constructs a [`UserSliceWriter`]. + pub fn writer(self) -> UserSliceWriter { + UserSliceWriter { + ptr: self.ptr, + length: self.length, + } + } + + /// Constructs both a [`UserSliceReader`] and a [`UserSliceWriter`]. + /// + /// Usually when this is used, you will first read the data, and then overwrite it afterwards. + pub fn reader_writer(self) -> (UserSliceReader, UserSliceWriter) { + ( + UserSliceReader { + ptr: self.ptr, + length: self.length, + }, + UserSliceWriter { + ptr: self.ptr, + length: self.length, + }, + ) + } +} + +/// A reader for [`UserSlice`]. +/// +/// Used to incrementally read from the user slice. +pub struct UserSliceReader { + ptr: UserPtr, + length: usize, +} + +impl UserSliceReader { + /// Skip the provided number of bytes. + /// + /// Returns an error if skipping more than the length of the buffer. + pub fn skip(&mut self, num_skip: usize) -> Result { + // Update `self.length` first since that's the fallible part of this operation. + self.length = self.length.checked_sub(num_skip).ok_or(EFAULT)?; + self.ptr = self.ptr.wrapping_add(num_skip); + Ok(()) + } + + /// Create a reader that can access the same range of data. + /// + /// Reading from the clone does not advance the current reader. + /// + /// The caller should take care to not introduce TOCTOU issues, as described in the + /// documentation for [`UserSlice`]. + pub fn clone_reader(&self) -> UserSliceReader { + UserSliceReader { + ptr: self.ptr, + length: self.length, + } + } + + /// Returns the number of bytes left to be read from this reader. + /// + /// Note that even reading less than this number of bytes may fail. + pub fn len(&self) -> usize { + self.length + } + + /// Returns `true` if no data is available in the io buffer. + pub fn is_empty(&self) -> bool { + self.length == 0 + } + + /// Reads raw data from the user slice into a kernel buffer. + /// + /// For a version that uses `&mut [u8]`, please see [`UserSliceReader::read_slice`]. + /// + /// Fails with [`EFAULT`] if the read happens on a bad address, or if the read goes out of + /// bounds of this [`UserSliceReader`]. This call may modify `out` even if it returns an error. + /// + /// # Guarantees + /// + /// After a successful call to this method, all bytes in `out` are initialized. + pub fn read_raw(&mut self, out: &mut [MaybeUninit<u8>]) -> Result { + let len = out.len(); + let out_ptr = out.as_mut_ptr().cast::<c_void>(); + if len > self.length { + return Err(EFAULT); + } + let Ok(len_ulong) = c_ulong::try_from(len) else { + return Err(EFAULT); + }; + // SAFETY: `out_ptr` points into a mutable slice of length `len_ulong`, so we may write + // that many bytes to it. + let res = + unsafe { bindings::copy_from_user(out_ptr, self.ptr as *const c_void, len_ulong) }; + if res != 0 { + return Err(EFAULT); + } + self.ptr = self.ptr.wrapping_add(len); + self.length -= len; + Ok(()) + } + + /// Reads raw data from the user slice into a kernel buffer. + /// + /// Fails with [`EFAULT`] if the read happens on a bad address, or if the read goes out of + /// bounds of this [`UserSliceReader`]. This call may modify `out` even if it returns an error. + pub fn read_slice(&mut self, out: &mut [u8]) -> Result { + // SAFETY: The types are compatible and `read_raw` doesn't write uninitialized bytes to + // `out`. + let out = unsafe { &mut *(out as *mut [u8] as *mut [MaybeUninit<u8>]) }; + self.read_raw(out) + } + + /// Reads a value of the specified type. + /// + /// Fails with [`EFAULT`] if the read happens on a bad address, or if the read goes out of + /// bounds of this [`UserSliceReader`]. + pub fn read<T: FromBytes>(&mut self) -> Result<T> { + let len = size_of::<T>(); + if len > self.length { + return Err(EFAULT); + } + let Ok(len_ulong) = c_ulong::try_from(len) else { + return Err(EFAULT); + }; + let mut out: MaybeUninit<T> = MaybeUninit::uninit(); + // SAFETY: The local variable `out` is valid for writing `size_of::<T>()` bytes. + // + // By using the _copy_from_user variant, we skip the check_object_size check that verifies + // the kernel pointer. This mirrors the logic on the C side that skips the check when the + // length is a compile-time constant. + let res = unsafe { + bindings::_copy_from_user( + out.as_mut_ptr().cast::<c_void>(), + self.ptr as *const c_void, + len_ulong, + ) + }; + if res != 0 { + return Err(EFAULT); + } + self.ptr = self.ptr.wrapping_add(len); + self.length -= len; + // SAFETY: The read above has initialized all bytes in `out`, and since `T` implements + // `FromBytes`, any bit-pattern is a valid value for this type. + Ok(unsafe { out.assume_init() }) + } + + /// Reads the entirety of the user slice, appending it to the end of the provided buffer. + /// + /// Fails with [`EFAULT`] if the read happens on a bad address. + pub fn read_all(mut self, buf: &mut Vec<u8>, flags: Flags) -> Result { + let len = self.length; + VecExt::<u8>::reserve(buf, len, flags)?; + + // The call to `try_reserve` was successful, so the spare capacity is at least `len` bytes + // long. + self.read_raw(&mut buf.spare_capacity_mut()[..len])?; + + // SAFETY: Since the call to `read_raw` was successful, so the next `len` bytes of the + // vector have been initialized. + unsafe { buf.set_len(buf.len() + len) }; + Ok(()) + } +} + +/// A writer for [`UserSlice`]. +/// +/// Used to incrementally write into the user slice. +pub struct UserSliceWriter { + ptr: UserPtr, + length: usize, +} + +impl UserSliceWriter { + /// Returns the amount of space remaining in this buffer. + /// + /// Note that even writing less than this number of bytes may fail. + pub fn len(&self) -> usize { + self.length + } + + /// Returns `true` if no more data can be written to this buffer. + pub fn is_empty(&self) -> bool { + self.length == 0 + } + + /// Writes raw data to this user pointer from a kernel buffer. + /// + /// Fails with [`EFAULT`] if the write happens on a bad address, or if the write goes out of + /// bounds of this [`UserSliceWriter`]. This call may modify the associated userspace slice even + /// if it returns an error. + pub fn write_slice(&mut self, data: &[u8]) -> Result { + let len = data.len(); + let data_ptr = data.as_ptr().cast::<c_void>(); + if len > self.length { + return Err(EFAULT); + } + let Ok(len_ulong) = c_ulong::try_from(len) else { + return Err(EFAULT); + }; + // SAFETY: `data_ptr` points into an immutable slice of length `len_ulong`, so we may read + // that many bytes from it. + let res = unsafe { bindings::copy_to_user(self.ptr as *mut c_void, data_ptr, len_ulong) }; + if res != 0 { + return Err(EFAULT); + } + self.ptr = self.ptr.wrapping_add(len); + self.length -= len; + Ok(()) + } + + /// Writes the provided Rust value to this userspace pointer. + /// + /// Fails with [`EFAULT`] if the write happens on a bad address, or if the write goes out of + /// bounds of this [`UserSliceWriter`]. This call may modify the associated userspace slice even + /// if it returns an error. + pub fn write<T: AsBytes>(&mut self, value: &T) -> Result { + let len = size_of::<T>(); + if len > self.length { + return Err(EFAULT); + } + let Ok(len_ulong) = c_ulong::try_from(len) else { + return Err(EFAULT); + }; + // SAFETY: The reference points to a value of type `T`, so it is valid for reading + // `size_of::<T>()` bytes. + // + // By using the _copy_to_user variant, we skip the check_object_size check that verifies the + // kernel pointer. This mirrors the logic on the C side that skips the check when the length + // is a compile-time constant. + let res = unsafe { + bindings::_copy_to_user( + self.ptr as *mut c_void, + (value as *const T).cast::<c_void>(), + len_ulong, + ) + }; + if res != 0 { + return Err(EFAULT); + } + self.ptr = self.ptr.wrapping_add(len); + self.length -= len; + Ok(()) + } +} diff --git a/rust/kernel/workqueue.rs b/rust/kernel/workqueue.rs index 1cec63a2aea8..553a5cba2adc 100644 --- a/rust/kernel/workqueue.rs +++ b/rust/kernel/workqueue.rs @@ -482,24 +482,26 @@ pub unsafe trait HasWork<T, const ID: u64 = 0> { /// use kernel::sync::Arc; /// use kernel::workqueue::{self, impl_has_work, Work}; /// -/// struct MyStruct { -/// work_field: Work<MyStruct, 17>, +/// struct MyStruct<'a, T, const N: usize> { +/// work_field: Work<MyStruct<'a, T, N>, 17>, +/// f: fn(&'a [T; N]), /// } /// /// impl_has_work! { -/// impl HasWork<MyStruct, 17> for MyStruct { self.work_field } +/// impl{'a, T, const N: usize} HasWork<MyStruct<'a, T, N>, 17> +/// for MyStruct<'a, T, N> { self.work_field } /// } /// ``` #[macro_export] macro_rules! impl_has_work { - ($(impl$(<$($implarg:ident),*>)? + ($(impl$({$($generics:tt)*})? HasWork<$work_type:ty $(, $id:tt)?> - for $self:ident $(<$($selfarg:ident),*>)? + for $self:ty { self.$field:ident } )*) => {$( // SAFETY: The implementation of `raw_get_work` only compiles if the field has the right // type. - unsafe impl$(<$($implarg),*>)? $crate::workqueue::HasWork<$work_type $(, $id)?> for $self $(<$($selfarg),*>)? { + unsafe impl$(<$($generics)+>)? $crate::workqueue::HasWork<$work_type $(, $id)?> for $self { const OFFSET: usize = ::core::mem::offset_of!(Self, $field) as usize; #[inline] @@ -515,7 +517,7 @@ macro_rules! impl_has_work { pub use impl_has_work; impl_has_work! { - impl<T> HasWork<Self> for ClosureWork<T> { self.work } + impl{T} HasWork<Self> for ClosureWork<T> { self.work } } unsafe impl<T, const ID: u64> WorkItemPointer<ID> for Arc<T> diff --git a/rust/macros/lib.rs b/rust/macros/lib.rs index 520eae5fd792..a626b1145e5c 100644 --- a/rust/macros/lib.rs +++ b/rust/macros/lib.rs @@ -2,6 +2,10 @@ //! Crate for all kernel procedural macros. +// When fixdep scans this, it will find this string `CONFIG_RUSTC_VERSION_TEXT` +// and thus add a dependency on `include/config/RUSTC_VERSION_TEXT`, which is +// touched by Kconfig when the version string from the compiler changes. + #[macro_use] mod quote; mod concat_idents; @@ -35,6 +39,7 @@ use proc_macro::TokenStream; /// author: "Rust for Linux Contributors", /// description: "My very own kernel module!", /// license: "GPL", +/// alias: ["alternate_module_name"], /// } /// /// struct MyModule; @@ -55,13 +60,45 @@ use proc_macro::TokenStream; /// } /// ``` /// +/// ## Firmware +/// +/// The following example shows how to declare a kernel module that needs +/// to load binary firmware files. You need to specify the file names of +/// the firmware in the `firmware` field. The information is embedded +/// in the `modinfo` section of the kernel module. For example, a tool to +/// build an initramfs uses this information to put the firmware files into +/// the initramfs image. +/// +/// ```ignore +/// use kernel::prelude::*; +/// +/// module!{ +/// type: MyDeviceDriverModule, +/// name: "my_device_driver_module", +/// author: "Rust for Linux Contributors", +/// description: "My device driver requires firmware", +/// license: "GPL", +/// firmware: ["my_device_firmware1.bin", "my_device_firmware2.bin"], +/// } +/// +/// struct MyDeviceDriverModule; +/// +/// impl kernel::Module for MyDeviceDriverModule { +/// fn init() -> Result<Self> { +/// Ok(Self) +/// } +/// } +/// ``` +/// /// # Supported argument types /// - `type`: type which implements the [`Module`] trait (required). -/// - `name`: byte array of the name of the kernel module (required). -/// - `author`: byte array of the author of the kernel module. -/// - `description`: byte array of the description of the kernel module. -/// - `license`: byte array of the license of the kernel module (required). -/// - `alias`: byte array of alias name of the kernel module. +/// - `name`: ASCII string literal of the name of the kernel module (required). +/// - `author`: string literal of the author of the kernel module. +/// - `description`: string literal of the description of the kernel module. +/// - `license`: ASCII string literal of the license of the kernel module (required). +/// - `alias`: array of ASCII string literals of the alias names of the kernel module. +/// - `firmware`: array of ASCII string literals of the firmware files of +/// the kernel module. #[proc_macro] pub fn module(ts: TokenStream) -> TokenStream { module::module(ts) @@ -312,7 +349,7 @@ pub fn pinned_drop(args: TokenStream, input: TokenStream) -> TokenStream { /// /// Currently supported modifiers are: /// * `span`: change the span of concatenated identifier to the span of the specified token. By -/// default the span of the `[< >]` group is used. +/// default the span of the `[< >]` group is used. /// * `lower`: change the identifier to lower case. /// * `upper`: change the identifier to upper case. /// diff --git a/rust/macros/module.rs b/rust/macros/module.rs index acd0393b5095..aef3b132f32b 100644 --- a/rust/macros/module.rs +++ b/rust/macros/module.rs @@ -97,14 +97,22 @@ struct ModuleInfo { author: Option<String>, description: Option<String>, alias: Option<Vec<String>>, + firmware: Option<Vec<String>>, } impl ModuleInfo { fn parse(it: &mut token_stream::IntoIter) -> Self { let mut info = ModuleInfo::default(); - const EXPECTED_KEYS: &[&str] = - &["type", "name", "author", "description", "license", "alias"]; + const EXPECTED_KEYS: &[&str] = &[ + "type", + "name", + "author", + "description", + "license", + "alias", + "firmware", + ]; const REQUIRED_KEYS: &[&str] = &["type", "name", "license"]; let mut seen_keys = Vec::new(); @@ -131,6 +139,7 @@ impl ModuleInfo { "description" => info.description = Some(expect_string(it)), "license" => info.license = expect_string_ascii(it), "alias" => info.alias = Some(expect_string_array(it)), + "firmware" => info.firmware = Some(expect_string_array(it)), _ => panic!( "Unknown key \"{}\". Valid keys are: {:?}.", key, EXPECTED_KEYS @@ -186,6 +195,11 @@ pub(crate) fn module(ts: TokenStream) -> TokenStream { modinfo.emit("alias", &alias); } } + if let Some(firmware) = info.firmware { + for fw in firmware { + modinfo.emit("firmware", &fw); + } + } // Built-in modules also export the `file` modinfo string. let file = @@ -203,7 +217,11 @@ pub(crate) fn module(ts: TokenStream) -> TokenStream { // freed until the module is unloaded. #[cfg(MODULE)] static THIS_MODULE: kernel::ThisModule = unsafe {{ - kernel::ThisModule::from_ptr(&kernel::bindings::__this_module as *const _ as *mut _) + extern \"C\" {{ + static __this_module: kernel::types::Opaque<kernel::bindings::module>; + }} + + kernel::ThisModule::from_ptr(__this_module.get()) }}; #[cfg(not(MODULE))] static THIS_MODULE: kernel::ThisModule = unsafe {{ @@ -244,6 +262,12 @@ pub(crate) fn module(ts: TokenStream) -> TokenStream { #[cfg(MODULE)] #[doc(hidden)] + #[used] + #[link_section = \".init.data\"] + static __UNIQUE_ID___addressable_init_module: unsafe extern \"C\" fn() -> i32 = init_module; + + #[cfg(MODULE)] + #[doc(hidden)] #[no_mangle] pub extern \"C\" fn cleanup_module() {{ // SAFETY: @@ -255,6 +279,12 @@ pub(crate) fn module(ts: TokenStream) -> TokenStream { unsafe {{ __exit() }} }} + #[cfg(MODULE)] + #[doc(hidden)] + #[used] + #[link_section = \".exit.data\"] + static __UNIQUE_ID___addressable_cleanup_module: extern \"C\" fn() = cleanup_module; + // Built-in modules are initialized through an initcall pointer // and the identifiers need to be unique. #[cfg(not(MODULE))] diff --git a/rust/uapi/lib.rs b/rust/uapi/lib.rs index 0caad902ba40..80a00260e3e7 100644 --- a/rust/uapi/lib.rs +++ b/rust/uapi/lib.rs @@ -14,6 +14,7 @@ #![cfg_attr(test, allow(unsafe_op_in_unsafe_fn))] #![allow( clippy::all, + dead_code, missing_docs, non_camel_case_types, non_upper_case_globals, diff --git a/rust/uapi/uapi_helper.h b/rust/uapi/uapi_helper.h index 08f5e9334c9e..76d3f103e764 100644 --- a/rust/uapi/uapi_helper.h +++ b/rust/uapi/uapi_helper.h @@ -7,5 +7,6 @@ */ #include <uapi/asm-generic/ioctl.h> +#include <uapi/linux/mdio.h> #include <uapi/linux/mii.h> #include <uapi/linux/ethtool.h> |