diff options
Diffstat (limited to 'rust')
156 files changed, 23024 insertions, 3044 deletions
diff --git a/rust/.kunitconfig b/rust/.kunitconfig new file mode 100644 index 000000000000..9e72a5ab03c9 --- /dev/null +++ b/rust/.kunitconfig @@ -0,0 +1,3 @@ +CONFIG_KUNIT=y +CONFIG_RUST=y +CONFIG_RUST_KERNEL_DOCTESTS=y diff --git a/rust/Makefile b/rust/Makefile index 43cd7f845a9a..bfa915b0e588 100644 --- a/rust/Makefile +++ b/rust/Makefile @@ -11,11 +11,8 @@ always-$(CONFIG_RUST) += exports_core_generated.h obj-$(CONFIG_RUST) += helpers/helpers.o CFLAGS_REMOVE_helpers/helpers.o = -Wmissing-prototypes -Wmissing-declarations -always-$(CONFIG_RUST) += libmacros.so -no-clean-files += libmacros.so - always-$(CONFIG_RUST) += bindings/bindings_generated.rs bindings/bindings_helpers_generated.rs -obj-$(CONFIG_RUST) += bindings.o kernel.o +obj-$(CONFIG_RUST) += bindings.o pin_init.o kernel.o always-$(CONFIG_RUST) += exports_helpers_generated.h \ exports_bindings_generated.h exports_kernel_generated.h @@ -37,10 +34,21 @@ obj-$(CONFIG_RUST_KERNEL_DOCTESTS) += doctests_kernel_generated.o obj-$(CONFIG_RUST_KERNEL_DOCTESTS) += doctests_kernel_generated_kunit.o always-$(subst y,$(CONFIG_RUST),$(CONFIG_JUMP_LABEL)) += kernel/generated_arch_static_branch_asm.rs +ifndef CONFIG_UML +always-$(subst y,$(CONFIG_RUST),$(CONFIG_BUG)) += kernel/generated_arch_warn_asm.rs kernel/generated_arch_reachable_asm.rs +endif -# Avoids running `$(RUSTC)` for the sysroot when it may not be available. +# Avoids running `$(RUSTC)` when it may not be available. ifdef CONFIG_RUST +libmacros_name := $(shell MAKEFLAGS= $(RUSTC) --print file-names --crate-name macros --crate-type proc-macro - </dev/null) +libmacros_extension := $(patsubst libmacros.%,%,$(libmacros_name)) + +libpin_init_internal_name := $(shell MAKEFLAGS= $(RUSTC) --print file-names --crate-name pin_init_internal --crate-type proc-macro - </dev/null) +libpin_init_internal_extension := $(patsubst libpin_init_internal.%,%,$(libpin_init_internal_name)) + +always-$(CONFIG_RUST) += $(libmacros_name) $(libpin_init_internal_name) + # `$(rust_flags)` is passed in case the user added `--sysroot`. rustc_sysroot := $(shell MAKEFLAGS= $(RUSTC) $(rust_flags) --print sysroot) rustc_host_target := $(shell $(RUSTC) --version --verbose | grep -F 'host: ' | cut -d' ' -f2) @@ -55,14 +63,25 @@ endif core-cfgs = \ --cfg no_fp_fmt_parse +core-edition := $(if $(call rustc-min-version,108700),2024,2021) + +# `rustdoc` did not save the target modifiers, thus workaround for +# the time being (https://github.com/rust-lang/rust/issues/144521). +rustdoc_modifiers_workaround := $(if $(call rustc-min-version,108800),-Cunsafe-allow-abi-mismatch=fixed-x18) + +# `rustc` recognizes `--remap-path-prefix` since 1.26.0, but `rustdoc` only +# since Rust 1.81.0. Moreover, `rustdoc` ICEs on out-of-tree builds since Rust +# 1.82.0 (https://github.com/rust-lang/rust/issues/138520). Thus workaround both +# issues skipping the flag. The former also applies to `RUSTDOC TK`. quiet_cmd_rustdoc = RUSTDOC $(if $(rustdoc_host),H, ) $< cmd_rustdoc = \ OBJTREE=$(abspath $(objtree)) \ - $(RUSTDOC) $(filter-out $(skip_flags),$(if $(rustdoc_host),$(rust_common_flags),$(rust_flags))) \ + $(RUSTDOC) $(filter-out $(skip_flags) --remap-path-prefix=%,$(if $(rustdoc_host),$(rust_common_flags),$(rust_flags))) \ $(rustc_target_flags) -L$(objtree)/$(obj) \ -Zunstable-options --generate-link-to-definition \ --output $(rustdoc_output) \ --crate-name $(subst rustdoc-,,$@) \ + $(rustdoc_modifiers_workaround) \ $(if $(rustdoc_host),,--sysroot=/dev/null) \ @$(objtree)/include/generated/rustc_cfg $< @@ -78,7 +97,7 @@ quiet_cmd_rustdoc = RUSTDOC $(if $(rustdoc_host),H, ) $< # command-like flags to solve the issue. Meanwhile, we use the non-custom case # and then retouch the generated files. rustdoc: rustdoc-core rustdoc-macros rustdoc-compiler_builtins \ - rustdoc-kernel + rustdoc-kernel rustdoc-pin_init $(Q)cp $(srctree)/Documentation/images/logo.svg $(rustdoc_output)/static.files/ $(Q)cp $(srctree)/Documentation/images/COPYING-logo $(rustdoc_output)/static.files/ $(Q)find $(rustdoc_output) -name '*.html' -type f -print0 | xargs -0 sed -Ei \ @@ -92,14 +111,14 @@ rustdoc: rustdoc-core rustdoc-macros rustdoc-compiler_builtins \ rustdoc-macros: private rustdoc_host = yes rustdoc-macros: private rustc_target_flags = --crate-type proc-macro \ --extern proc_macro -rustdoc-macros: $(src)/macros/lib.rs FORCE +rustdoc-macros: $(src)/macros/lib.rs rustdoc-clean FORCE +$(call if_changed,rustdoc) # Starting with Rust 1.82.0, skipping `-Wrustdoc::unescaped_backticks` should # not be needed -- see https://github.com/rust-lang/rust/pull/128307. -rustdoc-core: private skip_flags = -Wrustdoc::unescaped_backticks -rustdoc-core: private rustc_target_flags = $(core-cfgs) -rustdoc-core: $(RUST_LIB_SRC)/core/src/lib.rs FORCE +rustdoc-core: private skip_flags = --edition=2021 -Wrustdoc::unescaped_backticks +rustdoc-core: private rustc_target_flags = --edition=$(core-edition) $(core-cfgs) +rustdoc-core: $(RUST_LIB_SRC)/core/src/lib.rs rustdoc-clean FORCE +$(call if_changed,rustdoc) rustdoc-compiler_builtins: $(src)/compiler_builtins.rs rustdoc-core FORCE @@ -108,18 +127,35 @@ rustdoc-compiler_builtins: $(src)/compiler_builtins.rs rustdoc-core FORCE rustdoc-ffi: $(src)/ffi.rs rustdoc-core FORCE +$(call if_changed,rustdoc) -rustdoc-kernel: private rustc_target_flags = --extern ffi \ - --extern build_error --extern macros=$(objtree)/$(obj)/libmacros.so \ +rustdoc-pin_init_internal: private rustdoc_host = yes +rustdoc-pin_init_internal: private rustc_target_flags = --cfg kernel \ + --extern proc_macro --crate-type proc-macro +rustdoc-pin_init_internal: $(src)/pin-init/internal/src/lib.rs \ + rustdoc-clean FORCE + +$(call if_changed,rustdoc) + +rustdoc-pin_init: private rustdoc_host = yes +rustdoc-pin_init: private rustc_target_flags = --extern pin_init_internal \ + --extern macros --extern alloc --cfg kernel --cfg feature=\"alloc\" +rustdoc-pin_init: $(src)/pin-init/src/lib.rs rustdoc-pin_init_internal \ + rustdoc-macros FORCE + +$(call if_changed,rustdoc) + +rustdoc-kernel: private rustc_target_flags = --extern ffi --extern pin_init \ + --extern build_error --extern macros \ --extern bindings --extern uapi rustdoc-kernel: $(src)/kernel/lib.rs rustdoc-core rustdoc-ffi rustdoc-macros \ - rustdoc-compiler_builtins $(obj)/libmacros.so \ + rustdoc-pin_init rustdoc-compiler_builtins $(obj)/$(libmacros_name) \ $(obj)/bindings.o FORCE +$(call if_changed,rustdoc) -quiet_cmd_rustc_test_library = RUSTC TL $< +rustdoc-clean: FORCE + $(Q)rm -rf $(rustdoc_output) + +quiet_cmd_rustc_test_library = $(RUSTC_OR_CLIPPY_QUIET) TL $< cmd_rustc_test_library = \ OBJTREE=$(abspath $(objtree)) \ - $(RUSTC) $(rust_common_flags) \ + $(RUSTC_OR_CLIPPY) $(rust_common_flags) \ @$(objtree)/include/generated/rustc_cfg $(rustc_target_flags) \ --crate-type $(if $(rustc_test_library_proc),proc-macro,rlib) \ --out-dir $(objtree)/$(obj)/test --cfg testlib \ @@ -137,12 +173,24 @@ rusttestlib-macros: private rustc_test_library_proc = yes rusttestlib-macros: $(src)/macros/lib.rs FORCE +$(call if_changed,rustc_test_library) +rusttestlib-pin_init_internal: private rustc_target_flags = --cfg kernel \ + --extern proc_macro +rusttestlib-pin_init_internal: private rustc_test_library_proc = yes +rusttestlib-pin_init_internal: $(src)/pin-init/internal/src/lib.rs FORCE + +$(call if_changed,rustc_test_library) + +rusttestlib-pin_init: private rustc_target_flags = --extern pin_init_internal \ + --extern macros --cfg kernel +rusttestlib-pin_init: $(src)/pin-init/src/lib.rs rusttestlib-macros \ + rusttestlib-pin_init_internal $(obj)/$(libpin_init_internal_name) FORCE + +$(call if_changed,rustc_test_library) + rusttestlib-kernel: private rustc_target_flags = --extern ffi \ - --extern build_error --extern macros \ + --extern build_error --extern macros --extern pin_init \ --extern bindings --extern uapi -rusttestlib-kernel: $(src)/kernel/lib.rs \ - rusttestlib-bindings rusttestlib-uapi rusttestlib-build_error \ - $(obj)/libmacros.so $(obj)/bindings.o FORCE +rusttestlib-kernel: $(src)/kernel/lib.rs rusttestlib-bindings rusttestlib-uapi \ + rusttestlib-build_error rusttestlib-pin_init $(obj)/$(libmacros_name) \ + $(obj)/bindings.o FORCE +$(call if_changed,rustc_test_library) rusttestlib-bindings: private rustc_target_flags = --extern ffi @@ -158,6 +206,7 @@ quiet_cmd_rustdoc_test = RUSTDOC T $< RUST_MODFILE=test.rs \ OBJTREE=$(abspath $(objtree)) \ $(RUSTDOC) --test $(rust_common_flags) \ + -Zcrate-attr='feature(used_with_arg)' \ @$(objtree)/include/generated/rustc_cfg \ $(rustc_target_flags) $(rustdoc_test_target_flags) \ $(rustdoc_test_quiet) \ @@ -169,12 +218,13 @@ quiet_cmd_rustdoc_test_kernel = RUSTDOC TK $< rm -rf $(objtree)/$(obj)/test/doctests/kernel; \ mkdir -p $(objtree)/$(obj)/test/doctests/kernel; \ OBJTREE=$(abspath $(objtree)) \ - $(RUSTDOC) --test $(rust_flags) \ - -L$(objtree)/$(obj) --extern ffi --extern kernel \ - --extern build_error --extern macros \ + $(RUSTDOC) --test $(filter-out --remap-path-prefix=%,$(rust_flags)) \ + -L$(objtree)/$(obj) --extern ffi --extern pin_init \ + --extern kernel --extern build_error --extern macros \ --extern bindings --extern uapi \ --no-run --crate-name kernel -Zunstable-options \ --sysroot=/dev/null \ + $(rustdoc_modifiers_workaround) \ --test-builder $(objtree)/scripts/rustdoc_test_builder \ $< $(rustdoc_test_kernel_quiet); \ $(objtree)/scripts/rustdoc_test_gen @@ -187,10 +237,10 @@ quiet_cmd_rustdoc_test_kernel = RUSTDOC TK $< # We cannot use `-Zpanic-abort-tests` because some tests are dynamic, # so for the moment we skip `-Cpanic=abort`. -quiet_cmd_rustc_test = RUSTC T $< +quiet_cmd_rustc_test = $(RUSTC_OR_CLIPPY_QUIET) T $< cmd_rustc_test = \ OBJTREE=$(abspath $(objtree)) \ - $(RUSTC) --test $(rust_common_flags) \ + $(RUSTC_OR_CLIPPY) --test $(rust_common_flags) \ @$(objtree)/include/generated/rustc_cfg \ $(rustc_target_flags) --out-dir $(objtree)/$(obj)/test \ -L$(objtree)/$(obj)/test \ @@ -201,18 +251,18 @@ quiet_cmd_rustc_test = RUSTC T $< rusttest: rusttest-macros rusttest-kernel rusttest-macros: private rustc_target_flags = --extern proc_macro \ - --extern macros --extern kernel + --extern macros --extern kernel --extern pin_init rusttest-macros: private rustdoc_test_target_flags = --crate-type proc-macro rusttest-macros: $(src)/macros/lib.rs \ - rusttestlib-macros rusttestlib-kernel FORCE + rusttestlib-macros rusttestlib-kernel rusttestlib-pin_init FORCE +$(call if_changed,rustc_test) +$(call if_changed,rustdoc_test) -rusttest-kernel: private rustc_target_flags = --extern ffi \ +rusttest-kernel: private rustc_target_flags = --extern ffi --extern pin_init \ --extern build_error --extern macros --extern bindings --extern uapi rusttest-kernel: $(src)/kernel/lib.rs rusttestlib-ffi rusttestlib-kernel \ rusttestlib-build_error rusttestlib-macros rusttestlib-bindings \ - rusttestlib-uapi FORCE + rusttestlib-uapi rusttestlib-pin_init FORCE +$(call if_changed,rustc_test) ifdef CONFIG_CC_IS_CLANG @@ -230,7 +280,8 @@ bindgen_skip_c_flags := -mno-fp-ret-in-387 -mpreferred-stack-boundary=% \ -mfunction-return=thunk-extern -mrecord-mcount -mabi=lp64 \ -mindirect-branch-cs-prefix -mstack-protector-guard% -mtraceback=no \ -mno-pointers-to-nested-functions -mno-string \ - -mno-strict-align -mstrict-align \ + -mno-strict-align -mstrict-align -mdirect-extern-access \ + -mexplicit-relocs -mno-check-zero-division \ -fconserve-stack -falign-jumps=% -falign-loops=% \ -femit-struct-debug-baseonly -fno-ipa-cp-clone -fno-ipa-sra \ -fno-partial-inlining -fplugin-arg-arm_ssp_per_task_plugin-% \ @@ -238,12 +289,15 @@ bindgen_skip_c_flags := -mno-fp-ret-in-387 -mpreferred-stack-boundary=% \ -fzero-call-used-regs=% -fno-stack-clash-protection \ -fno-inline-functions-called-once -fsanitize=bounds-strict \ -fstrict-flex-arrays=% -fmin-function-alignment=% \ - -fzero-init-padding-bits=% \ + -fzero-init-padding-bits=% -mno-fdpic \ --param=% --param asan-% # Derived from `scripts/Makefile.clang`. BINDGEN_TARGET_x86 := x86_64-linux-gnu BINDGEN_TARGET_arm64 := aarch64-linux-gnu +BINDGEN_TARGET_arm := arm-linux-gnueabi +BINDGEN_TARGET_loongarch := loongarch64-linux-gnusf +BINDGEN_TARGET_um := $(BINDGEN_TARGET_$(SUBARCH)) BINDGEN_TARGET := $(BINDGEN_TARGET_$(SRCARCH)) # All warnings are inhibited since GCC builds are very experimental, @@ -330,10 +384,11 @@ $(obj)/bindings/bindings_helpers_generated.rs: private bindgen_target_extra = ; $(obj)/bindings/bindings_helpers_generated.rs: $(src)/helpers/helpers.c FORCE $(call if_changed_dep,bindgen) +rust_exports = $(NM) -p --defined-only $(1) | awk '$$2~/(T|R|D|B)/ && $$3!~/__(pfx|cfi|odr_asan)/ { printf $(2),$$3 }' + quiet_cmd_exports = EXPORTS $@ cmd_exports = \ - $(NM) -p --defined-only $< \ - | awk '$$2~/(T|R|D|B)/ && $$3!~/__cfi/ {printf "EXPORT_SYMBOL_RUST_GPL(%s);\n",$$3}' > $@ + $(call rust_exports,$<,"EXPORT_SYMBOL_RUST_GPL(%s);\n") > $@ $(obj)/exports_core_generated.h: $(obj)/core.o FORCE $(call if_changed,exports) @@ -358,22 +413,27 @@ $(obj)/exports_kernel_generated.h: $(obj)/kernel.o FORCE quiet_cmd_rustc_procmacro = $(RUSTC_OR_CLIPPY_QUIET) P $@ cmd_rustc_procmacro = \ - $(RUSTC_OR_CLIPPY) $(rust_common_flags) \ + $(RUSTC_OR_CLIPPY) $(rust_common_flags) $(rustc_target_flags) \ -Clinker-flavor=gcc -Clinker=$(HOSTCC) \ - -Clink-args='$(call escsq,$(KBUILD_HOSTLDFLAGS))' \ + -Clink-args='$(call escsq,$(KBUILD_PROCMACROLDFLAGS))' \ --emit=dep-info=$(depfile) --emit=link=$@ --extern proc_macro \ --crate-type proc-macro \ - --crate-name $(patsubst lib%.so,%,$(notdir $@)) $< + --crate-name $(patsubst lib%.$(libmacros_extension),%,$(notdir $@)) \ + @$(objtree)/include/generated/rustc_cfg $< # Procedural macros can only be used with the `rustc` that compiled it. -$(obj)/libmacros.so: $(src)/macros/lib.rs FORCE +$(obj)/$(libmacros_name): $(src)/macros/lib.rs FORCE + +$(call if_changed_dep,rustc_procmacro) + +$(obj)/$(libpin_init_internal_name): private rustc_target_flags = --cfg kernel +$(obj)/$(libpin_init_internal_name): $(src)/pin-init/internal/src/lib.rs FORCE +$(call if_changed_dep,rustc_procmacro) quiet_cmd_rustc_library = $(if $(skip_clippy),RUSTC,$(RUSTC_OR_CLIPPY_QUIET)) L $@ cmd_rustc_library = \ OBJTREE=$(abspath $(objtree)) \ $(if $(skip_clippy),$(RUSTC),$(RUSTC_OR_CLIPPY)) \ - $(filter-out $(skip_flags),$(rust_flags) $(rustc_target_flags)) \ + $(filter-out $(skip_flags),$(rust_flags)) $(rustc_target_flags) \ --emit=dep-info=$(depfile) --emit=obj=$@ \ --emit=metadata=$(dir $@)$(patsubst %.o,lib%.rmeta,$(notdir $@)) \ --crate-type rlib -L$(objtree)/$(obj) \ @@ -383,8 +443,8 @@ quiet_cmd_rustc_library = $(if $(skip_clippy),RUSTC,$(RUSTC_OR_CLIPPY_QUIET)) L $(cmd_objtool) rust-analyzer: - $(Q)$(srctree)/scripts/generate_rust_analyzer.py \ - --cfgs='core=$(core-cfgs)' \ + $(Q)MAKEFLAGS= $(srctree)/scripts/generate_rust_analyzer.py \ + --cfgs='core=$(core-cfgs)' $(core-edition) \ $(realpath $(srctree)) $(realpath $(objtree)) \ $(rustc_sysroot) $(RUST_LIB_SRC) $(if $(KBUILD_EXTMOD),$(srcroot)) \ > rust-project.json @@ -395,6 +455,13 @@ redirect-intrinsics = \ __muloti4 __multi3 \ __udivmodti4 __udivti3 __umodti3 +ifdef CONFIG_ARM + # Add eabi initrinsics for ARM 32-bit + redirect-intrinsics += \ + __aeabi_fadd __aeabi_fmul __aeabi_fcmpeq __aeabi_fcmple __aeabi_fcmplt __aeabi_fcmpun \ + __aeabi_dadd __aeabi_dmul __aeabi_dcmple __aeabi_dcmplt __aeabi_dcmpun \ + __aeabi_uldivmod +endif ifneq ($(or $(CONFIG_ARM64),$(and $(CONFIG_RISCV),$(CONFIG_64BIT))),) # These intrinsics are defined for ARM64 and RISCV64 redirect-intrinsics += \ @@ -402,29 +469,65 @@ ifneq ($(or $(CONFIG_ARM64),$(and $(CONFIG_RISCV),$(CONFIG_64BIT))),) __ashlti3 __lshrti3 endif +ifdef CONFIG_MODVERSIONS +cmd_gendwarfksyms = $(if $(skip_gendwarfksyms),, \ + $(call rust_exports,$@,"%s\n") | \ + scripts/gendwarfksyms/gendwarfksyms \ + $(if $(KBUILD_GENDWARFKSYMS_STABLE), --stable) \ + $(if $(KBUILD_SYMTYPES), --symtypes $(@:.o=.symtypes),) \ + $@ >> $(dot-target).cmd) +endif + define rule_rustc_library $(call cmd_and_fixdep,rustc_library) $(call cmd,gen_objtooldep) + $(call cmd,gendwarfksyms) +endef + +define rule_rust_cc_library + $(call if_changed_rule,cc_o_c) + $(call cmd,force_checksrc) + $(call cmd,gendwarfksyms) endef +# helpers.o uses the same export mechanism as Rust libraries, so ensure symbol +# versions are calculated for the helpers too. +$(obj)/helpers/helpers.o: $(src)/helpers/helpers.c $(recordmcount_source) FORCE + +$(call if_changed_rule,rust_cc_library) + +# Disable symbol versioning for exports.o to avoid conflicts with the actual +# symbol versions generated from Rust objects. +$(obj)/exports.o: private skip_gendwarfksyms = 1 + $(obj)/core.o: private skip_clippy = 1 -$(obj)/core.o: private skip_flags = -Wunreachable_pub +$(obj)/core.o: private skip_flags = --edition=2021 -Wunreachable_pub $(obj)/core.o: private rustc_objcopy = $(foreach sym,$(redirect-intrinsics),--redefine-sym $(sym)=__rust$(sym)) -$(obj)/core.o: private rustc_target_flags = $(core-cfgs) +$(obj)/core.o: private rustc_target_flags = --edition=$(core-edition) $(core-cfgs) $(obj)/core.o: $(RUST_LIB_SRC)/core/src/lib.rs \ $(wildcard $(objtree)/include/config/RUSTC_VERSION_TEXT) FORCE +$(call if_changed_rule,rustc_library) ifneq ($(or $(CONFIG_X86_64),$(CONFIG_X86_32)),) $(obj)/core.o: scripts/target.json endif +KCOV_INSTRUMENT_core.o := n +$(obj)/compiler_builtins.o: private skip_gendwarfksyms = 1 $(obj)/compiler_builtins.o: private rustc_objcopy = -w -W '__*' $(obj)/compiler_builtins.o: $(src)/compiler_builtins.rs $(obj)/core.o FORCE +$(call if_changed_rule,rustc_library) +$(obj)/pin_init.o: private skip_gendwarfksyms = 1 +$(obj)/pin_init.o: private rustc_target_flags = --extern pin_init_internal \ + --extern macros --cfg kernel +$(obj)/pin_init.o: $(src)/pin-init/src/lib.rs $(obj)/compiler_builtins.o \ + $(obj)/$(libpin_init_internal_name) $(obj)/$(libmacros_name) FORCE + +$(call if_changed_rule,rustc_library) + +$(obj)/build_error.o: private skip_gendwarfksyms = 1 $(obj)/build_error.o: $(src)/build_error.rs $(obj)/compiler_builtins.o FORCE +$(call if_changed_rule,rustc_library) +$(obj)/ffi.o: private skip_gendwarfksyms = 1 $(obj)/ffi.o: $(src)/ffi.rs $(obj)/compiler_builtins.o FORCE +$(call if_changed_rule,rustc_library) @@ -436,19 +539,25 @@ $(obj)/bindings.o: $(src)/bindings/lib.rs \ +$(call if_changed_rule,rustc_library) $(obj)/uapi.o: private rustc_target_flags = --extern ffi +$(obj)/uapi.o: private skip_gendwarfksyms = 1 $(obj)/uapi.o: $(src)/uapi/lib.rs \ $(obj)/ffi.o \ $(obj)/uapi/uapi_generated.rs FORCE +$(call if_changed_rule,rustc_library) -$(obj)/kernel.o: private rustc_target_flags = --extern ffi \ +$(obj)/kernel.o: private rustc_target_flags = --extern ffi --extern pin_init \ --extern build_error --extern macros --extern bindings --extern uapi -$(obj)/kernel.o: $(src)/kernel/lib.rs $(obj)/build_error.o \ - $(obj)/libmacros.so $(obj)/bindings.o $(obj)/uapi.o FORCE +$(obj)/kernel.o: $(src)/kernel/lib.rs $(obj)/build_error.o $(obj)/pin_init.o \ + $(obj)/$(libmacros_name) $(obj)/bindings.o $(obj)/uapi.o FORCE +$(call if_changed_rule,rustc_library) ifdef CONFIG_JUMP_LABEL $(obj)/kernel.o: $(obj)/kernel/generated_arch_static_branch_asm.rs endif +ifndef CONFIG_UML +ifdef CONFIG_BUG +$(obj)/kernel.o: $(obj)/kernel/generated_arch_warn_asm.rs $(obj)/kernel/generated_arch_reachable_asm.rs +endif +endif endif # CONFIG_RUST diff --git a/rust/bindings/bindings_helper.h b/rust/bindings/bindings_helper.h index 5c4dfe22f41a..7c47875cd2ae 100644 --- a/rust/bindings/bindings_helper.h +++ b/rust/bindings/bindings_helper.h @@ -6,32 +6,82 @@ * Sorted alphabetically. */ +/* + * First, avoid forward references to `enum` types. + * + * This workarounds a `bindgen` issue with them: + * <https://github.com/rust-lang/rust-bindgen/issues/3179>. + * + * Without this, the generated Rust type may be the wrong one (`i32`) or + * the proper one (typically `c_uint`) depending on how the headers are + * included, which in turn may depend on the particular kernel configuration + * or the architecture. + * + * The alternative would be to use casts and likely an + * `#[allow(clippy::unnecessary_cast)]` in the Rust source files. Instead, + * this approach allows us to keep the correct code in the source files and + * simply remove this section when the issue is fixed upstream and we bump + * the minimum `bindgen` version. + * + * This workaround may not be possible in some cases, depending on how the C + * headers are set up. + */ +#include <linux/hrtimer_types.h> + +#include <linux/acpi.h> +#include <drm/drm_device.h> +#include <drm/drm_drv.h> +#include <drm/drm_file.h> +#include <drm/drm_gem.h> +#include <drm/drm_ioctl.h> #include <kunit/test.h> +#include <linux/auxiliary_bus.h> #include <linux/blk-mq.h> #include <linux/blk_types.h> #include <linux/blkdev.h> +#include <linux/clk.h> +#include <linux/completion.h> +#include <linux/configfs.h> +#include <linux/cpu.h> +#include <linux/cpufreq.h> +#include <linux/cpumask.h> #include <linux/cred.h> +#include <linux/device/faux.h> +#include <linux/dma-mapping.h> #include <linux/errname.h> #include <linux/ethtool.h> #include <linux/file.h> #include <linux/firmware.h> #include <linux/fs.h> +#include <linux/ioport.h> #include <linux/jiffies.h> #include <linux/jump_label.h> #include <linux/mdio.h> #include <linux/miscdevice.h> +#include <linux/of_device.h> +#include <linux/pci.h> #include <linux/phy.h> #include <linux/pid_namespace.h> +#include <linux/platform_device.h> +#include <linux/pm_opp.h> #include <linux/poll.h> +#include <linux/property.h> #include <linux/refcount.h> +#include <linux/regulator/consumer.h> #include <linux/sched.h> #include <linux/security.h> #include <linux/slab.h> #include <linux/tracepoint.h> #include <linux/wait.h> #include <linux/workqueue.h> +#include <linux/xarray.h> #include <trace/events/rust_sample.h> +#if defined(CONFIG_DRM_PANIC_SCREEN_QR_CODE) +// Used by `#[export]` in `drivers/gpu/drm/drm_panic_qr.rs`. +#include <drm/drm_panic.h> +#endif + /* `bindgen` gets confused at certain things. */ const size_t RUST_CONST_HELPER_ARCH_SLAB_MINALIGN = ARCH_SLAB_MINALIGN; const size_t RUST_CONST_HELPER_PAGE_SIZE = PAGE_SIZE; @@ -43,3 +93,10 @@ const gfp_t RUST_CONST_HELPER___GFP_ZERO = __GFP_ZERO; const gfp_t RUST_CONST_HELPER___GFP_HIGHMEM = ___GFP_HIGHMEM; const gfp_t RUST_CONST_HELPER___GFP_NOWARN = ___GFP_NOWARN; const blk_features_t RUST_CONST_HELPER_BLK_FEAT_ROTATIONAL = BLK_FEAT_ROTATIONAL; +const fop_flags_t RUST_CONST_HELPER_FOP_UNSIGNED_OFFSET = FOP_UNSIGNED_OFFSET; + +const xa_mark_t RUST_CONST_HELPER_XA_PRESENT = XA_PRESENT; + +const gfp_t RUST_CONST_HELPER_XA_FLAGS_ALLOC = XA_FLAGS_ALLOC; +const gfp_t RUST_CONST_HELPER_XA_FLAGS_ALLOC1 = XA_FLAGS_ALLOC1; +const vm_flags_t RUST_CONST_HELPER_VM_MERGEABLE = VM_MERGEABLE; diff --git a/rust/bindings/lib.rs b/rust/bindings/lib.rs index 014af0d1fc70..474cc98c48a3 100644 --- a/rust/bindings/lib.rs +++ b/rust/bindings/lib.rs @@ -25,7 +25,11 @@ )] #[allow(dead_code)] +#[allow(clippy::cast_lossless)] +#[allow(clippy::ptr_as_ptr)] +#[allow(clippy::ref_as_ptr)] #[allow(clippy::undocumented_unsafe_blocks)] +#[cfg_attr(CONFIG_RUSTC_HAS_UNNECESSARY_TRANSMUTES, allow(unnecessary_transmutes))] mod bindings_raw { // Manual definition for blocklisted types. type __kernel_size_t = usize; diff --git a/rust/compiler_builtins.rs b/rust/compiler_builtins.rs index f14b8d7caf89..dd16c1dc899c 100644 --- a/rust/compiler_builtins.rs +++ b/rust/compiler_builtins.rs @@ -73,5 +73,29 @@ define_panicking_intrinsics!("`u128` should not be used", { __umodti3, }); +#[cfg(target_arch = "arm")] +define_panicking_intrinsics!("`f32` should not be used", { + __aeabi_fadd, + __aeabi_fmul, + __aeabi_fcmpeq, + __aeabi_fcmple, + __aeabi_fcmplt, + __aeabi_fcmpun, +}); + +#[cfg(target_arch = "arm")] +define_panicking_intrinsics!("`f64` should not be used", { + __aeabi_dadd, + __aeabi_dmul, + __aeabi_dcmple, + __aeabi_dcmplt, + __aeabi_dcmpun, +}); + +#[cfg(target_arch = "arm")] +define_panicking_intrinsics!("`u64` division/modulo should not be used", { + __aeabi_uldivmod, +}); + // NOTE: if you are adding a new intrinsic here, you should also add it to // `redirect-intrinsics` in `rust/Makefile`. diff --git a/rust/ffi.rs b/rust/ffi.rs index 584f75b49862..d60aad792af4 100644 --- a/rust/ffi.rs +++ b/rust/ffi.rs @@ -17,7 +17,7 @@ macro_rules! alias { // Check size compatibility with `core`. const _: () = assert!( - core::mem::size_of::<$name>() == core::mem::size_of::<core::ffi::$name>() + ::core::mem::size_of::<$name>() == ::core::mem::size_of::<::core::ffi::$name>() ); )*} } diff --git a/rust/helpers/auxiliary.c b/rust/helpers/auxiliary.c new file mode 100644 index 000000000000..8b5e0fea4493 --- /dev/null +++ b/rust/helpers/auxiliary.c @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/auxiliary_bus.h> + +void rust_helper_auxiliary_device_uninit(struct auxiliary_device *adev) +{ + return auxiliary_device_uninit(adev); +} + +void rust_helper_auxiliary_device_delete(struct auxiliary_device *adev) +{ + return auxiliary_device_delete(adev); +} diff --git a/rust/helpers/bug.c b/rust/helpers/bug.c index e2d13babc737..a62c96f507d1 100644 --- a/rust/helpers/bug.c +++ b/rust/helpers/bug.c @@ -6,3 +6,8 @@ __noreturn void rust_helper_BUG(void) { BUG(); } + +bool rust_helper_WARN_ON(bool cond) +{ + return WARN_ON(cond); +} diff --git a/rust/helpers/clk.c b/rust/helpers/clk.c new file mode 100644 index 000000000000..6d04372c9f3b --- /dev/null +++ b/rust/helpers/clk.c @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/clk.h> + +/* + * The "inline" implementation of below helpers are only available when + * CONFIG_HAVE_CLK or CONFIG_HAVE_CLK_PREPARE aren't set. + */ +#ifndef CONFIG_HAVE_CLK +struct clk *rust_helper_clk_get(struct device *dev, const char *id) +{ + return clk_get(dev, id); +} + +void rust_helper_clk_put(struct clk *clk) +{ + clk_put(clk); +} + +int rust_helper_clk_enable(struct clk *clk) +{ + return clk_enable(clk); +} + +void rust_helper_clk_disable(struct clk *clk) +{ + clk_disable(clk); +} + +unsigned long rust_helper_clk_get_rate(struct clk *clk) +{ + return clk_get_rate(clk); +} + +int rust_helper_clk_set_rate(struct clk *clk, unsigned long rate) +{ + return clk_set_rate(clk, rate); +} +#endif + +#ifndef CONFIG_HAVE_CLK_PREPARE +int rust_helper_clk_prepare(struct clk *clk) +{ + return clk_prepare(clk); +} + +void rust_helper_clk_unprepare(struct clk *clk) +{ + clk_unprepare(clk); +} +#endif + +struct clk *rust_helper_clk_get_optional(struct device *dev, const char *id) +{ + return clk_get_optional(dev, id); +} + +int rust_helper_clk_prepare_enable(struct clk *clk) +{ + return clk_prepare_enable(clk); +} + +void rust_helper_clk_disable_unprepare(struct clk *clk) +{ + clk_disable_unprepare(clk); +} diff --git a/rust/helpers/completion.c b/rust/helpers/completion.c new file mode 100644 index 000000000000..b2443262a2ae --- /dev/null +++ b/rust/helpers/completion.c @@ -0,0 +1,8 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/completion.h> + +void rust_helper_init_completion(struct completion *x) +{ + init_completion(x); +} diff --git a/rust/helpers/cpu.c b/rust/helpers/cpu.c new file mode 100644 index 000000000000..824e0adb19d4 --- /dev/null +++ b/rust/helpers/cpu.c @@ -0,0 +1,8 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/smp.h> + +unsigned int rust_helper_raw_smp_processor_id(void) +{ + return raw_smp_processor_id(); +} diff --git a/rust/helpers/cpufreq.c b/rust/helpers/cpufreq.c new file mode 100644 index 000000000000..7c1343c4d65e --- /dev/null +++ b/rust/helpers/cpufreq.c @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/cpufreq.h> + +#ifdef CONFIG_CPU_FREQ +void rust_helper_cpufreq_register_em_with_opp(struct cpufreq_policy *policy) +{ + cpufreq_register_em_with_opp(policy); +} +#endif diff --git a/rust/helpers/cpumask.c b/rust/helpers/cpumask.c new file mode 100644 index 000000000000..eb10598a0242 --- /dev/null +++ b/rust/helpers/cpumask.c @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/cpumask.h> + +void rust_helper_cpumask_set_cpu(unsigned int cpu, struct cpumask *dstp) +{ + cpumask_set_cpu(cpu, dstp); +} + +void rust_helper___cpumask_set_cpu(unsigned int cpu, struct cpumask *dstp) +{ + __cpumask_set_cpu(cpu, dstp); +} + +void rust_helper_cpumask_clear_cpu(int cpu, struct cpumask *dstp) +{ + cpumask_clear_cpu(cpu, dstp); +} + +void rust_helper___cpumask_clear_cpu(int cpu, struct cpumask *dstp) +{ + __cpumask_clear_cpu(cpu, dstp); +} + +bool rust_helper_cpumask_test_cpu(int cpu, struct cpumask *srcp) +{ + return cpumask_test_cpu(cpu, srcp); +} + +void rust_helper_cpumask_setall(struct cpumask *dstp) +{ + cpumask_setall(dstp); +} + +bool rust_helper_cpumask_empty(struct cpumask *srcp) +{ + return cpumask_empty(srcp); +} + +bool rust_helper_cpumask_full(struct cpumask *srcp) +{ + return cpumask_full(srcp); +} + +unsigned int rust_helper_cpumask_weight(struct cpumask *srcp) +{ + return cpumask_weight(srcp); +} + +void rust_helper_cpumask_copy(struct cpumask *dstp, const struct cpumask *srcp) +{ + cpumask_copy(dstp, srcp); +} + +bool rust_helper_alloc_cpumask_var(cpumask_var_t *mask, gfp_t flags) +{ + return alloc_cpumask_var(mask, flags); +} + +bool rust_helper_zalloc_cpumask_var(cpumask_var_t *mask, gfp_t flags) +{ + return zalloc_cpumask_var(mask, flags); +} + +#ifndef CONFIG_CPUMASK_OFFSTACK +void rust_helper_free_cpumask_var(cpumask_var_t mask) +{ + free_cpumask_var(mask); +} +#endif diff --git a/rust/helpers/device.c b/rust/helpers/device.c new file mode 100644 index 000000000000..9a4316bafedf --- /dev/null +++ b/rust/helpers/device.c @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/device.h> + +int rust_helper_devm_add_action(struct device *dev, + void (*action)(void *), + void *data) +{ + return devm_add_action(dev, action, data); +} + +int rust_helper_devm_add_action_or_reset(struct device *dev, + void (*action)(void *), + void *data) +{ + return devm_add_action_or_reset(dev, action, data); +} + +void *rust_helper_dev_get_drvdata(const struct device *dev) +{ + return dev_get_drvdata(dev); +} + +void rust_helper_dev_set_drvdata(struct device *dev, void *data) +{ + dev_set_drvdata(dev, data); +} diff --git a/rust/helpers/dma.c b/rust/helpers/dma.c new file mode 100644 index 000000000000..6e741c197242 --- /dev/null +++ b/rust/helpers/dma.c @@ -0,0 +1,21 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/dma-mapping.h> + +void *rust_helper_dma_alloc_attrs(struct device *dev, size_t size, + dma_addr_t *dma_handle, gfp_t flag, + unsigned long attrs) +{ + return dma_alloc_attrs(dev, size, dma_handle, flag, attrs); +} + +void rust_helper_dma_free_attrs(struct device *dev, size_t size, void *cpu_addr, + dma_addr_t dma_handle, unsigned long attrs) +{ + dma_free_attrs(dev, size, cpu_addr, dma_handle, attrs); +} + +int rust_helper_dma_set_mask_and_coherent(struct device *dev, u64 mask) +{ + return dma_set_mask_and_coherent(dev, mask); +} diff --git a/rust/helpers/drm.c b/rust/helpers/drm.c new file mode 100644 index 000000000000..450b406c6f27 --- /dev/null +++ b/rust/helpers/drm.c @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <drm/drm_gem.h> +#include <drm/drm_vma_manager.h> + +#ifdef CONFIG_DRM + +void rust_helper_drm_gem_object_get(struct drm_gem_object *obj) +{ + drm_gem_object_get(obj); +} + +void rust_helper_drm_gem_object_put(struct drm_gem_object *obj) +{ + drm_gem_object_put(obj); +} + +__u64 rust_helper_drm_vma_node_offset_addr(struct drm_vma_offset_node *node) +{ + return drm_vma_node_offset_addr(node); +} + +#endif diff --git a/rust/helpers/helpers.c b/rust/helpers/helpers.c index dcf827a61b52..7cf7fe95e41d 100644 --- a/rust/helpers/helpers.c +++ b/rust/helpers/helpers.c @@ -7,26 +7,47 @@ * Sorted alphabetically. */ +#include "auxiliary.c" #include "blk.c" #include "bug.c" #include "build_assert.c" #include "build_bug.c" +#include "clk.c" +#include "completion.c" +#include "cpu.c" +#include "cpufreq.c" +#include "cpumask.c" #include "cred.c" +#include "device.c" +#include "dma.c" +#include "drm.c" #include "err.c" #include "fs.c" +#include "io.c" #include "jump_label.c" #include "kunit.c" +#include "mm.c" #include "mutex.c" +#include "of.c" #include "page.c" +#include "pci.c" #include "pid_namespace.c" +#include "platform.c" +#include "poll.c" +#include "property.c" #include "rbtree.c" +#include "rcu.c" #include "refcount.c" +#include "regulator.c" #include "security.c" #include "signal.c" #include "slab.c" #include "spinlock.c" +#include "sync.c" #include "task.c" +#include "time.c" #include "uaccess.c" #include "vmalloc.c" #include "wait.c" #include "workqueue.c" +#include "xarray.c" diff --git a/rust/helpers/io.c b/rust/helpers/io.c new file mode 100644 index 000000000000..c475913c69e6 --- /dev/null +++ b/rust/helpers/io.c @@ -0,0 +1,142 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/io.h> +#include <linux/ioport.h> + +void __iomem *rust_helper_ioremap(phys_addr_t offset, size_t size) +{ + return ioremap(offset, size); +} + +void __iomem *rust_helper_ioremap_np(phys_addr_t offset, size_t size) +{ + return ioremap_np(offset, size); +} + +void rust_helper_iounmap(void __iomem *addr) +{ + iounmap(addr); +} + +u8 rust_helper_readb(const void __iomem *addr) +{ + return readb(addr); +} + +u16 rust_helper_readw(const void __iomem *addr) +{ + return readw(addr); +} + +u32 rust_helper_readl(const void __iomem *addr) +{ + return readl(addr); +} + +#ifdef CONFIG_64BIT +u64 rust_helper_readq(const void __iomem *addr) +{ + return readq(addr); +} +#endif + +void rust_helper_writeb(u8 value, void __iomem *addr) +{ + writeb(value, addr); +} + +void rust_helper_writew(u16 value, void __iomem *addr) +{ + writew(value, addr); +} + +void rust_helper_writel(u32 value, void __iomem *addr) +{ + writel(value, addr); +} + +#ifdef CONFIG_64BIT +void rust_helper_writeq(u64 value, void __iomem *addr) +{ + writeq(value, addr); +} +#endif + +u8 rust_helper_readb_relaxed(const void __iomem *addr) +{ + return readb_relaxed(addr); +} + +u16 rust_helper_readw_relaxed(const void __iomem *addr) +{ + return readw_relaxed(addr); +} + +u32 rust_helper_readl_relaxed(const void __iomem *addr) +{ + return readl_relaxed(addr); +} + +#ifdef CONFIG_64BIT +u64 rust_helper_readq_relaxed(const void __iomem *addr) +{ + return readq_relaxed(addr); +} +#endif + +void rust_helper_writeb_relaxed(u8 value, void __iomem *addr) +{ + writeb_relaxed(value, addr); +} + +void rust_helper_writew_relaxed(u16 value, void __iomem *addr) +{ + writew_relaxed(value, addr); +} + +void rust_helper_writel_relaxed(u32 value, void __iomem *addr) +{ + writel_relaxed(value, addr); +} + +#ifdef CONFIG_64BIT +void rust_helper_writeq_relaxed(u64 value, void __iomem *addr) +{ + writeq_relaxed(value, addr); +} +#endif + +resource_size_t rust_helper_resource_size(struct resource *res) +{ + return resource_size(res); +} + +struct resource *rust_helper_request_mem_region(resource_size_t start, + resource_size_t n, + const char *name) +{ + return request_mem_region(start, n, name); +} + +void rust_helper_release_mem_region(resource_size_t start, resource_size_t n) +{ + release_mem_region(start, n); +} + +struct resource *rust_helper_request_region(resource_size_t start, + resource_size_t n, const char *name) +{ + return request_region(start, n, name); +} + +struct resource *rust_helper_request_muxed_region(resource_size_t start, + resource_size_t n, + const char *name) +{ + return request_muxed_region(start, n, name); +} + +void rust_helper_release_region(resource_size_t start, resource_size_t n) +{ + release_region(start, n); +} diff --git a/rust/helpers/mm.c b/rust/helpers/mm.c new file mode 100644 index 000000000000..81b510c96fd2 --- /dev/null +++ b/rust/helpers/mm.c @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/mm.h> +#include <linux/sched/mm.h> + +void rust_helper_mmgrab(struct mm_struct *mm) +{ + mmgrab(mm); +} + +void rust_helper_mmdrop(struct mm_struct *mm) +{ + mmdrop(mm); +} + +void rust_helper_mmget(struct mm_struct *mm) +{ + mmget(mm); +} + +bool rust_helper_mmget_not_zero(struct mm_struct *mm) +{ + return mmget_not_zero(mm); +} + +void rust_helper_mmap_read_lock(struct mm_struct *mm) +{ + mmap_read_lock(mm); +} + +bool rust_helper_mmap_read_trylock(struct mm_struct *mm) +{ + return mmap_read_trylock(mm); +} + +void rust_helper_mmap_read_unlock(struct mm_struct *mm) +{ + mmap_read_unlock(mm); +} + +struct vm_area_struct *rust_helper_vma_lookup(struct mm_struct *mm, + unsigned long addr) +{ + return vma_lookup(mm, addr); +} + +void rust_helper_vma_end_read(struct vm_area_struct *vma) +{ + vma_end_read(vma); +} diff --git a/rust/helpers/mutex.c b/rust/helpers/mutex.c index 7e00680958ef..e487819125f0 100644 --- a/rust/helpers/mutex.c +++ b/rust/helpers/mutex.c @@ -7,8 +7,23 @@ void rust_helper_mutex_lock(struct mutex *lock) mutex_lock(lock); } +int rust_helper_mutex_trylock(struct mutex *lock) +{ + return mutex_trylock(lock); +} + void rust_helper___mutex_init(struct mutex *mutex, const char *name, struct lock_class_key *key) { __mutex_init(mutex, name, key); } + +void rust_helper_mutex_assert_is_held(struct mutex *mutex) +{ + lockdep_assert_held(mutex); +} + +void rust_helper_mutex_destroy(struct mutex *lock) +{ + mutex_destroy(lock); +} diff --git a/rust/helpers/of.c b/rust/helpers/of.c new file mode 100644 index 000000000000..86b51167c913 --- /dev/null +++ b/rust/helpers/of.c @@ -0,0 +1,8 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/of.h> + +bool rust_helper_is_of_node(const struct fwnode_handle *fwnode) +{ + return is_of_node(fwnode); +} diff --git a/rust/helpers/pci.c b/rust/helpers/pci.c new file mode 100644 index 000000000000..ef9cb38c81a6 --- /dev/null +++ b/rust/helpers/pci.c @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/pci.h> + +resource_size_t rust_helper_pci_resource_len(struct pci_dev *pdev, int bar) +{ + return pci_resource_len(pdev, bar); +} + +bool rust_helper_dev_is_pci(const struct device *dev) +{ + return dev_is_pci(dev); +} diff --git a/rust/helpers/platform.c b/rust/helpers/platform.c new file mode 100644 index 000000000000..1ce89c1a36f7 --- /dev/null +++ b/rust/helpers/platform.c @@ -0,0 +1,8 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/platform_device.h> + +bool rust_helper_dev_is_platform(const struct device *dev) +{ + return dev_is_platform(dev); +} diff --git a/rust/helpers/poll.c b/rust/helpers/poll.c new file mode 100644 index 000000000000..7e5b1751c2d5 --- /dev/null +++ b/rust/helpers/poll.c @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/export.h> +#include <linux/poll.h> + +void rust_helper_poll_wait(struct file *filp, wait_queue_head_t *wait_address, + poll_table *p) +{ + poll_wait(filp, wait_address, p); +} diff --git a/rust/helpers/property.c b/rust/helpers/property.c new file mode 100644 index 000000000000..08f68e2dac4a --- /dev/null +++ b/rust/helpers/property.c @@ -0,0 +1,8 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/property.h> + +void rust_helper_fwnode_handle_put(struct fwnode_handle *fwnode) +{ + fwnode_handle_put(fwnode); +} diff --git a/rust/helpers/rcu.c b/rust/helpers/rcu.c new file mode 100644 index 000000000000..f1cec6583513 --- /dev/null +++ b/rust/helpers/rcu.c @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/rcupdate.h> + +void rust_helper_rcu_read_lock(void) +{ + rcu_read_lock(); +} + +void rust_helper_rcu_read_unlock(void) +{ + rcu_read_unlock(); +} diff --git a/rust/helpers/regulator.c b/rust/helpers/regulator.c new file mode 100644 index 000000000000..cd8b7ba648ee --- /dev/null +++ b/rust/helpers/regulator.c @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/regulator/consumer.h> + +#ifndef CONFIG_REGULATOR + +void rust_helper_regulator_put(struct regulator *regulator) +{ + regulator_put(regulator); +} + +int rust_helper_regulator_set_voltage(struct regulator *regulator, int min_uV, + int max_uV) +{ + return regulator_set_voltage(regulator, min_uV, max_uV); +} + +int rust_helper_regulator_get_voltage(struct regulator *regulator) +{ + return regulator_get_voltage(regulator); +} + +struct regulator *rust_helper_regulator_get(struct device *dev, const char *id) +{ + return regulator_get(dev, id); +} + +int rust_helper_regulator_enable(struct regulator *regulator) +{ + return regulator_enable(regulator); +} + +int rust_helper_regulator_disable(struct regulator *regulator) +{ + return regulator_disable(regulator); +} + +int rust_helper_regulator_is_enabled(struct regulator *regulator) +{ + return regulator_is_enabled(regulator); +} + +#endif diff --git a/rust/helpers/security.c b/rust/helpers/security.c index 239e5b4745fe..0c4c2065df28 100644 --- a/rust/helpers/security.c +++ b/rust/helpers/security.c @@ -8,13 +8,13 @@ void rust_helper_security_cred_getsecid(const struct cred *c, u32 *secid) security_cred_getsecid(c, secid); } -int rust_helper_security_secid_to_secctx(u32 secid, char **secdata, u32 *seclen) +int rust_helper_security_secid_to_secctx(u32 secid, struct lsm_context *cp) { - return security_secid_to_secctx(secid, secdata, seclen); + return security_secid_to_secctx(secid, cp); } -void rust_helper_security_release_secctx(char *secdata, u32 seclen) +void rust_helper_security_release_secctx(struct lsm_context *cp) { - security_release_secctx(secdata, seclen); + security_release_secctx(cp); } #endif diff --git a/rust/helpers/spinlock.c b/rust/helpers/spinlock.c index 5971fdf6f755..42c4bf01a23e 100644 --- a/rust/helpers/spinlock.c +++ b/rust/helpers/spinlock.c @@ -30,3 +30,8 @@ int rust_helper_spin_trylock(spinlock_t *lock) { return spin_trylock(lock); } + +void rust_helper_spin_assert_is_held(spinlock_t *lock) +{ + lockdep_assert_held(lock); +} diff --git a/rust/helpers/sync.c b/rust/helpers/sync.c new file mode 100644 index 000000000000..ff7e68b48810 --- /dev/null +++ b/rust/helpers/sync.c @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/lockdep.h> + +void rust_helper_lockdep_register_key(struct lock_class_key *k) +{ + lockdep_register_key(k); +} + +void rust_helper_lockdep_unregister_key(struct lock_class_key *k) +{ + lockdep_unregister_key(k); +} diff --git a/rust/helpers/task.c b/rust/helpers/task.c index 31c33ea2dce6..2c85bbc2727e 100644 --- a/rust/helpers/task.c +++ b/rust/helpers/task.c @@ -1,7 +1,13 @@ // SPDX-License-Identifier: GPL-2.0 +#include <linux/kernel.h> #include <linux/sched/task.h> +void rust_helper_might_resched(void) +{ + might_resched(); +} + struct task_struct *rust_helper_get_current(void) { return current; diff --git a/rust/helpers/time.c b/rust/helpers/time.c new file mode 100644 index 000000000000..a318e9fa4408 --- /dev/null +++ b/rust/helpers/time.c @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/delay.h> +#include <linux/ktime.h> +#include <linux/timekeeping.h> + +void rust_helper_fsleep(unsigned long usecs) +{ + fsleep(usecs); +} + +ktime_t rust_helper_ktime_get_real(void) +{ + return ktime_get_real(); +} + +ktime_t rust_helper_ktime_get_boottime(void) +{ + return ktime_get_boottime(); +} + +ktime_t rust_helper_ktime_get_clocktai(void) +{ + return ktime_get_clocktai(); +} + +s64 rust_helper_ktime_to_us(const ktime_t kt) +{ + return ktime_to_us(kt); +} + +s64 rust_helper_ktime_to_ms(const ktime_t kt) +{ + return ktime_to_ms(kt); +} diff --git a/rust/helpers/xarray.c b/rust/helpers/xarray.c new file mode 100644 index 000000000000..60b299f11451 --- /dev/null +++ b/rust/helpers/xarray.c @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/xarray.h> + +int rust_helper_xa_err(void *entry) +{ + return xa_err(entry); +} + +void rust_helper_xa_init_flags(struct xarray *xa, gfp_t flags) +{ + return xa_init_flags(xa, flags); +} + +int rust_helper_xa_trylock(struct xarray *xa) +{ + return xa_trylock(xa); +} + +void rust_helper_xa_lock(struct xarray *xa) +{ + return xa_lock(xa); +} + +void rust_helper_xa_unlock(struct xarray *xa) +{ + return xa_unlock(xa); +} diff --git a/rust/kernel/.gitignore b/rust/kernel/.gitignore index 6ba39a178f30..f636ad95aaf3 100644 --- a/rust/kernel/.gitignore +++ b/rust/kernel/.gitignore @@ -1,3 +1,5 @@ # SPDX-License-Identifier: GPL-2.0 /generated_arch_static_branch_asm.rs +/generated_arch_warn_asm.rs +/generated_arch_reachable_asm.rs diff --git a/rust/kernel/acpi.rs b/rust/kernel/acpi.rs new file mode 100644 index 000000000000..7ae317368b00 --- /dev/null +++ b/rust/kernel/acpi.rs @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Advanced Configuration and Power Interface abstractions. + +use crate::{ + bindings, + device_id::{RawDeviceId, RawDeviceIdIndex}, + prelude::*, +}; + +/// IdTable type for ACPI drivers. +pub type IdTable<T> = &'static dyn kernel::device_id::IdTable<DeviceId, T>; + +/// An ACPI device id. +#[repr(transparent)] +#[derive(Clone, Copy)] +pub struct DeviceId(bindings::acpi_device_id); + +// SAFETY: `DeviceId` is a `#[repr(transparent)]` wrapper of `acpi_device_id` and does not add +// additional invariants, so it's safe to transmute to `RawType`. +unsafe impl RawDeviceId for DeviceId { + type RawType = bindings::acpi_device_id; +} + +// SAFETY: `DRIVER_DATA_OFFSET` is the offset to the `driver_data` field. +unsafe impl RawDeviceIdIndex for DeviceId { + const DRIVER_DATA_OFFSET: usize = core::mem::offset_of!(bindings::acpi_device_id, driver_data); + + fn index(&self) -> usize { + self.0.driver_data + } +} + +impl DeviceId { + const ACPI_ID_LEN: usize = 16; + + /// Create a new device id from an ACPI 'id' string. + #[inline(always)] + pub const fn new(id: &'static CStr) -> Self { + build_assert!( + id.len_with_nul() <= Self::ACPI_ID_LEN, + "ID exceeds 16 bytes" + ); + let src = id.as_bytes_with_nul(); + // Replace with `bindings::acpi_device_id::default()` once stabilized for `const`. + // SAFETY: FFI type is valid to be zero-initialized. + let mut acpi: bindings::acpi_device_id = unsafe { core::mem::zeroed() }; + let mut i = 0; + while i < src.len() { + acpi.id[i] = src[i]; + i += 1; + } + + Self(acpi) + } +} + +/// Create an ACPI `IdTable` with an "alias" for modpost. +#[macro_export] +macro_rules! acpi_device_table { + ($table_name:ident, $module_table_name:ident, $id_info_type: ty, $table_data: expr) => { + const $table_name: $crate::device_id::IdArray< + $crate::acpi::DeviceId, + $id_info_type, + { $table_data.len() }, + > = $crate::device_id::IdArray::new($table_data); + + $crate::module_device_table!("acpi", $module_table_name, $table_name); + }; +} diff --git a/rust/kernel/alloc.rs b/rust/kernel/alloc.rs index f2f7f3a53d29..a2c49e5494d3 100644 --- a/rust/kernel/alloc.rs +++ b/rust/kernel/alloc.rs @@ -94,10 +94,10 @@ pub mod flags { /// /// A lower watermark is applied to allow access to "atomic reserves". The current /// implementation doesn't support NMI and few other strict non-preemptive contexts (e.g. - /// raw_spin_lock). The same applies to [`GFP_NOWAIT`]. + /// `raw_spin_lock`). The same applies to [`GFP_NOWAIT`]. pub const GFP_ATOMIC: Flags = Flags(bindings::GFP_ATOMIC); - /// Typical for kernel-internal allocations. The caller requires ZONE_NORMAL or a lower zone + /// Typical for kernel-internal allocations. The caller requires `ZONE_NORMAL` or a lower zone /// for direct access but can direct reclaim. pub const GFP_KERNEL: Flags = Flags(bindings::GFP_KERNEL); @@ -123,7 +123,7 @@ pub mod flags { /// [`Allocator`] is designed to be implemented as a ZST; [`Allocator`] functions do not operate on /// an object instance. /// -/// In order to be able to support `#[derive(SmartPointer)]` later on, we need to avoid a design +/// In order to be able to support `#[derive(CoercePointee)]` later on, we need to avoid a design /// that requires an `Allocator` to be instantiated, hence its functions must not contain any kind /// of `self` parameter. /// diff --git a/rust/kernel/alloc/allocator.rs b/rust/kernel/alloc/allocator.rs index 439985e29fbc..2692cf90c948 100644 --- a/rust/kernel/alloc/allocator.rs +++ b/rust/kernel/alloc/allocator.rs @@ -43,17 +43,6 @@ pub struct Vmalloc; /// For more details see [self]. pub struct KVmalloc; -/// Returns a proper size to alloc a new object aligned to `new_layout`'s alignment. -fn aligned_size(new_layout: Layout) -> usize { - // Customized layouts from `Layout::from_size_align()` can have size < align, so pad first. - let layout = new_layout.pad_to_align(); - - // Note that `layout.size()` (after padding) is guaranteed to be a multiple of `layout.align()` - // which together with the slab guarantees means the `krealloc` will return a properly aligned - // object (see comments in `kmalloc()` for more information). - layout.size() -} - /// # Invariants /// /// One of the following: `krealloc`, `vrealloc`, `kvrealloc`. @@ -80,6 +69,7 @@ impl ReallocFunc { /// This method has the same guarantees as `Allocator::realloc`. Additionally /// - it accepts any pointer to a valid memory allocation allocated by this function. /// - memory allocated by this function remains valid until it is passed to this function. + #[inline] unsafe fn call( &self, ptr: Option<NonNull<u8>>, @@ -87,7 +77,7 @@ impl ReallocFunc { old_layout: Layout, flags: Flags, ) -> Result<NonNull<[u8]>, AllocError> { - let size = aligned_size(layout); + let size = layout.size(); let ptr = match ptr { Some(ptr) => { if old_layout.size() == 0 { @@ -122,6 +112,17 @@ impl ReallocFunc { } } +impl Kmalloc { + /// Returns a [`Layout`] that makes [`Kmalloc`] fulfill the requested size and alignment of + /// `layout`. + pub fn aligned_layout(layout: Layout) -> Layout { + // Note that `layout.size()` (after padding) is guaranteed to be a multiple of + // `layout.align()` which together with the slab guarantees means that `Kmalloc` will return + // a properly aligned object (see comments in `kmalloc()` for more information). + layout.pad_to_align() + } +} + // SAFETY: `realloc` delegates to `ReallocFunc::call`, which guarantees that // - memory remains valid until it is explicitly freed, // - passing a pointer to a valid memory allocation is OK, @@ -134,6 +135,8 @@ unsafe impl Allocator for Kmalloc { old_layout: Layout, flags: Flags, ) -> Result<NonNull<[u8]>, AllocError> { + let layout = Kmalloc::aligned_layout(layout); + // SAFETY: `ReallocFunc::call` has the same safety requirements as `Allocator::realloc`. unsafe { ReallocFunc::KREALLOC.call(ptr, layout, old_layout, flags) } } @@ -175,6 +178,10 @@ unsafe impl Allocator for KVmalloc { old_layout: Layout, flags: Flags, ) -> Result<NonNull<[u8]>, AllocError> { + // `KVmalloc` may use the `Kmalloc` backend, hence we have to enforce a `Kmalloc` + // compatible layout. + let layout = Kmalloc::aligned_layout(layout); + // TODO: Support alignments larger than PAGE_SIZE. if layout.align() > bindings::PAGE_SIZE { pr_warn!("KVmalloc does not support alignments larger than PAGE_SIZE yet.\n"); diff --git a/rust/kernel/alloc/allocator_test.rs b/rust/kernel/alloc/allocator_test.rs index e3240d16040b..90dd987d40e4 100644 --- a/rust/kernel/alloc/allocator_test.rs +++ b/rust/kernel/alloc/allocator_test.rs @@ -4,7 +4,7 @@ //! of those types (e.g. `CString`) use kernel allocators for instantiation. //! //! In order to allow userspace test cases to make use of such types as well, implement the -//! `Cmalloc` allocator within the allocator_test module and type alias all kernel allocators to +//! `Cmalloc` allocator within the `allocator_test` module and type alias all kernel allocators to //! `Cmalloc`. The `Cmalloc` allocator uses libc's `realloc()` function as allocator backend. #![allow(missing_docs)] @@ -22,6 +22,17 @@ pub type Kmalloc = Cmalloc; pub type Vmalloc = Kmalloc; pub type KVmalloc = Kmalloc; +impl Cmalloc { + /// Returns a [`Layout`] that makes [`Kmalloc`] fulfill the requested size and alignment of + /// `layout`. + pub fn aligned_layout(layout: Layout) -> Layout { + // Note that `layout.size()` (after padding) is guaranteed to be a multiple of + // `layout.align()` which together with the slab guarantees means that `Kmalloc` will return + // a properly aligned object (see comments in `kmalloc()` for more information). + layout.pad_to_align() + } +} + extern "C" { #[link_name = "aligned_alloc"] fn libc_aligned_alloc(align: usize, size: usize) -> *mut crate::ffi::c_void; @@ -62,9 +73,27 @@ unsafe impl Allocator for Cmalloc { )); } + // ISO C (ISO/IEC 9899:2011) defines `aligned_alloc`: + // + // > The value of alignment shall be a valid alignment supported by the implementation + // [...]. + // + // As an example of the "supported by the implementation" requirement, POSIX.1-2001 (IEEE + // 1003.1-2001) defines `posix_memalign`: + // + // > The value of alignment shall be a power of two multiple of sizeof (void *). + // + // and POSIX-based implementations of `aligned_alloc` inherit this requirement. At the time + // of writing, this is known to be the case on macOS (but not in glibc). + // + // Satisfy the stricter requirement to avoid spurious test failures on some platforms. + let min_align = core::mem::size_of::<*const crate::ffi::c_void>(); + let layout = layout.align_to(min_align).map_err(|_| AllocError)?; + let layout = layout.pad_to_align(); + // SAFETY: Returns either NULL or a pointer to a memory allocation that satisfies or // exceeds the given size and alignment requirements. - let dst = unsafe { libc_aligned_alloc(layout.align(), layout.size()) } as *mut u8; + let dst = unsafe { libc_aligned_alloc(layout.align(), layout.size()) }.cast::<u8>(); let dst = NonNull::new(dst).ok_or(AllocError)?; if flags.contains(__GFP_ZERO) { diff --git a/rust/kernel/alloc/kbox.rs b/rust/kernel/alloc/kbox.rs index 9ce414361c2c..856d05aa60f1 100644 --- a/rust/kernel/alloc/kbox.rs +++ b/rust/kernel/alloc/kbox.rs @@ -6,6 +6,7 @@ use super::allocator::{KVmalloc, Kmalloc, Vmalloc}; use super::{AllocError, Allocator, Flags}; use core::alloc::Layout; +use core::borrow::{Borrow, BorrowMut}; use core::fmt; use core::marker::PhantomData; use core::mem::ManuallyDrop; @@ -15,8 +16,10 @@ use core::pin::Pin; use core::ptr::NonNull; use core::result::Result; -use crate::init::{InPlaceInit, InPlaceWrite, Init, PinInit}; +use crate::ffi::c_void; +use crate::init::InPlaceInit; use crate::types::ForeignOwnable; +use pin_init::{InPlaceWrite, Init, PinInit, ZeroableOption}; /// The kernel's [`Box`] type -- a heap allocation for a single value of type `T`. /// @@ -56,12 +59,50 @@ use crate::types::ForeignOwnable; /// assert!(KVBox::<Huge>::new_uninit(GFP_KERNEL).is_ok()); /// ``` /// +/// [`Box`]es can also be used to store trait objects by coercing their type: +/// +/// ``` +/// trait FooTrait {} +/// +/// struct FooStruct; +/// impl FooTrait for FooStruct {} +/// +/// let _ = KBox::new(FooStruct, GFP_KERNEL)? as KBox<dyn FooTrait>; +/// # Ok::<(), Error>(()) +/// ``` +/// /// # Invariants /// /// `self.0` is always properly aligned and either points to memory allocated with `A` or, for /// zero-sized types, is a dangling, well aligned pointer. #[repr(transparent)] -pub struct Box<T: ?Sized, A: Allocator>(NonNull<T>, PhantomData<A>); +#[cfg_attr(CONFIG_RUSTC_HAS_COERCE_POINTEE, derive(core::marker::CoercePointee))] +pub struct Box<#[cfg_attr(CONFIG_RUSTC_HAS_COERCE_POINTEE, pointee)] T: ?Sized, A: Allocator>( + NonNull<T>, + PhantomData<A>, +); + +// This is to allow coercion from `Box<T, A>` to `Box<U, A>` if `T` can be converted to the +// dynamically-sized type (DST) `U`. +#[cfg(not(CONFIG_RUSTC_HAS_COERCE_POINTEE))] +impl<T, U, A> core::ops::CoerceUnsized<Box<U, A>> for Box<T, A> +where + T: ?Sized + core::marker::Unsize<U>, + U: ?Sized, + A: Allocator, +{ +} + +// This is to allow `Box<U, A>` to be dispatched on when `Box<T, A>` can be coerced into `Box<U, +// A>`. +#[cfg(not(CONFIG_RUSTC_HAS_COERCE_POINTEE))] +impl<T, U, A> core::ops::DispatchFromDyn<Box<U, A>> for Box<T, A> +where + T: ?Sized + core::marker::Unsize<U>, + U: ?Sized, + A: Allocator, +{ +} /// Type alias for [`Box`] with a [`Kmalloc`] allocator. /// @@ -99,6 +140,10 @@ pub type VBox<T> = Box<T, super::allocator::Vmalloc>; /// ``` pub type KVBox<T> = Box<T, super::allocator::KVmalloc>; +// SAFETY: All zeros is equivalent to `None` (option layout optimization guarantee: +// <https://doc.rust-lang.org/stable/std/option/index.html#representation>). +unsafe impl<T, A: Allocator> ZeroableOption for Box<T, A> {} + // SAFETY: `Box` is `Send` if `T` is `Send` because the `Box` owns a `T`. unsafe impl<T, A> Send for Box<T, A> where @@ -245,6 +290,12 @@ where Ok(Self::new(x, flags)?.into()) } + /// Convert a [`Box<T,A>`] to a [`Pin<Box<T,A>>`]. If `T` does not implement + /// [`Unpin`], then `x` will be pinned in memory and can't be moved. + pub fn into_pin(this: Self) -> Pin<Self> { + this.into() + } + /// Forgets the contents (does not run the destructor), but keeps the allocation. fn forget_contents(this: Self) -> Box<MaybeUninit<T>, A> { let ptr = Self::into_raw(this); @@ -349,47 +400,62 @@ where } } -impl<T: 'static, A> ForeignOwnable for Box<T, A> +// SAFETY: The pointer returned by `into_foreign` comes from a well aligned +// pointer to `T`. +unsafe impl<T: 'static, A> ForeignOwnable for Box<T, A> where A: Allocator, { + const FOREIGN_ALIGN: usize = core::mem::align_of::<T>(); type Borrowed<'a> = &'a T; + type BorrowedMut<'a> = &'a mut T; - fn into_foreign(self) -> *const crate::ffi::c_void { - Box::into_raw(self) as _ + fn into_foreign(self) -> *mut c_void { + Box::into_raw(self).cast() } - unsafe fn from_foreign(ptr: *const crate::ffi::c_void) -> Self { + unsafe fn from_foreign(ptr: *mut c_void) -> Self { // SAFETY: The safety requirements of this function ensure that `ptr` comes from a previous // call to `Self::into_foreign`. - unsafe { Box::from_raw(ptr as _) } + unsafe { Box::from_raw(ptr.cast()) } } - unsafe fn borrow<'a>(ptr: *const crate::ffi::c_void) -> &'a T { + unsafe fn borrow<'a>(ptr: *mut c_void) -> &'a T { // SAFETY: The safety requirements of this method ensure that the object remains alive and // immutable for the duration of 'a. unsafe { &*ptr.cast() } } + + unsafe fn borrow_mut<'a>(ptr: *mut c_void) -> &'a mut T { + let ptr = ptr.cast(); + // SAFETY: The safety requirements of this method ensure that the pointer is valid and that + // nothing else will access the value for the duration of 'a. + unsafe { &mut *ptr } + } } -impl<T: 'static, A> ForeignOwnable for Pin<Box<T, A>> +// SAFETY: The pointer returned by `into_foreign` comes from a well aligned +// pointer to `T`. +unsafe impl<T: 'static, A> ForeignOwnable for Pin<Box<T, A>> where A: Allocator, { + const FOREIGN_ALIGN: usize = core::mem::align_of::<T>(); type Borrowed<'a> = Pin<&'a T>; + type BorrowedMut<'a> = Pin<&'a mut T>; - fn into_foreign(self) -> *const crate::ffi::c_void { + fn into_foreign(self) -> *mut c_void { // SAFETY: We are still treating the box as pinned. - Box::into_raw(unsafe { Pin::into_inner_unchecked(self) }) as _ + Box::into_raw(unsafe { Pin::into_inner_unchecked(self) }).cast() } - unsafe fn from_foreign(ptr: *const crate::ffi::c_void) -> Self { + unsafe fn from_foreign(ptr: *mut c_void) -> Self { // SAFETY: The safety requirements of this function ensure that `ptr` comes from a previous // call to `Self::into_foreign`. - unsafe { Pin::new_unchecked(Box::from_raw(ptr as _)) } + unsafe { Pin::new_unchecked(Box::from_raw(ptr.cast())) } } - unsafe fn borrow<'a>(ptr: *const crate::ffi::c_void) -> Pin<&'a T> { + unsafe fn borrow<'a>(ptr: *mut c_void) -> Pin<&'a T> { // SAFETY: The safety requirements for this function ensure that the object is still alive, // so it is safe to dereference the raw pointer. // The safety requirements of `from_foreign` also ensure that the object remains alive for @@ -399,6 +465,18 @@ where // SAFETY: This pointer originates from a `Pin<Box<T>>`. unsafe { Pin::new_unchecked(r) } } + + unsafe fn borrow_mut<'a>(ptr: *mut c_void) -> Pin<&'a mut T> { + let ptr = ptr.cast(); + // SAFETY: The safety requirements for this function ensure that the object is still alive, + // so it is safe to dereference the raw pointer. + // The safety requirements of `from_foreign` also ensure that the object remains alive for + // the lifetime of the returned value. + let r = unsafe { &mut *ptr }; + + // SAFETY: This pointer originates from a `Pin<Box<T>>`. + unsafe { Pin::new_unchecked(r) } + } } impl<T, A> Deref for Box<T, A> @@ -427,13 +505,79 @@ where } } +/// # Examples +/// +/// ``` +/// # use core::borrow::Borrow; +/// # use kernel::alloc::KBox; +/// struct Foo<B: Borrow<u32>>(B); +/// +/// // Owned instance. +/// let owned = Foo(1); +/// +/// // Owned instance using `KBox`. +/// let owned_kbox = Foo(KBox::new(1, GFP_KERNEL)?); +/// +/// let i = 1; +/// // Borrowed from `i`. +/// let borrowed = Foo(&i); +/// # Ok::<(), Error>(()) +/// ``` +impl<T, A> Borrow<T> for Box<T, A> +where + T: ?Sized, + A: Allocator, +{ + fn borrow(&self) -> &T { + self.deref() + } +} + +/// # Examples +/// +/// ``` +/// # use core::borrow::BorrowMut; +/// # use kernel::alloc::KBox; +/// struct Foo<B: BorrowMut<u32>>(B); +/// +/// // Owned instance. +/// let owned = Foo(1); +/// +/// // Owned instance using `KBox`. +/// let owned_kbox = Foo(KBox::new(1, GFP_KERNEL)?); +/// +/// let mut i = 1; +/// // Borrowed from `i`. +/// let borrowed = Foo(&mut i); +/// # Ok::<(), Error>(()) +/// ``` +impl<T, A> BorrowMut<T> for Box<T, A> +where + T: ?Sized, + A: Allocator, +{ + fn borrow_mut(&mut self) -> &mut T { + self.deref_mut() + } +} + +impl<T, A> fmt::Display for Box<T, A> +where + T: ?Sized + fmt::Display, + A: Allocator, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + <T as fmt::Display>::fmt(&**self, f) + } +} + impl<T, A> fmt::Debug for Box<T, A> where T: ?Sized + fmt::Debug, A: Allocator, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Debug::fmt(&**self, f) + <T as fmt::Debug>::fmt(&**self, f) } } diff --git a/rust/kernel/alloc/kvec.rs b/rust/kernel/alloc/kvec.rs index ae9d072741ce..3c72e0bdddb8 100644 --- a/rust/kernel/alloc/kvec.rs +++ b/rust/kernel/alloc/kvec.rs @@ -8,6 +8,7 @@ use super::{ AllocError, Allocator, Box, Flags, }; use core::{ + borrow::{Borrow, BorrowMut}, fmt, marker::PhantomData, mem::{ManuallyDrop, MaybeUninit}, @@ -21,6 +22,9 @@ use core::{ slice::SliceIndex, }; +mod errors; +pub use self::errors::{InsertError, PushError, RemoveError}; + /// Create a [`KVec`] containing the arguments. /// /// New memory is allocated with `GFP_KERNEL`. @@ -90,6 +94,8 @@ macro_rules! kvec { /// without re-allocation. For ZSTs `self.layout`'s capacity is zero. However, it is legal for the /// backing buffer to be larger than `layout`. /// +/// - `self.len()` is always less than or equal to `self.capacity()`. +/// /// - The `Allocator` type `A` of the vector is the exact same `Allocator` type the backing buffer /// was allocated with (and must be freed with). pub struct Vec<T, A: Allocator> { @@ -183,17 +189,38 @@ where self.len } - /// Forcefully sets `self.len` to `new_len`. + /// Increments `self.len` by `additional`. /// /// # Safety /// - /// - `new_len` must be less than or equal to [`Self::capacity`]. - /// - If `new_len` is greater than `self.len`, all elements within the interval - /// [`self.len`,`new_len`) must be initialized. + /// - `additional` must be less than or equal to `self.capacity - self.len`. + /// - All elements within the interval [`self.len`,`self.len + additional`) must be initialized. #[inline] - pub unsafe fn set_len(&mut self, new_len: usize) { - debug_assert!(new_len <= self.capacity()); - self.len = new_len; + pub unsafe fn inc_len(&mut self, additional: usize) { + // Guaranteed by the type invariant to never underflow. + debug_assert!(additional <= self.capacity() - self.len()); + // INVARIANT: By the safety requirements of this method this represents the exact number of + // elements stored within `self`. + self.len += additional; + } + + /// Decreases `self.len` by `count`. + /// + /// Returns a mutable slice to the elements forgotten by the vector. It is the caller's + /// responsibility to drop these elements if necessary. + /// + /// # Safety + /// + /// - `count` must be less than or equal to `self.len`. + unsafe fn dec_len(&mut self, count: usize) -> &mut [T] { + debug_assert!(count <= self.len()); + // INVARIANT: We relinquish ownership of the elements within the range `[self.len - count, + // self.len)`, hence the updated value of `set.len` represents the exact number of elements + // stored within `self`. + self.len -= count; + // SAFETY: The memory after `self.len()` is guaranteed to contain `count` initialized + // elements of type `T`. + unsafe { slice::from_raw_parts_mut(self.as_mut_ptr().add(self.len), count) } } /// Returns a slice of the entire vector. @@ -259,10 +286,10 @@ where /// Returns a slice of `MaybeUninit<T>` for the remaining spare capacity of the vector. pub fn spare_capacity_mut(&mut self) -> &mut [MaybeUninit<T>] { // SAFETY: - // - `self.len` is smaller than `self.capacity` and hence, the resulting pointer is - // guaranteed to be part of the same allocated object. + // - `self.len` is smaller than `self.capacity` by the type invariant and hence, the + // resulting pointer is guaranteed to be part of the same allocated object. // - `self.len` can not overflow `isize`. - let ptr = unsafe { self.as_mut_ptr().add(self.len) } as *mut MaybeUninit<T>; + let ptr = unsafe { self.as_mut_ptr().add(self.len) }.cast::<MaybeUninit<T>>(); // SAFETY: The memory between `self.len` and `self.capacity` is guaranteed to be allocated // and valid, but uninitialized. @@ -284,24 +311,170 @@ where /// ``` pub fn push(&mut self, v: T, flags: Flags) -> Result<(), AllocError> { self.reserve(1, flags)?; + // SAFETY: The call to `reserve` was successful, so the capacity is at least one greater + // than the length. + unsafe { self.push_within_capacity_unchecked(v) }; + Ok(()) + } - // SAFETY: - // - `self.len` is smaller than `self.capacity` and hence, the resulting pointer is - // guaranteed to be part of the same allocated object. - // - `self.len` can not overflow `isize`. - let ptr = unsafe { self.as_mut_ptr().add(self.len) }; + /// Appends an element to the back of the [`Vec`] instance without reallocating. + /// + /// Fails if the vector does not have capacity for the new element. + /// + /// # Examples + /// + /// ``` + /// let mut v = KVec::with_capacity(10, GFP_KERNEL)?; + /// for i in 0..10 { + /// v.push_within_capacity(i)?; + /// } + /// + /// assert!(v.push_within_capacity(10).is_err()); + /// # Ok::<(), Error>(()) + /// ``` + pub fn push_within_capacity(&mut self, v: T) -> Result<(), PushError<T>> { + if self.len() < self.capacity() { + // SAFETY: The length is less than the capacity. + unsafe { self.push_within_capacity_unchecked(v) }; + Ok(()) + } else { + Err(PushError(v)) + } + } - // SAFETY: - // - `ptr` is properly aligned and valid for writes. - unsafe { core::ptr::write(ptr, v) }; + /// Appends an element to the back of the [`Vec`] instance without reallocating. + /// + /// # Safety + /// + /// The length must be less than the capacity. + unsafe fn push_within_capacity_unchecked(&mut self, v: T) { + let spare = self.spare_capacity_mut(); + + // SAFETY: By the safety requirements, `spare` is non-empty. + unsafe { spare.get_unchecked_mut(0) }.write(v); // SAFETY: We just initialised the first spare entry, so it is safe to increase the length - // by 1. We also know that the new length is <= capacity because of the previous call to - // `reserve` above. - unsafe { self.set_len(self.len() + 1) }; + // by 1. We also know that the new length is <= capacity because the caller guarantees that + // the length is less than the capacity at the beginning of this function. + unsafe { self.inc_len(1) }; + } + + /// Inserts an element at the given index in the [`Vec`] instance. + /// + /// Fails if the vector does not have capacity for the new element. Panics if the index is out + /// of bounds. + /// + /// # Examples + /// + /// ``` + /// use kernel::alloc::kvec::InsertError; + /// + /// let mut v = KVec::with_capacity(5, GFP_KERNEL)?; + /// for i in 0..5 { + /// v.insert_within_capacity(0, i)?; + /// } + /// + /// assert!(matches!(v.insert_within_capacity(0, 5), Err(InsertError::OutOfCapacity(_)))); + /// assert!(matches!(v.insert_within_capacity(1000, 5), Err(InsertError::IndexOutOfBounds(_)))); + /// assert_eq!(v, [4, 3, 2, 1, 0]); + /// # Ok::<(), Error>(()) + /// ``` + pub fn insert_within_capacity( + &mut self, + index: usize, + element: T, + ) -> Result<(), InsertError<T>> { + let len = self.len(); + if index > len { + return Err(InsertError::IndexOutOfBounds(element)); + } + + if len >= self.capacity() { + return Err(InsertError::OutOfCapacity(element)); + } + + // SAFETY: This is in bounds since `index <= len < capacity`. + let p = unsafe { self.as_mut_ptr().add(index) }; + // INVARIANT: This breaks the Vec invariants by making `index` contain an invalid element, + // but we restore the invariants below. + // SAFETY: Both the src and dst ranges end no later than one element after the length. + // Since the length is less than the capacity, both ranges are in bounds of the allocation. + unsafe { ptr::copy(p, p.add(1), len - index) }; + // INVARIANT: This restores the Vec invariants. + // SAFETY: The pointer is in-bounds of the allocation. + unsafe { ptr::write(p, element) }; + // SAFETY: Index `len` contains a valid element due to the above copy and write. + unsafe { self.inc_len(1) }; Ok(()) } + /// Removes the last element from a vector and returns it, or `None` if it is empty. + /// + /// # Examples + /// + /// ``` + /// let mut v = KVec::new(); + /// v.push(1, GFP_KERNEL)?; + /// v.push(2, GFP_KERNEL)?; + /// assert_eq!(&v, &[1, 2]); + /// + /// assert_eq!(v.pop(), Some(2)); + /// assert_eq!(v.pop(), Some(1)); + /// assert_eq!(v.pop(), None); + /// # Ok::<(), Error>(()) + /// ``` + pub fn pop(&mut self) -> Option<T> { + if self.is_empty() { + return None; + } + + let removed: *mut T = { + // SAFETY: We just checked that the length is at least one. + let slice = unsafe { self.dec_len(1) }; + // SAFETY: The argument to `dec_len` was 1 so this returns a slice of length 1. + unsafe { slice.get_unchecked_mut(0) } + }; + + // SAFETY: The guarantees of `dec_len` allow us to take ownership of this value. + Some(unsafe { removed.read() }) + } + + /// Removes the element at the given index. + /// + /// # Examples + /// + /// ``` + /// let mut v = kernel::kvec![1, 2, 3]?; + /// assert_eq!(v.remove(1)?, 2); + /// assert_eq!(v, [1, 3]); + /// # Ok::<(), Error>(()) + /// ``` + pub fn remove(&mut self, i: usize) -> Result<T, RemoveError> { + let value = { + let value_ref = self.get(i).ok_or(RemoveError)?; + // INVARIANT: This breaks the invariants by invalidating the value at index `i`, but we + // restore the invariants below. + // SAFETY: The value at index `i` is valid, because otherwise we would have already + // failed with `RemoveError`. + unsafe { ptr::read(value_ref) } + }; + + // SAFETY: We checked that `i` is in-bounds. + let p = unsafe { self.as_mut_ptr().add(i) }; + + // INVARIANT: After this call, the invalid value is at the last slot, so the Vec invariants + // are restored after the below call to `dec_len(1)`. + // SAFETY: `p.add(1).add(self.len - i - 1)` is `i+1+len-i-1 == len` elements after the + // beginning of the vector, so this is in-bounds of the vector's allocation. + unsafe { ptr::copy(p.add(1), p, self.len - i - 1) }; + + // SAFETY: Since the check at the beginning of this call did not fail with `RemoveError`, + // the length is at least one. + unsafe { self.dec_len(1) }; + + Ok(value) + } + /// Creates a new [`Vec`] instance with at least the given capacity. /// /// # Examples @@ -395,6 +568,26 @@ where (ptr, len, capacity) } + /// Clears the vector, removing all values. + /// + /// Note that this method has no effect on the allocated capacity + /// of the vector. + /// + /// # Examples + /// + /// ``` + /// let mut v = kernel::kvec![1, 2, 3]?; + /// + /// v.clear(); + /// + /// assert!(v.is_empty()); + /// # Ok::<(), Error>(()) + /// ``` + #[inline] + pub fn clear(&mut self) { + self.truncate(0); + } + /// Ensures that the capacity exceeds the length by at least `additional` elements. /// /// # Examples @@ -452,6 +645,80 @@ where Ok(()) } + + /// Shortens the vector, setting the length to `len` and drops the removed values. + /// If `len` is greater than or equal to the current length, this does nothing. + /// + /// This has no effect on the capacity and will not allocate. + /// + /// # Examples + /// + /// ``` + /// let mut v = kernel::kvec![1, 2, 3]?; + /// v.truncate(1); + /// assert_eq!(v.len(), 1); + /// assert_eq!(&v, &[1]); + /// + /// # Ok::<(), Error>(()) + /// ``` + pub fn truncate(&mut self, len: usize) { + if let Some(count) = self.len().checked_sub(len) { + // SAFETY: `count` is `self.len() - len` so it is guaranteed to be less than or + // equal to `self.len()`. + let ptr: *mut [T] = unsafe { self.dec_len(count) }; + + // SAFETY: the contract of `dec_len` guarantees that the elements in `ptr` are + // valid elements whose ownership has been transferred to the caller. + unsafe { ptr::drop_in_place(ptr) }; + } + } + + /// Takes ownership of all items in this vector without consuming the allocation. + /// + /// # Examples + /// + /// ``` + /// let mut v = kernel::kvec![0, 1, 2, 3]?; + /// + /// for (i, j) in v.drain_all().enumerate() { + /// assert_eq!(i, j); + /// } + /// + /// assert!(v.capacity() >= 4); + /// # Ok::<(), Error>(()) + /// ``` + pub fn drain_all(&mut self) -> DrainAll<'_, T> { + // SAFETY: This does not underflow the length. + let elems = unsafe { self.dec_len(self.len()) }; + // INVARIANT: The first `len` elements of the spare capacity are valid values, and as we + // just set the length to zero, we may transfer ownership to the `DrainAll` object. + DrainAll { + elements: elems.iter_mut(), + } + } + + /// Removes all elements that don't match the provided closure. + /// + /// # Examples + /// + /// ``` + /// let mut v = kernel::kvec![1, 2, 3, 4]?; + /// v.retain(|i| *i % 2 == 0); + /// assert_eq!(v, [2, 4]); + /// # Ok::<(), Error>(()) + /// ``` + pub fn retain(&mut self, mut f: impl FnMut(&mut T) -> bool) { + let mut num_kept = 0; + let mut next_to_check = 0; + while let Some(to_check) = self.get_mut(next_to_check) { + if f(to_check) { + self.swap(num_kept, next_to_check); + num_kept += 1; + } + next_to_check += 1; + } + self.truncate(num_kept); + } } impl<T: Clone, A: Allocator> Vec<T, A> { @@ -475,7 +742,7 @@ impl<T: Clone, A: Allocator> Vec<T, A> { // SAFETY: // - `self.len() + n < self.capacity()` due to the call to reserve above, // - the loop and the line above initialized the next `n` elements. - unsafe { self.set_len(self.len() + n) }; + unsafe { self.inc_len(n) }; Ok(()) } @@ -506,7 +773,7 @@ impl<T: Clone, A: Allocator> Vec<T, A> { // the length by the same number. // - `self.len() + other.len() <= self.capacity()` is guaranteed by the preceding `reserve` // call. - unsafe { self.set_len(self.len() + other.len()) }; + unsafe { self.inc_len(other.len()) }; Ok(()) } @@ -518,6 +785,33 @@ impl<T: Clone, A: Allocator> Vec<T, A> { Ok(v) } + + /// Resizes the [`Vec`] so that `len` is equal to `new_len`. + /// + /// If `new_len` is smaller than `len`, the `Vec` is [`Vec::truncate`]d. + /// If `new_len` is larger, each new slot is filled with clones of `value`. + /// + /// # Examples + /// + /// ``` + /// let mut v = kernel::kvec![1, 2, 3]?; + /// v.resize(1, 42, GFP_KERNEL)?; + /// assert_eq!(&v, &[1]); + /// + /// v.resize(3, 42, GFP_KERNEL)?; + /// assert_eq!(&v, &[1, 42, 42]); + /// + /// # Ok::<(), Error>(()) + /// ``` + pub fn resize(&mut self, new_len: usize, value: T, flags: Flags) -> Result<(), AllocError> { + match new_len.checked_sub(self.len()) { + Some(n) => self.extend_with(n, value, flags), + None => { + self.truncate(new_len); + Ok(()) + } + } + } } impl<T, A> Drop for Vec<T, A> @@ -554,11 +848,11 @@ where // - `ptr` points to memory with at least a size of `size_of::<T>() * len`, // - all elements within `b` are initialized values of `T`, // - `len` does not exceed `isize::MAX`. - unsafe { Vec::from_raw_parts(ptr as _, len, len) } + unsafe { Vec::from_raw_parts(ptr.cast(), len, len) } } } -impl<T> Default for KVec<T> { +impl<T, A: Allocator> Default for Vec<T, A> { #[inline] fn default() -> Self { Self::new() @@ -597,6 +891,58 @@ where } } +/// # Examples +/// +/// ``` +/// # use core::borrow::Borrow; +/// struct Foo<B: Borrow<[u32]>>(B); +/// +/// // Owned array. +/// let owned_array = Foo([1, 2, 3]); +/// +/// // Owned vector. +/// let owned_vec = Foo(KVec::from_elem(0, 3, GFP_KERNEL)?); +/// +/// let arr = [1, 2, 3]; +/// // Borrowed slice from `arr`. +/// let borrowed_slice = Foo(&arr[..]); +/// # Ok::<(), Error>(()) +/// ``` +impl<T, A> Borrow<[T]> for Vec<T, A> +where + A: Allocator, +{ + fn borrow(&self) -> &[T] { + self.as_slice() + } +} + +/// # Examples +/// +/// ``` +/// # use core::borrow::BorrowMut; +/// struct Foo<B: BorrowMut<[u32]>>(B); +/// +/// // Owned array. +/// let owned_array = Foo([1, 2, 3]); +/// +/// // Owned vector. +/// let owned_vec = Foo(KVec::from_elem(0, 3, GFP_KERNEL)?); +/// +/// let mut arr = [1, 2, 3]; +/// // Borrowed slice from `arr`. +/// let borrowed_slice = Foo(&mut arr[..]); +/// # Ok::<(), Error>(()) +/// ``` +impl<T, A> BorrowMut<[T]> for Vec<T, A> +where + A: Allocator, +{ + fn borrow_mut(&mut self) -> &mut [T] { + self.as_mut_slice() + } +} + impl<T: Eq, A> Eq for Vec<T, A> where A: Allocator {} impl<T, I: SliceIndex<[T]>, A> Index<I> for Vec<T, A> @@ -757,12 +1103,13 @@ where unsafe { ptr::copy(ptr, buf.as_ptr(), len) }; ptr = buf.as_ptr(); - // SAFETY: `len` is guaranteed to be smaller than `self.layout.len()`. + // SAFETY: `len` is guaranteed to be smaller than `self.layout.len()` by the type + // invariant. let layout = unsafe { ArrayLayout::<T>::new_unchecked(len) }; - // SAFETY: `buf` points to the start of the backing buffer and `len` is guaranteed to be - // smaller than `cap`. Depending on `alloc` this operation may shrink the buffer or leaves - // it as it is. + // SAFETY: `buf` points to the start of the backing buffer and `len` is guaranteed by + // the type invariant to be smaller than `cap`. Depending on `realloc` this operation + // may shrink the buffer or leave it as it is. ptr = match unsafe { A::realloc(Some(buf.cast()), layout.into(), old_layout.into(), flags) } { @@ -911,3 +1258,87 @@ where } } } + +/// An iterator that owns all items in a vector, but does not own its allocation. +/// +/// # Invariants +/// +/// Every `&mut T` returned by the iterator references a `T` that the iterator may take ownership +/// of. +pub struct DrainAll<'vec, T> { + elements: slice::IterMut<'vec, T>, +} + +impl<'vec, T> Iterator for DrainAll<'vec, T> { + type Item = T; + + fn next(&mut self) -> Option<T> { + let elem: *mut T = self.elements.next()?; + // SAFETY: By the type invariants, we may take ownership of this value. + Some(unsafe { elem.read() }) + } + + fn size_hint(&self) -> (usize, Option<usize>) { + self.elements.size_hint() + } +} + +impl<'vec, T> Drop for DrainAll<'vec, T> { + fn drop(&mut self) { + if core::mem::needs_drop::<T>() { + let iter = core::mem::take(&mut self.elements); + let ptr: *mut [T] = iter.into_slice(); + // SAFETY: By the type invariants, we own these values so we may destroy them. + unsafe { ptr::drop_in_place(ptr) }; + } + } +} + +#[macros::kunit_tests(rust_kvec_kunit)] +mod tests { + use super::*; + use crate::prelude::*; + + #[test] + fn test_kvec_retain() { + /// Verify correctness for one specific function. + #[expect(clippy::needless_range_loop)] + fn verify(c: &[bool]) { + let mut vec1: KVec<usize> = KVec::with_capacity(c.len(), GFP_KERNEL).unwrap(); + let mut vec2: KVec<usize> = KVec::with_capacity(c.len(), GFP_KERNEL).unwrap(); + + for i in 0..c.len() { + vec1.push_within_capacity(i).unwrap(); + if c[i] { + vec2.push_within_capacity(i).unwrap(); + } + } + + vec1.retain(|i| c[*i]); + + assert_eq!(vec1, vec2); + } + + /// Add one to a binary integer represented as a boolean array. + fn add(value: &mut [bool]) { + let mut carry = true; + for v in value { + let new_v = carry != *v; + carry = carry && *v; + *v = new_v; + } + } + + // This boolean array represents a function from index to boolean. We check that `retain` + // behaves correctly for all possible boolean arrays of every possible length less than + // ten. + let mut func = KVec::with_capacity(10, GFP_KERNEL).unwrap(); + for len in 0..10 { + for _ in 0u32..1u32 << len { + verify(&func); + add(&mut func); + } + func.push_within_capacity(false).unwrap(); + } + } +} diff --git a/rust/kernel/alloc/kvec/errors.rs b/rust/kernel/alloc/kvec/errors.rs new file mode 100644 index 000000000000..348b8d27e102 --- /dev/null +++ b/rust/kernel/alloc/kvec/errors.rs @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Errors for the [`Vec`] type. + +use core::fmt::{self, Debug, Formatter}; +use kernel::prelude::*; + +/// Error type for [`Vec::push_within_capacity`]. +pub struct PushError<T>(pub T); + +impl<T> Debug for PushError<T> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "Not enough capacity") + } +} + +impl<T> From<PushError<T>> for Error { + fn from(_: PushError<T>) -> Error { + // Returning ENOMEM isn't appropriate because the system is not out of memory. The vector + // is just full and we are refusing to resize it. + EINVAL + } +} + +/// Error type for [`Vec::remove`]. +pub struct RemoveError; + +impl Debug for RemoveError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "Index out of bounds") + } +} + +impl From<RemoveError> for Error { + fn from(_: RemoveError) -> Error { + EINVAL + } +} + +/// Error type for [`Vec::insert_within_capacity`]. +pub enum InsertError<T> { + /// The value could not be inserted because the index is out of bounds. + IndexOutOfBounds(T), + /// The value could not be inserted because the vector is out of capacity. + OutOfCapacity(T), +} + +impl<T> Debug for InsertError<T> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + InsertError::IndexOutOfBounds(_) => write!(f, "Index out of bounds"), + InsertError::OutOfCapacity(_) => write!(f, "Not enough capacity"), + } + } +} + +impl<T> From<InsertError<T>> for Error { + fn from(_: InsertError<T>) -> Error { + EINVAL + } +} diff --git a/rust/kernel/alloc/layout.rs b/rust/kernel/alloc/layout.rs index 4b3cd7fdc816..93ed514f7cc7 100644 --- a/rust/kernel/alloc/layout.rs +++ b/rust/kernel/alloc/layout.rs @@ -43,6 +43,25 @@ impl<T> ArrayLayout<T> { /// # Errors /// /// When `len * size_of::<T>()` overflows or when `len * size_of::<T>() > isize::MAX`. + /// + /// # Examples + /// + /// ``` + /// # use kernel::alloc::layout::{ArrayLayout, LayoutError}; + /// let layout = ArrayLayout::<i32>::new(15)?; + /// assert_eq!(layout.len(), 15); + /// + /// // Errors because `len * size_of::<T>()` overflows. + /// let layout = ArrayLayout::<i32>::new(isize::MAX as usize); + /// assert!(layout.is_err()); + /// + /// // Errors because `len * size_of::<i32>() > isize::MAX`, + /// // even though `len < isize::MAX`. + /// let layout = ArrayLayout::<i32>::new(isize::MAX as usize / 2); + /// assert!(layout.is_err()); + /// + /// # Ok::<(), Error>(()) + /// ``` pub const fn new(len: usize) -> Result<Self, LayoutError> { match len.checked_mul(core::mem::size_of::<T>()) { Some(size) if size <= ISIZE_MAX => { diff --git a/rust/kernel/auxiliary.rs b/rust/kernel/auxiliary.rs new file mode 100644 index 000000000000..4749fb6bffef --- /dev/null +++ b/rust/kernel/auxiliary.rs @@ -0,0 +1,357 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Abstractions for the auxiliary bus. +//! +//! C header: [`include/linux/auxiliary_bus.h`](srctree/include/linux/auxiliary_bus.h) + +use crate::{ + bindings, container_of, device, + device_id::{RawDeviceId, RawDeviceIdIndex}, + driver, + error::{from_result, to_result, Result}, + prelude::*, + types::Opaque, + ThisModule, +}; +use core::{ + marker::PhantomData, + ptr::{addr_of_mut, NonNull}, +}; + +/// An adapter for the registration of auxiliary drivers. +pub struct Adapter<T: Driver>(T); + +// SAFETY: A call to `unregister` for a given instance of `RegType` is guaranteed to be valid if +// a preceding call to `register` has been successful. +unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { + type RegType = bindings::auxiliary_driver; + + unsafe fn register( + adrv: &Opaque<Self::RegType>, + name: &'static CStr, + module: &'static ThisModule, + ) -> Result { + // SAFETY: It's safe to set the fields of `struct auxiliary_driver` on initialization. + unsafe { + (*adrv.get()).name = name.as_char_ptr(); + (*adrv.get()).probe = Some(Self::probe_callback); + (*adrv.get()).remove = Some(Self::remove_callback); + (*adrv.get()).id_table = T::ID_TABLE.as_ptr(); + } + + // SAFETY: `adrv` is guaranteed to be a valid `RegType`. + to_result(unsafe { + bindings::__auxiliary_driver_register(adrv.get(), module.0, name.as_char_ptr()) + }) + } + + unsafe fn unregister(adrv: &Opaque<Self::RegType>) { + // SAFETY: `adrv` is guaranteed to be a valid `RegType`. + unsafe { bindings::auxiliary_driver_unregister(adrv.get()) } + } +} + +impl<T: Driver + 'static> Adapter<T> { + extern "C" fn probe_callback( + adev: *mut bindings::auxiliary_device, + id: *const bindings::auxiliary_device_id, + ) -> kernel::ffi::c_int { + // SAFETY: The auxiliary bus only ever calls the probe callback with a valid pointer to a + // `struct auxiliary_device`. + // + // INVARIANT: `adev` is valid for the duration of `probe_callback()`. + let adev = unsafe { &*adev.cast::<Device<device::CoreInternal>>() }; + + // SAFETY: `DeviceId` is a `#[repr(transparent)`] wrapper of `struct auxiliary_device_id` + // and does not add additional invariants, so it's safe to transmute. + let id = unsafe { &*id.cast::<DeviceId>() }; + let info = T::ID_TABLE.info(id.index()); + + from_result(|| { + let data = T::probe(adev, info)?; + + adev.as_ref().set_drvdata(data); + Ok(0) + }) + } + + extern "C" fn remove_callback(adev: *mut bindings::auxiliary_device) { + // SAFETY: The auxiliary bus only ever calls the probe callback with a valid pointer to a + // `struct auxiliary_device`. + // + // INVARIANT: `adev` is valid for the duration of `probe_callback()`. + let adev = unsafe { &*adev.cast::<Device<device::CoreInternal>>() }; + + // SAFETY: `remove_callback` is only ever called after a successful call to + // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called + // and stored a `Pin<KBox<T>>`. + drop(unsafe { adev.as_ref().drvdata_obtain::<Pin<KBox<T>>>() }); + } +} + +/// Declares a kernel module that exposes a single auxiliary driver. +#[macro_export] +macro_rules! module_auxiliary_driver { + ($($f:tt)*) => { + $crate::module_driver!(<T>, $crate::auxiliary::Adapter<T>, { $($f)* }); + }; +} + +/// Abstraction for `bindings::auxiliary_device_id`. +#[repr(transparent)] +#[derive(Clone, Copy)] +pub struct DeviceId(bindings::auxiliary_device_id); + +impl DeviceId { + /// Create a new [`DeviceId`] from name. + pub const fn new(modname: &'static CStr, name: &'static CStr) -> Self { + let name = name.as_bytes_with_nul(); + let modname = modname.as_bytes_with_nul(); + + // TODO: Replace with `bindings::auxiliary_device_id::default()` once stabilized for + // `const`. + // + // SAFETY: FFI type is valid to be zero-initialized. + let mut id: bindings::auxiliary_device_id = unsafe { core::mem::zeroed() }; + + let mut i = 0; + while i < modname.len() { + id.name[i] = modname[i]; + i += 1; + } + + // Reuse the space of the NULL terminator. + id.name[i - 1] = b'.'; + + let mut j = 0; + while j < name.len() { + id.name[i] = name[j]; + i += 1; + j += 1; + } + + Self(id) + } +} + +// SAFETY: `DeviceId` is a `#[repr(transparent)]` wrapper of `auxiliary_device_id` and does not add +// additional invariants, so it's safe to transmute to `RawType`. +unsafe impl RawDeviceId for DeviceId { + type RawType = bindings::auxiliary_device_id; +} + +// SAFETY: `DRIVER_DATA_OFFSET` is the offset to the `driver_data` field. +unsafe impl RawDeviceIdIndex for DeviceId { + const DRIVER_DATA_OFFSET: usize = + core::mem::offset_of!(bindings::auxiliary_device_id, driver_data); + + fn index(&self) -> usize { + self.0.driver_data + } +} + +/// IdTable type for auxiliary drivers. +pub type IdTable<T> = &'static dyn kernel::device_id::IdTable<DeviceId, T>; + +/// Create a auxiliary `IdTable` with its alias for modpost. +#[macro_export] +macro_rules! auxiliary_device_table { + ($table_name:ident, $module_table_name:ident, $id_info_type: ty, $table_data: expr) => { + const $table_name: $crate::device_id::IdArray< + $crate::auxiliary::DeviceId, + $id_info_type, + { $table_data.len() }, + > = $crate::device_id::IdArray::new($table_data); + + $crate::module_device_table!("auxiliary", $module_table_name, $table_name); + }; +} + +/// The auxiliary driver trait. +/// +/// Drivers must implement this trait in order to get an auxiliary driver registered. +pub trait Driver { + /// The type holding information about each device id supported by the driver. + /// + /// TODO: Use associated_type_defaults once stabilized: + /// + /// type IdInfo: 'static = (); + type IdInfo: 'static; + + /// The table of device ids supported by the driver. + const ID_TABLE: IdTable<Self::IdInfo>; + + /// Auxiliary driver probe. + /// + /// Called when an auxiliary device is matches a corresponding driver. + fn probe(dev: &Device<device::Core>, id_info: &Self::IdInfo) -> Result<Pin<KBox<Self>>>; +} + +/// The auxiliary device representation. +/// +/// This structure represents the Rust abstraction for a C `struct auxiliary_device`. The +/// implementation abstracts the usage of an already existing C `struct auxiliary_device` within +/// Rust code that we get passed from the C side. +/// +/// # Invariants +/// +/// A [`Device`] instance represents a valid `struct auxiliary_device` created by the C portion of +/// the kernel. +#[repr(transparent)] +pub struct Device<Ctx: device::DeviceContext = device::Normal>( + Opaque<bindings::auxiliary_device>, + PhantomData<Ctx>, +); + +impl<Ctx: device::DeviceContext> Device<Ctx> { + fn as_raw(&self) -> *mut bindings::auxiliary_device { + self.0.get() + } + + /// Returns the auxiliary device' id. + pub fn id(&self) -> u32 { + // SAFETY: By the type invariant `self.as_raw()` is a valid pointer to a + // `struct auxiliary_device`. + unsafe { (*self.as_raw()).id } + } + + /// Returns a reference to the parent [`device::Device`], if any. + pub fn parent(&self) -> Option<&device::Device> { + let ptr: *const Self = self; + // CAST: `Device<Ctx: DeviceContext>` types are transparent to each other. + let ptr: *const Device = ptr.cast(); + // SAFETY: `ptr` was derived from `&self`. + let this = unsafe { &*ptr }; + + this.as_ref().parent() + } +} + +impl Device { + extern "C" fn release(dev: *mut bindings::device) { + // SAFETY: By the type invariant `self.0.as_raw` is a pointer to the `struct device` + // embedded in `struct auxiliary_device`. + let adev = unsafe { container_of!(dev, bindings::auxiliary_device, dev) }; + + // SAFETY: `adev` points to the memory that has been allocated in `Registration::new`, via + // `KBox::new(Opaque::<bindings::auxiliary_device>::zeroed(), GFP_KERNEL)`. + let _ = unsafe { KBox::<Opaque<bindings::auxiliary_device>>::from_raw(adev.cast()) }; + } +} + +// SAFETY: `Device` is a transparent wrapper of a type that doesn't depend on `Device`'s generic +// argument. +kernel::impl_device_context_deref!(unsafe { Device }); +kernel::impl_device_context_into_aref!(Device); + +// SAFETY: Instances of `Device` are always reference-counted. +unsafe impl crate::types::AlwaysRefCounted for Device { + fn inc_ref(&self) { + // SAFETY: The existence of a shared reference guarantees that the refcount is non-zero. + unsafe { bindings::get_device(self.as_ref().as_raw()) }; + } + + unsafe fn dec_ref(obj: NonNull<Self>) { + // CAST: `Self` a transparent wrapper of `bindings::auxiliary_device`. + let adev: *mut bindings::auxiliary_device = obj.cast().as_ptr(); + + // SAFETY: By the type invariant of `Self`, `adev` is a pointer to a valid + // `struct auxiliary_device`. + let dev = unsafe { addr_of_mut!((*adev).dev) }; + + // SAFETY: The safety requirements guarantee that the refcount is non-zero. + unsafe { bindings::put_device(dev) } + } +} + +impl<Ctx: device::DeviceContext> AsRef<device::Device<Ctx>> for Device<Ctx> { + fn as_ref(&self) -> &device::Device<Ctx> { + // SAFETY: By the type invariant of `Self`, `self.as_raw()` is a pointer to a valid + // `struct auxiliary_device`. + let dev = unsafe { addr_of_mut!((*self.as_raw()).dev) }; + + // SAFETY: `dev` points to a valid `struct device`. + unsafe { device::Device::from_raw(dev) } + } +} + +// SAFETY: A `Device` is always reference-counted and can be released from any thread. +unsafe impl Send for Device {} + +// SAFETY: `Device` can be shared among threads because all methods of `Device` +// (i.e. `Device<Normal>) are thread safe. +unsafe impl Sync for Device {} + +/// The registration of an auxiliary device. +/// +/// This type represents the registration of a [`struct auxiliary_device`]. When an instance of this +/// type is dropped, its respective auxiliary device will be unregistered from the system. +/// +/// # Invariants +/// +/// `self.0` always holds a valid pointer to an initialized and registered +/// [`struct auxiliary_device`]. +pub struct Registration(NonNull<bindings::auxiliary_device>); + +impl Registration { + /// Create and register a new auxiliary device. + pub fn new(parent: &device::Device, name: &CStr, id: u32, modname: &CStr) -> Result<Self> { + let boxed = KBox::new(Opaque::<bindings::auxiliary_device>::zeroed(), GFP_KERNEL)?; + let adev = boxed.get(); + + // SAFETY: It's safe to set the fields of `struct auxiliary_device` on initialization. + unsafe { + (*adev).dev.parent = parent.as_raw(); + (*adev).dev.release = Some(Device::release); + (*adev).name = name.as_char_ptr(); + (*adev).id = id; + } + + // SAFETY: `adev` is guaranteed to be a valid pointer to a `struct auxiliary_device`, + // which has not been initialized yet. + unsafe { bindings::auxiliary_device_init(adev) }; + + // Now that `adev` is initialized, leak the `Box`; the corresponding memory will be freed + // by `Device::release` when the last reference to the `struct auxiliary_device` is dropped. + let _ = KBox::into_raw(boxed); + + // SAFETY: + // - `adev` is guaranteed to be a valid pointer to a `struct auxiliary_device`, which has + // been initialialized, + // - `modname.as_char_ptr()` is a NULL terminated string. + let ret = unsafe { bindings::__auxiliary_device_add(adev, modname.as_char_ptr()) }; + if ret != 0 { + // SAFETY: `adev` is guaranteed to be a valid pointer to a `struct auxiliary_device`, + // which has been initialialized. + unsafe { bindings::auxiliary_device_uninit(adev) }; + + return Err(Error::from_errno(ret)); + } + + // SAFETY: `adev` is guaranteed to be non-null, since the `KBox` was allocated successfully. + // + // INVARIANT: The device will remain registered until `auxiliary_device_delete()` is called, + // which happens in `Self::drop()`. + Ok(Self(unsafe { NonNull::new_unchecked(adev) })) + } +} + +impl Drop for Registration { + fn drop(&mut self) { + // SAFETY: By the type invariant of `Self`, `self.0.as_ptr()` is a valid registered + // `struct auxiliary_device`. + unsafe { bindings::auxiliary_device_delete(self.0.as_ptr()) }; + + // This drops the reference we acquired through `auxiliary_device_init()`. + // + // SAFETY: By the type invariant of `Self`, `self.0.as_ptr()` is a valid registered + // `struct auxiliary_device`. + unsafe { bindings::auxiliary_device_uninit(self.0.as_ptr()) }; + } +} + +// SAFETY: A `Registration` of a `struct auxiliary_device` can be released from any thread. +unsafe impl Send for Registration {} + +// SAFETY: `Registration` does not expose any methods or fields that need synchronization. +unsafe impl Sync for Registration {} diff --git a/rust/kernel/bits.rs b/rust/kernel/bits.rs new file mode 100644 index 000000000000..553d50265883 --- /dev/null +++ b/rust/kernel/bits.rs @@ -0,0 +1,203 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Bit manipulation macros. +//! +//! C header: [`include/linux/bits.h`](srctree/include/linux/bits.h) + +use crate::prelude::*; +use core::ops::RangeInclusive; +use macros::paste; + +macro_rules! impl_bit_fn { + ( + $ty:ty + ) => { + paste! { + /// Computes `1 << n` if `n` is in bounds, i.e.: if `n` is smaller than + /// the maximum number of bits supported by the type. + /// + /// Returns [`None`] otherwise. + #[inline] + pub fn [<checked_bit_ $ty>](n: u32) -> Option<$ty> { + (1 as $ty).checked_shl(n) + } + + /// Computes `1 << n` by performing a compile-time assertion that `n` is + /// in bounds. + /// + /// This version is the default and should be used if `n` is known at + /// compile time. + #[inline] + pub const fn [<bit_ $ty>](n: u32) -> $ty { + build_assert!(n < <$ty>::BITS); + (1 as $ty) << n + } + } + }; +} + +impl_bit_fn!(u64); +impl_bit_fn!(u32); +impl_bit_fn!(u16); +impl_bit_fn!(u8); + +macro_rules! impl_genmask_fn { + ( + $ty:ty, + $(#[$genmask_checked_ex:meta])*, + $(#[$genmask_ex:meta])* + ) => { + paste! { + /// Creates a contiguous bitmask for the given range by validating + /// the range at runtime. + /// + /// Returns [`None`] if the range is invalid, i.e.: if the start is + /// greater than the end or if the range is outside of the + /// representable range for the type. + $(#[$genmask_checked_ex])* + #[inline] + pub fn [<genmask_checked_ $ty>](range: RangeInclusive<u32>) -> Option<$ty> { + let start = *range.start(); + let end = *range.end(); + + if start > end { + return None; + } + + let high = [<checked_bit_ $ty>](end)?; + let low = [<checked_bit_ $ty>](start)?; + Some((high | (high - 1)) & !(low - 1)) + } + + /// Creates a compile-time contiguous bitmask for the given range by + /// performing a compile-time assertion that the range is valid. + /// + /// This version is the default and should be used if the range is known + /// at compile time. + $(#[$genmask_ex])* + #[inline] + pub const fn [<genmask_ $ty>](range: RangeInclusive<u32>) -> $ty { + let start = *range.start(); + let end = *range.end(); + + build_assert!(start <= end); + + let high = [<bit_ $ty>](end); + let low = [<bit_ $ty>](start); + (high | (high - 1)) & !(low - 1) + } + } + }; +} + +impl_genmask_fn!( + u64, + /// # Examples + /// + /// ``` + /// # #![expect(clippy::reversed_empty_ranges)] + /// # use kernel::bits::genmask_checked_u64; + /// assert_eq!(genmask_checked_u64(0..=0), Some(0b1)); + /// assert_eq!(genmask_checked_u64(0..=63), Some(u64::MAX)); + /// assert_eq!(genmask_checked_u64(21..=39), Some(0x0000_00ff_ffe0_0000)); + /// + /// // `80` is out of the supported bit range. + /// assert_eq!(genmask_checked_u64(21..=80), None); + /// + /// // Invalid range where the start is bigger than the end. + /// assert_eq!(genmask_checked_u64(15..=8), None); + /// ``` + , + /// # Examples + /// + /// ``` + /// # use kernel::bits::genmask_u64; + /// assert_eq!(genmask_u64(21..=39), 0x0000_00ff_ffe0_0000); + /// assert_eq!(genmask_u64(0..=0), 0b1); + /// assert_eq!(genmask_u64(0..=63), u64::MAX); + /// ``` +); + +impl_genmask_fn!( + u32, + /// # Examples + /// + /// ``` + /// # #![expect(clippy::reversed_empty_ranges)] + /// # use kernel::bits::genmask_checked_u32; + /// assert_eq!(genmask_checked_u32(0..=0), Some(0b1)); + /// assert_eq!(genmask_checked_u32(0..=31), Some(u32::MAX)); + /// assert_eq!(genmask_checked_u32(21..=31), Some(0xffe0_0000)); + /// + /// // `40` is out of the supported bit range. + /// assert_eq!(genmask_checked_u32(21..=40), None); + /// + /// // Invalid range where the start is bigger than the end. + /// assert_eq!(genmask_checked_u32(15..=8), None); + /// ``` + , + /// # Examples + /// + /// ``` + /// # use kernel::bits::genmask_u32; + /// assert_eq!(genmask_u32(21..=31), 0xffe0_0000); + /// assert_eq!(genmask_u32(0..=0), 0b1); + /// assert_eq!(genmask_u32(0..=31), u32::MAX); + /// ``` +); + +impl_genmask_fn!( + u16, + /// # Examples + /// + /// ``` + /// # #![expect(clippy::reversed_empty_ranges)] + /// # use kernel::bits::genmask_checked_u16; + /// assert_eq!(genmask_checked_u16(0..=0), Some(0b1)); + /// assert_eq!(genmask_checked_u16(0..=15), Some(u16::MAX)); + /// assert_eq!(genmask_checked_u16(6..=15), Some(0xffc0)); + /// + /// // `20` is out of the supported bit range. + /// assert_eq!(genmask_checked_u16(6..=20), None); + /// + /// // Invalid range where the start is bigger than the end. + /// assert_eq!(genmask_checked_u16(10..=5), None); + /// ``` + , + /// # Examples + /// + /// ``` + /// # use kernel::bits::genmask_u16; + /// assert_eq!(genmask_u16(6..=15), 0xffc0); + /// assert_eq!(genmask_u16(0..=0), 0b1); + /// assert_eq!(genmask_u16(0..=15), u16::MAX); + /// ``` +); + +impl_genmask_fn!( + u8, + /// # Examples + /// + /// ``` + /// # #![expect(clippy::reversed_empty_ranges)] + /// # use kernel::bits::genmask_checked_u8; + /// assert_eq!(genmask_checked_u8(0..=0), Some(0b1)); + /// assert_eq!(genmask_checked_u8(0..=7), Some(u8::MAX)); + /// assert_eq!(genmask_checked_u8(6..=7), Some(0xc0)); + /// + /// // `10` is out of the supported bit range. + /// assert_eq!(genmask_checked_u8(6..=10), None); + /// + /// // Invalid range where the start is bigger than the end. + /// assert_eq!(genmask_checked_u8(5..=2), None); + /// ``` + , + /// # Examples + /// + /// ``` + /// # use kernel::bits::genmask_u8; + /// assert_eq!(genmask_u8(6..=7), 0xc0); + /// assert_eq!(genmask_u8(0..=0), 0b1); + /// assert_eq!(genmask_u8(0..=7), u8::MAX); + /// ``` +); diff --git a/rust/kernel/block/mq.rs b/rust/kernel/block/mq.rs index fb0f393c1cea..831445d37181 100644 --- a/rust/kernel/block/mq.rs +++ b/rust/kernel/block/mq.rs @@ -53,7 +53,7 @@ //! [`GenDiskBuilder`]: gen_disk::GenDiskBuilder //! [`GenDiskBuilder::build`]: gen_disk::GenDiskBuilder::build //! -//! # Example +//! # Examples //! //! ```rust //! use kernel::{ diff --git a/rust/kernel/block/mq/gen_disk.rs b/rust/kernel/block/mq/gen_disk.rs index 798c4ae0bded..e1af0fa302a3 100644 --- a/rust/kernel/block/mq/gen_disk.rs +++ b/rust/kernel/block/mq/gen_disk.rs @@ -3,7 +3,7 @@ //! Generic disk abstraction. //! //! C header: [`include/linux/blkdev.h`](srctree/include/linux/blkdev.h) -//! C header: [`include/linux/blk_mq.h`](srctree/include/linux/blk_mq.h) +//! C header: [`include/linux/blk-mq.h`](srctree/include/linux/blk-mq.h) use crate::block::mq::{raw_writer::RawWriter, Operations, TagSet}; use crate::{bindings, error::from_err_ptr, error::Result, sync::Arc}; @@ -129,7 +129,7 @@ impl GenDiskBuilder { get_unique_id: None, // TODO: Set to THIS_MODULE. Waiting for const_refs_to_static feature to // be merged (unstable in rustc 1.78 which is staged for linux 6.10) - // https://github.com/rust-lang/rust/issues/119618 + // <https://github.com/rust-lang/rust/issues/119618> owner: core::ptr::null_mut(), pr_ops: core::ptr::null_mut(), free_disk: None, @@ -174,9 +174,9 @@ impl GenDiskBuilder { /// /// # Invariants /// -/// - `gendisk` must always point to an initialized and valid `struct gendisk`. -/// - `gendisk` was added to the VFS through a call to -/// `bindings::device_add_disk`. +/// - `gendisk` must always point to an initialized and valid `struct gendisk`. +/// - `gendisk` was added to the VFS through a call to +/// `bindings::device_add_disk`. pub struct GenDisk<T: Operations> { _tagset: Arc<TagSet<T>>, gendisk: *mut bindings::gendisk, diff --git a/rust/kernel/block/mq/operations.rs b/rust/kernel/block/mq/operations.rs index c8646d0d9866..c2b98f507bcb 100644 --- a/rust/kernel/block/mq/operations.rs +++ b/rust/kernel/block/mq/operations.rs @@ -9,6 +9,7 @@ use crate::{ block::mq::request::RequestDataWrapper, block::mq::Request, error::{from_result, Result}, + prelude::*, types::ARef, }; use core::{marker::PhantomData, sync::atomic::AtomicU64, sync::atomic::Ordering}; @@ -35,7 +36,7 @@ pub trait Operations: Sized { /// Called by the kernel to poll the device for completed requests. Only /// used for poll queues. fn poll() -> bool { - crate::build_error(crate::error::VTABLE_DEFAULT_ERROR) + build_error!(crate::error::VTABLE_DEFAULT_ERROR) } } @@ -100,7 +101,7 @@ impl<T: Operations> OperationsVTable<T> { if let Err(e) = ret { e.to_blk_status() } else { - bindings::BLK_STS_OK as _ + bindings::BLK_STS_OK as bindings::blk_status_t } } diff --git a/rust/kernel/block/mq/request.rs b/rust/kernel/block/mq/request.rs index 7943f43b9575..fefd394f064a 100644 --- a/rust/kernel/block/mq/request.rs +++ b/rust/kernel/block/mq/request.rs @@ -12,7 +12,7 @@ use crate::{ }; use core::{ marker::PhantomData, - ptr::{addr_of_mut, NonNull}, + ptr::NonNull, sync::atomic::{AtomicU64, Ordering}, }; @@ -69,7 +69,7 @@ impl<T: Operations> Request<T> { // INVARIANT: By the safety requirements of this function, invariants are upheld. // SAFETY: By the safety requirement of this function, we own a // reference count that we can pass to `ARef`. - unsafe { ARef::from_raw(NonNull::new_unchecked(ptr as *const Self as *mut Self)) } + unsafe { ARef::from_raw(NonNull::new_unchecked(ptr.cast())) } } /// Notify the block layer that a request is going to be processed now. @@ -125,7 +125,12 @@ impl<T: Operations> Request<T> { // success of the call to `try_set_end` guarantees that there are no // `ARef`s pointing to this request. Therefore it is safe to hand it // back to the block layer. - unsafe { bindings::blk_mq_end_request(request_ptr, bindings::BLK_STS_OK as _) }; + unsafe { + bindings::blk_mq_end_request( + request_ptr, + bindings::BLK_STS_OK as bindings::blk_status_t, + ) + }; Ok(()) } @@ -155,7 +160,7 @@ impl<T: Operations> Request<T> { // the private data associated with this request is initialized and // valid. The existence of `&self` guarantees that the private data is // valid as a shared reference. - unsafe { Self::wrapper_ptr(self as *const Self as *mut Self).as_ref() } + unsafe { Self::wrapper_ptr(core::ptr::from_ref(self).cast_mut()).as_ref() } } } @@ -187,7 +192,7 @@ impl RequestDataWrapper { pub(crate) unsafe fn refcount_ptr(this: *mut Self) -> *mut AtomicU64 { // SAFETY: Because of the safety requirements of this function, the // field projection is safe. - unsafe { addr_of_mut!((*this).refcount) } + unsafe { &raw mut (*this).refcount } } } diff --git a/rust/kernel/block/mq/tag_set.rs b/rust/kernel/block/mq/tag_set.rs index d7f175a05d99..c3cf56d52bee 100644 --- a/rust/kernel/block/mq/tag_set.rs +++ b/rust/kernel/block/mq/tag_set.rs @@ -9,13 +9,12 @@ use core::pin::Pin; use crate::{ bindings, block::mq::{operations::OperationsVTable, request::RequestDataWrapper, Operations}, - error, - prelude::PinInit, - try_pin_init, + error::{self, Result}, + prelude::try_pin_init, types::Opaque, }; use core::{convert::TryInto, marker::PhantomData}; -use macros::{pin_data, pinned_drop}; +use pin_init::{pin_data, pinned_drop, PinInit}; /// A wrapper for the C `struct blk_mq_tag_set`. /// @@ -42,7 +41,7 @@ impl<T: Operations> TagSet<T> { // SAFETY: `blk_mq_tag_set` only contains integers and pointers, which // all are allowed to be 0. let tag_set: bindings::blk_mq_tag_set = unsafe { core::mem::zeroed() }; - let tag_set = core::mem::size_of::<RequestDataWrapper>() + let tag_set: Result<_> = core::mem::size_of::<RequestDataWrapper>() .try_into() .map(|cmd_size| { bindings::blk_mq_tag_set { @@ -52,17 +51,19 @@ impl<T: Operations> TagSet<T> { numa_node: bindings::NUMA_NO_NODE, queue_depth: num_tags, cmd_size, - flags: bindings::BLK_MQ_F_SHOULD_MERGE, + flags: 0, driver_data: core::ptr::null_mut::<crate::ffi::c_void>(), nr_maps: num_maps, ..tag_set } - }); + }) + .map(Opaque::new) + .map_err(|e| e.into()); try_pin_init!(TagSet { - inner <- PinInit::<_, error::Error>::pin_chain(Opaque::new(tag_set?), |tag_set| { + inner <- tag_set.pin_chain(|tag_set| { // SAFETY: we do not move out of `tag_set`. - let tag_set = unsafe { Pin::get_unchecked_mut(tag_set) }; + let tag_set: &mut Opaque<_> = unsafe { Pin::get_unchecked_mut(tag_set) }; // SAFETY: `tag_set` is a reference to an initialized `blk_mq_tag_set`. error::to_result( unsafe { bindings::blk_mq_alloc_tag_set(tag_set.get())}) }), diff --git a/rust/kernel/bug.rs b/rust/kernel/bug.rs new file mode 100644 index 000000000000..36aef43e5ebe --- /dev/null +++ b/rust/kernel/bug.rs @@ -0,0 +1,126 @@ +// SPDX-License-Identifier: GPL-2.0 + +// Copyright (C) 2024, 2025 FUJITA Tomonori <fujita.tomonori@gmail.com> + +//! Support for BUG and WARN functionality. +//! +//! C header: [`include/asm-generic/bug.h`](srctree/include/asm-generic/bug.h) + +#[macro_export] +#[doc(hidden)] +#[cfg(all(CONFIG_BUG, not(CONFIG_UML), not(CONFIG_LOONGARCH), not(CONFIG_ARM)))] +#[cfg(CONFIG_DEBUG_BUGVERBOSE)] +macro_rules! warn_flags { + ($flags:expr) => { + const FLAGS: u32 = $crate::bindings::BUGFLAG_WARNING | $flags; + const _FILE: &[u8] = file!().as_bytes(); + // Plus one for null-terminator. + static FILE: [u8; _FILE.len() + 1] = { + let mut bytes = [0; _FILE.len() + 1]; + let mut i = 0; + while i < _FILE.len() { + bytes[i] = _FILE[i]; + i += 1; + } + bytes + }; + + // SAFETY: + // - `file`, `line`, `flags`, and `size` are all compile-time constants or + // symbols, preventing any invalid memory access. + // - The asm block has no side effects and does not modify any registers + // or memory. It is purely for embedding metadata into the ELF section. + unsafe { + $crate::asm!( + concat!( + "/* {size} */", + include!(concat!(env!("OBJTREE"), "/rust/kernel/generated_arch_warn_asm.rs")), + include!(concat!(env!("OBJTREE"), "/rust/kernel/generated_arch_reachable_asm.rs"))); + file = sym FILE, + line = const line!(), + flags = const FLAGS, + size = const ::core::mem::size_of::<$crate::bindings::bug_entry>(), + ); + } + } +} + +#[macro_export] +#[doc(hidden)] +#[cfg(all(CONFIG_BUG, not(CONFIG_UML), not(CONFIG_LOONGARCH), not(CONFIG_ARM)))] +#[cfg(not(CONFIG_DEBUG_BUGVERBOSE))] +macro_rules! warn_flags { + ($flags:expr) => { + const FLAGS: u32 = $crate::bindings::BUGFLAG_WARNING | $flags; + + // SAFETY: + // - `flags` and `size` are all compile-time constants, preventing + // any invalid memory access. + // - The asm block has no side effects and does not modify any registers + // or memory. It is purely for embedding metadata into the ELF section. + unsafe { + $crate::asm!( + concat!( + "/* {size} */", + include!(concat!(env!("OBJTREE"), "/rust/kernel/generated_arch_warn_asm.rs")), + include!(concat!(env!("OBJTREE"), "/rust/kernel/generated_arch_reachable_asm.rs"))); + flags = const FLAGS, + size = const ::core::mem::size_of::<$crate::bindings::bug_entry>(), + ); + } + } +} + +#[macro_export] +#[doc(hidden)] +#[cfg(all(CONFIG_BUG, CONFIG_UML))] +macro_rules! warn_flags { + ($flags:expr) => { + // SAFETY: It is always safe to call `warn_slowpath_fmt()` + // with a valid null-terminated string. + unsafe { + $crate::bindings::warn_slowpath_fmt( + $crate::c_str!(::core::file!()).as_char_ptr(), + line!() as $crate::ffi::c_int, + $flags as $crate::ffi::c_uint, + ::core::ptr::null(), + ); + } + }; +} + +#[macro_export] +#[doc(hidden)] +#[cfg(all(CONFIG_BUG, any(CONFIG_LOONGARCH, CONFIG_ARM)))] +macro_rules! warn_flags { + ($flags:expr) => { + // SAFETY: It is always safe to call `WARN_ON()`. + unsafe { $crate::bindings::WARN_ON(true) } + }; +} + +#[macro_export] +#[doc(hidden)] +#[cfg(not(CONFIG_BUG))] +macro_rules! warn_flags { + ($flags:expr) => {}; +} + +#[doc(hidden)] +pub const fn bugflag_taint(value: u32) -> u32 { + value << 8 +} + +/// Report a warning if `cond` is true and return the condition's evaluation result. +#[macro_export] +macro_rules! warn_on { + ($cond:expr) => {{ + let cond = $cond; + if cond { + const WARN_ON_FLAGS: u32 = $crate::bug::bugflag_taint($crate::bindings::TAINT_WARN); + + $crate::warn_flags!(WARN_ON_FLAGS); + } + cond + }}; +} diff --git a/rust/kernel/build_assert.rs b/rust/kernel/build_assert.rs index 9e37120bc69c..6331b15d7c4d 100644 --- a/rust/kernel/build_assert.rs +++ b/rust/kernel/build_assert.rs @@ -2,6 +2,9 @@ //! Build-time assert. +#[doc(hidden)] +pub use build_error::build_error; + /// Fails the build if the code path calling `build_error!` can possibly be executed. /// /// If the macro is executed in const context, `build_error!` will panic. @@ -11,7 +14,6 @@ /// # Examples /// /// ``` -/// # use kernel::build_error; /// #[inline] /// fn foo(a: usize) -> usize { /// a.checked_add(1).unwrap_or_else(|| build_error!("overflow")) @@ -23,10 +25,10 @@ #[macro_export] macro_rules! build_error { () => {{ - $crate::build_error("") + $crate::build_assert::build_error("") }}; ($msg:expr) => {{ - $crate::build_error($msg) + $crate::build_assert::build_error($msg) }}; } @@ -73,12 +75,12 @@ macro_rules! build_error { macro_rules! build_assert { ($cond:expr $(,)?) => {{ if !$cond { - $crate::build_error(concat!("assertion failed: ", stringify!($cond))); + $crate::build_assert::build_error(concat!("assertion failed: ", stringify!($cond))); } }}; ($cond:expr, $msg:expr) => {{ if !$cond { - $crate::build_error($msg); + $crate::build_assert::build_error($msg); } }}; } diff --git a/rust/kernel/clk.rs b/rust/kernel/clk.rs new file mode 100644 index 000000000000..1e6c8c42fb3a --- /dev/null +++ b/rust/kernel/clk.rs @@ -0,0 +1,330 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Clock abstractions. +//! +//! C header: [`include/linux/clk.h`](srctree/include/linux/clk.h) +//! +//! Reference: <https://docs.kernel.org/driver-api/clk.html> + +use crate::ffi::c_ulong; + +/// The frequency unit. +/// +/// Represents a frequency in hertz, wrapping a [`c_ulong`] value. +/// +/// # Examples +/// +/// ``` +/// use kernel::clk::Hertz; +/// +/// let hz = 1_000_000_000; +/// let rate = Hertz(hz); +/// +/// assert_eq!(rate.as_hz(), hz); +/// assert_eq!(rate, Hertz(hz)); +/// assert_eq!(rate, Hertz::from_khz(hz / 1_000)); +/// assert_eq!(rate, Hertz::from_mhz(hz / 1_000_000)); +/// assert_eq!(rate, Hertz::from_ghz(hz / 1_000_000_000)); +/// ``` +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +pub struct Hertz(pub c_ulong); + +impl Hertz { + const KHZ_TO_HZ: c_ulong = 1_000; + const MHZ_TO_HZ: c_ulong = 1_000_000; + const GHZ_TO_HZ: c_ulong = 1_000_000_000; + + /// Create a new instance from kilohertz (kHz) + pub const fn from_khz(khz: c_ulong) -> Self { + Self(khz * Self::KHZ_TO_HZ) + } + + /// Create a new instance from megahertz (MHz) + pub const fn from_mhz(mhz: c_ulong) -> Self { + Self(mhz * Self::MHZ_TO_HZ) + } + + /// Create a new instance from gigahertz (GHz) + pub const fn from_ghz(ghz: c_ulong) -> Self { + Self(ghz * Self::GHZ_TO_HZ) + } + + /// Get the frequency in hertz + pub const fn as_hz(&self) -> c_ulong { + self.0 + } + + /// Get the frequency in kilohertz + pub const fn as_khz(&self) -> c_ulong { + self.0 / Self::KHZ_TO_HZ + } + + /// Get the frequency in megahertz + pub const fn as_mhz(&self) -> c_ulong { + self.0 / Self::MHZ_TO_HZ + } + + /// Get the frequency in gigahertz + pub const fn as_ghz(&self) -> c_ulong { + self.0 / Self::GHZ_TO_HZ + } +} + +impl From<Hertz> for c_ulong { + fn from(freq: Hertz) -> Self { + freq.0 + } +} + +#[cfg(CONFIG_COMMON_CLK)] +mod common_clk { + use super::Hertz; + use crate::{ + device::Device, + error::{from_err_ptr, to_result, Result}, + prelude::*, + }; + + use core::{ops::Deref, ptr}; + + /// A reference-counted clock. + /// + /// Rust abstraction for the C [`struct clk`]. + /// + /// # Invariants + /// + /// A [`Clk`] instance holds either a pointer to a valid [`struct clk`] created by the C + /// portion of the kernel or a NULL pointer. + /// + /// Instances of this type are reference-counted. Calling [`Clk::get`] ensures that the + /// allocation remains valid for the lifetime of the [`Clk`]. + /// + /// # Examples + /// + /// The following example demonstrates how to obtain and configure a clock for a device. + /// + /// ``` + /// use kernel::c_str; + /// use kernel::clk::{Clk, Hertz}; + /// use kernel::device::Device; + /// use kernel::error::Result; + /// + /// fn configure_clk(dev: &Device) -> Result { + /// let clk = Clk::get(dev, Some(c_str!("apb_clk")))?; + /// + /// clk.prepare_enable()?; + /// + /// let expected_rate = Hertz::from_ghz(1); + /// + /// if clk.rate() != expected_rate { + /// clk.set_rate(expected_rate)?; + /// } + /// + /// clk.disable_unprepare(); + /// Ok(()) + /// } + /// ``` + /// + /// [`struct clk`]: https://docs.kernel.org/driver-api/clk.html + #[repr(transparent)] + pub struct Clk(*mut bindings::clk); + + impl Clk { + /// Gets [`Clk`] corresponding to a [`Device`] and a connection id. + /// + /// Equivalent to the kernel's [`clk_get`] API. + /// + /// [`clk_get`]: https://docs.kernel.org/core-api/kernel-api.html#c.clk_get + pub fn get(dev: &Device, name: Option<&CStr>) -> Result<Self> { + let con_id = name.map_or(ptr::null(), |n| n.as_ptr()); + + // SAFETY: It is safe to call [`clk_get`] for a valid device pointer. + // + // INVARIANT: The reference-count is decremented when [`Clk`] goes out of scope. + Ok(Self(from_err_ptr(unsafe { + bindings::clk_get(dev.as_raw(), con_id) + })?)) + } + + /// Obtain the raw [`struct clk`] pointer. + #[inline] + pub fn as_raw(&self) -> *mut bindings::clk { + self.0 + } + + /// Enable the clock. + /// + /// Equivalent to the kernel's [`clk_enable`] API. + /// + /// [`clk_enable`]: https://docs.kernel.org/core-api/kernel-api.html#c.clk_enable + #[inline] + pub fn enable(&self) -> Result { + // SAFETY: By the type invariants, self.as_raw() is a valid argument for + // [`clk_enable`]. + to_result(unsafe { bindings::clk_enable(self.as_raw()) }) + } + + /// Disable the clock. + /// + /// Equivalent to the kernel's [`clk_disable`] API. + /// + /// [`clk_disable`]: https://docs.kernel.org/core-api/kernel-api.html#c.clk_disable + #[inline] + pub fn disable(&self) { + // SAFETY: By the type invariants, self.as_raw() is a valid argument for + // [`clk_disable`]. + unsafe { bindings::clk_disable(self.as_raw()) }; + } + + /// Prepare the clock. + /// + /// Equivalent to the kernel's [`clk_prepare`] API. + /// + /// [`clk_prepare`]: https://docs.kernel.org/core-api/kernel-api.html#c.clk_prepare + #[inline] + pub fn prepare(&self) -> Result { + // SAFETY: By the type invariants, self.as_raw() is a valid argument for + // [`clk_prepare`]. + to_result(unsafe { bindings::clk_prepare(self.as_raw()) }) + } + + /// Unprepare the clock. + /// + /// Equivalent to the kernel's [`clk_unprepare`] API. + /// + /// [`clk_unprepare`]: https://docs.kernel.org/core-api/kernel-api.html#c.clk_unprepare + #[inline] + pub fn unprepare(&self) { + // SAFETY: By the type invariants, self.as_raw() is a valid argument for + // [`clk_unprepare`]. + unsafe { bindings::clk_unprepare(self.as_raw()) }; + } + + /// Prepare and enable the clock. + /// + /// Equivalent to calling [`Clk::prepare`] followed by [`Clk::enable`]. + #[inline] + pub fn prepare_enable(&self) -> Result { + // SAFETY: By the type invariants, self.as_raw() is a valid argument for + // [`clk_prepare_enable`]. + to_result(unsafe { bindings::clk_prepare_enable(self.as_raw()) }) + } + + /// Disable and unprepare the clock. + /// + /// Equivalent to calling [`Clk::disable`] followed by [`Clk::unprepare`]. + #[inline] + pub fn disable_unprepare(&self) { + // SAFETY: By the type invariants, self.as_raw() is a valid argument for + // [`clk_disable_unprepare`]. + unsafe { bindings::clk_disable_unprepare(self.as_raw()) }; + } + + /// Get clock's rate. + /// + /// Equivalent to the kernel's [`clk_get_rate`] API. + /// + /// [`clk_get_rate`]: https://docs.kernel.org/core-api/kernel-api.html#c.clk_get_rate + #[inline] + pub fn rate(&self) -> Hertz { + // SAFETY: By the type invariants, self.as_raw() is a valid argument for + // [`clk_get_rate`]. + Hertz(unsafe { bindings::clk_get_rate(self.as_raw()) }) + } + + /// Set clock's rate. + /// + /// Equivalent to the kernel's [`clk_set_rate`] API. + /// + /// [`clk_set_rate`]: https://docs.kernel.org/core-api/kernel-api.html#c.clk_set_rate + #[inline] + pub fn set_rate(&self, rate: Hertz) -> Result { + // SAFETY: By the type invariants, self.as_raw() is a valid argument for + // [`clk_set_rate`]. + to_result(unsafe { bindings::clk_set_rate(self.as_raw(), rate.as_hz()) }) + } + } + + impl Drop for Clk { + fn drop(&mut self) { + // SAFETY: By the type invariants, self.as_raw() is a valid argument for [`clk_put`]. + unsafe { bindings::clk_put(self.as_raw()) }; + } + } + + /// A reference-counted optional clock. + /// + /// A lightweight wrapper around an optional [`Clk`]. An [`OptionalClk`] represents a [`Clk`] + /// that a driver can function without but may improve performance or enable additional + /// features when available. + /// + /// # Invariants + /// + /// An [`OptionalClk`] instance encapsulates a [`Clk`] with either a valid [`struct clk`] or + /// `NULL` pointer. + /// + /// Instances of this type are reference-counted. Calling [`OptionalClk::get`] ensures that the + /// allocation remains valid for the lifetime of the [`OptionalClk`]. + /// + /// # Examples + /// + /// The following example demonstrates how to obtain and configure an optional clock for a + /// device. The code functions correctly whether or not the clock is available. + /// + /// ``` + /// use kernel::c_str; + /// use kernel::clk::{OptionalClk, Hertz}; + /// use kernel::device::Device; + /// use kernel::error::Result; + /// + /// fn configure_clk(dev: &Device) -> Result { + /// let clk = OptionalClk::get(dev, Some(c_str!("apb_clk")))?; + /// + /// clk.prepare_enable()?; + /// + /// let expected_rate = Hertz::from_ghz(1); + /// + /// if clk.rate() != expected_rate { + /// clk.set_rate(expected_rate)?; + /// } + /// + /// clk.disable_unprepare(); + /// Ok(()) + /// } + /// ``` + /// + /// [`struct clk`]: https://docs.kernel.org/driver-api/clk.html + pub struct OptionalClk(Clk); + + impl OptionalClk { + /// Gets [`OptionalClk`] corresponding to a [`Device`] and a connection id. + /// + /// Equivalent to the kernel's [`clk_get_optional`] API. + /// + /// [`clk_get_optional`]: + /// https://docs.kernel.org/core-api/kernel-api.html#c.clk_get_optional + pub fn get(dev: &Device, name: Option<&CStr>) -> Result<Self> { + let con_id = name.map_or(ptr::null(), |n| n.as_ptr()); + + // SAFETY: It is safe to call [`clk_get_optional`] for a valid device pointer. + // + // INVARIANT: The reference-count is decremented when [`OptionalClk`] goes out of + // scope. + Ok(Self(Clk(from_err_ptr(unsafe { + bindings::clk_get_optional(dev.as_raw(), con_id) + })?))) + } + } + + // Make [`OptionalClk`] behave like [`Clk`]. + impl Deref for OptionalClk { + type Target = Clk; + + fn deref(&self) -> &Clk { + &self.0 + } + } +} + +#[cfg(CONFIG_COMMON_CLK)] +pub use common_clk::*; diff --git a/rust/kernel/configfs.rs b/rust/kernel/configfs.rs new file mode 100644 index 000000000000..2736b798cdc6 --- /dev/null +++ b/rust/kernel/configfs.rs @@ -0,0 +1,1041 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! configfs interface: Userspace-driven Kernel Object Configuration +//! +//! configfs is an in-memory pseudo file system for configuration of kernel +//! modules. Please see the [C documentation] for details and intended use of +//! configfs. +//! +//! This module does not support the following configfs features: +//! +//! - Items. All group children are groups. +//! - Symlink support. +//! - `disconnect_notify` hook. +//! - Default groups. +//! +//! See the [`rust_configfs.rs`] sample for a full example use of this module. +//! +//! C header: [`include/linux/configfs.h`](srctree/include/linux/configfs.h) +//! +//! # Examples +//! +//! ```ignore +//! use kernel::alloc::flags; +//! use kernel::c_str; +//! use kernel::configfs_attrs; +//! use kernel::configfs; +//! use kernel::new_mutex; +//! use kernel::page::PAGE_SIZE; +//! use kernel::sync::Mutex; +//! use kernel::ThisModule; +//! +//! #[pin_data] +//! struct RustConfigfs { +//! #[pin] +//! config: configfs::Subsystem<Configuration>, +//! } +//! +//! impl kernel::InPlaceModule for RustConfigfs { +//! fn init(_module: &'static ThisModule) -> impl PinInit<Self, Error> { +//! pr_info!("Rust configfs sample (init)\n"); +//! +//! let item_type = configfs_attrs! { +//! container: configfs::Subsystem<Configuration>, +//! data: Configuration, +//! attributes: [ +//! message: 0, +//! bar: 1, +//! ], +//! }; +//! +//! try_pin_init!(Self { +//! config <- configfs::Subsystem::new( +//! c_str!("rust_configfs"), item_type, Configuration::new() +//! ), +//! }) +//! } +//! } +//! +//! #[pin_data] +//! struct Configuration { +//! message: &'static CStr, +//! #[pin] +//! bar: Mutex<(KBox<[u8; PAGE_SIZE]>, usize)>, +//! } +//! +//! impl Configuration { +//! fn new() -> impl PinInit<Self, Error> { +//! try_pin_init!(Self { +//! message: c_str!("Hello World\n"), +//! bar <- new_mutex!((KBox::new([0; PAGE_SIZE], flags::GFP_KERNEL)?, 0)), +//! }) +//! } +//! } +//! +//! #[vtable] +//! impl configfs::AttributeOperations<0> for Configuration { +//! type Data = Configuration; +//! +//! fn show(container: &Configuration, page: &mut [u8; PAGE_SIZE]) -> Result<usize> { +//! pr_info!("Show message\n"); +//! let data = container.message; +//! page[0..data.len()].copy_from_slice(data); +//! Ok(data.len()) +//! } +//! } +//! +//! #[vtable] +//! impl configfs::AttributeOperations<1> for Configuration { +//! type Data = Configuration; +//! +//! fn show(container: &Configuration, page: &mut [u8; PAGE_SIZE]) -> Result<usize> { +//! pr_info!("Show bar\n"); +//! let guard = container.bar.lock(); +//! let data = guard.0.as_slice(); +//! let len = guard.1; +//! page[0..len].copy_from_slice(&data[0..len]); +//! Ok(len) +//! } +//! +//! fn store(container: &Configuration, page: &[u8]) -> Result { +//! pr_info!("Store bar\n"); +//! let mut guard = container.bar.lock(); +//! guard.0[0..page.len()].copy_from_slice(page); +//! guard.1 = page.len(); +//! Ok(()) +//! } +//! } +//! ``` +//! +//! [C documentation]: srctree/Documentation/filesystems/configfs.rst +//! [`rust_configfs.rs`]: srctree/samples/rust/rust_configfs.rs + +use crate::alloc::flags; +use crate::container_of; +use crate::page::PAGE_SIZE; +use crate::prelude::*; +use crate::str::CString; +use crate::sync::Arc; +use crate::sync::ArcBorrow; +use crate::types::Opaque; +use core::cell::UnsafeCell; +use core::marker::PhantomData; + +/// A configfs subsystem. +/// +/// This is the top level entrypoint for a configfs hierarchy. To register +/// with configfs, embed a field of this type into your kernel module struct. +#[pin_data(PinnedDrop)] +pub struct Subsystem<Data> { + #[pin] + subsystem: Opaque<bindings::configfs_subsystem>, + #[pin] + data: Data, +} + +// SAFETY: We do not provide any operations on `Subsystem`. +unsafe impl<Data> Sync for Subsystem<Data> {} + +// SAFETY: Ownership of `Subsystem` can safely be transferred to other threads. +unsafe impl<Data> Send for Subsystem<Data> {} + +impl<Data> Subsystem<Data> { + /// Create an initializer for a [`Subsystem`]. + /// + /// The subsystem will appear in configfs as a directory name given by + /// `name`. The attributes available in directory are specified by + /// `item_type`. + pub fn new( + name: &'static CStr, + item_type: &'static ItemType<Subsystem<Data>, Data>, + data: impl PinInit<Data, Error>, + ) -> impl PinInit<Self, Error> { + try_pin_init!(Self { + subsystem <- pin_init::init_zeroed().chain( + |place: &mut Opaque<bindings::configfs_subsystem>| { + // SAFETY: We initialized the required fields of `place.group` above. + unsafe { + bindings::config_group_init_type_name( + &mut (*place.get()).su_group, + name.as_ptr(), + item_type.as_ptr(), + ) + }; + + // SAFETY: `place.su_mutex` is valid for use as a mutex. + unsafe { + bindings::__mutex_init( + &mut (*place.get()).su_mutex, + kernel::optional_name!().as_char_ptr(), + kernel::static_lock_class!().as_ptr(), + ) + } + Ok(()) + } + ), + data <- data, + }) + .pin_chain(|this| { + crate::error::to_result( + // SAFETY: We initialized `this.subsystem` according to C API contract above. + unsafe { bindings::configfs_register_subsystem(this.subsystem.get()) }, + ) + }) + } +} + +#[pinned_drop] +impl<Data> PinnedDrop for Subsystem<Data> { + fn drop(self: Pin<&mut Self>) { + // SAFETY: We registered `self.subsystem` in the initializer returned by `Self::new`. + unsafe { bindings::configfs_unregister_subsystem(self.subsystem.get()) }; + // SAFETY: We initialized the mutex in `Subsystem::new`. + unsafe { bindings::mutex_destroy(&raw mut (*self.subsystem.get()).su_mutex) }; + } +} + +/// Trait that allows offset calculations for structs that embed a +/// `bindings::config_group`. +/// +/// Users of the configfs API should not need to implement this trait. +/// +/// # Safety +/// +/// - Implementers of this trait must embed a `bindings::config_group`. +/// - Methods must be implemented according to method documentation. +pub unsafe trait HasGroup<Data> { + /// Return the address of the `bindings::config_group` embedded in [`Self`]. + /// + /// # Safety + /// + /// - `this` must be a valid allocation of at least the size of [`Self`]. + unsafe fn group(this: *const Self) -> *const bindings::config_group; + + /// Return the address of the [`Self`] that `group` is embedded in. + /// + /// # Safety + /// + /// - `group` must point to the `bindings::config_group` that is embedded in + /// [`Self`]. + unsafe fn container_of(group: *const bindings::config_group) -> *const Self; +} + +// SAFETY: `Subsystem<Data>` embeds a field of type `bindings::config_group` +// within the `subsystem` field. +unsafe impl<Data> HasGroup<Data> for Subsystem<Data> { + unsafe fn group(this: *const Self) -> *const bindings::config_group { + // SAFETY: By impl and function safety requirement this projection is in bounds. + unsafe { &raw const (*(*this).subsystem.get()).su_group } + } + + unsafe fn container_of(group: *const bindings::config_group) -> *const Self { + // SAFETY: By impl and function safety requirement this projection is in bounds. + let c_subsys_ptr = unsafe { container_of!(group, bindings::configfs_subsystem, su_group) }; + let opaque_ptr = c_subsys_ptr.cast::<Opaque<bindings::configfs_subsystem>>(); + // SAFETY: By impl and function safety requirement, `opaque_ptr` and the + // pointer it returns, are within the same allocation. + unsafe { container_of!(opaque_ptr, Subsystem<Data>, subsystem) } + } +} + +/// A configfs group. +/// +/// To add a subgroup to configfs, pass this type as `ctype` to +/// [`crate::configfs_attrs`] when creating a group in [`GroupOperations::make_group`]. +#[pin_data] +pub struct Group<Data> { + #[pin] + group: Opaque<bindings::config_group>, + #[pin] + data: Data, +} + +impl<Data> Group<Data> { + /// Create an initializer for a new group. + /// + /// When instantiated, the group will appear as a directory with the name + /// given by `name` and it will contain attributes specified by `item_type`. + pub fn new( + name: CString, + item_type: &'static ItemType<Group<Data>, Data>, + data: impl PinInit<Data, Error>, + ) -> impl PinInit<Self, Error> { + try_pin_init!(Self { + group <- pin_init::init_zeroed().chain(|v: &mut Opaque<bindings::config_group>| { + let place = v.get(); + let name = name.as_bytes_with_nul().as_ptr(); + // SAFETY: It is safe to initialize a group once it has been zeroed. + unsafe { + bindings::config_group_init_type_name(place, name.cast(), item_type.as_ptr()) + }; + Ok(()) + }), + data <- data, + }) + } +} + +// SAFETY: `Group<Data>` embeds a field of type `bindings::config_group` +// within the `group` field. +unsafe impl<Data> HasGroup<Data> for Group<Data> { + unsafe fn group(this: *const Self) -> *const bindings::config_group { + Opaque::cast_into( + // SAFETY: By impl and function safety requirements this field + // projection is within bounds of the allocation. + unsafe { &raw const (*this).group }, + ) + } + + unsafe fn container_of(group: *const bindings::config_group) -> *const Self { + let opaque_ptr = group.cast::<Opaque<bindings::config_group>>(); + // SAFETY: By impl and function safety requirement, `opaque_ptr` and + // pointer it returns will be in the same allocation. + unsafe { container_of!(opaque_ptr, Self, group) } + } +} + +/// # Safety +/// +/// `this` must be a valid pointer. +/// +/// If `this` does not represent the root group of a configfs subsystem, +/// `this` must be a pointer to a `bindings::config_group` embedded in a +/// `Group<Parent>`. +/// +/// Otherwise, `this` must be a pointer to a `bindings::config_group` that +/// is embedded in a `bindings::configfs_subsystem` that is embedded in a +/// `Subsystem<Parent>`. +unsafe fn get_group_data<'a, Parent>(this: *mut bindings::config_group) -> &'a Parent { + // SAFETY: `this` is a valid pointer. + let is_root = unsafe { (*this).cg_subsys.is_null() }; + + if !is_root { + // SAFETY: By C API contact,`this` was returned from a call to + // `make_group`. The pointer is known to be embedded within a + // `Group<Parent>`. + unsafe { &(*Group::<Parent>::container_of(this)).data } + } else { + // SAFETY: By C API contract, `this` is a pointer to the + // `bindings::config_group` field within a `Subsystem<Parent>`. + unsafe { &(*Subsystem::container_of(this)).data } + } +} + +struct GroupOperationsVTable<Parent, Child>(PhantomData<(Parent, Child)>); + +impl<Parent, Child> GroupOperationsVTable<Parent, Child> +where + Parent: GroupOperations<Child = Child>, + Child: 'static, +{ + /// # Safety + /// + /// `this` must be a valid pointer. + /// + /// If `this` does not represent the root group of a configfs subsystem, + /// `this` must be a pointer to a `bindings::config_group` embedded in a + /// `Group<Parent>`. + /// + /// Otherwise, `this` must be a pointer to a `bindings::config_group` that + /// is embedded in a `bindings::configfs_subsystem` that is embedded in a + /// `Subsystem<Parent>`. + /// + /// `name` must point to a null terminated string. + unsafe extern "C" fn make_group( + this: *mut bindings::config_group, + name: *const kernel::ffi::c_char, + ) -> *mut bindings::config_group { + // SAFETY: By function safety requirements of this function, this call + // is safe. + let parent_data = unsafe { get_group_data(this) }; + + let group_init = match Parent::make_group( + parent_data, + // SAFETY: By function safety requirements, name points to a null + // terminated string. + unsafe { CStr::from_char_ptr(name) }, + ) { + Ok(init) => init, + Err(e) => return e.to_ptr(), + }; + + let child_group = <Arc<Group<Child>> as InPlaceInit<Group<Child>>>::try_pin_init( + group_init, + flags::GFP_KERNEL, + ); + + match child_group { + Ok(child_group) => { + let child_group_ptr = child_group.into_raw(); + // SAFETY: We allocated the pointee of `child_ptr` above as a + // `Group<Child>`. + unsafe { Group::<Child>::group(child_group_ptr) }.cast_mut() + } + Err(e) => e.to_ptr(), + } + } + + /// # Safety + /// + /// If `this` does not represent the root group of a configfs subsystem, + /// `this` must be a pointer to a `bindings::config_group` embedded in a + /// `Group<Parent>`. + /// + /// Otherwise, `this` must be a pointer to a `bindings::config_group` that + /// is embedded in a `bindings::configfs_subsystem` that is embedded in a + /// `Subsystem<Parent>`. + /// + /// `item` must point to a `bindings::config_item` within a + /// `bindings::config_group` within a `Group<Child>`. + unsafe extern "C" fn drop_item( + this: *mut bindings::config_group, + item: *mut bindings::config_item, + ) { + // SAFETY: By function safety requirements of this function, this call + // is safe. + let parent_data = unsafe { get_group_data(this) }; + + // SAFETY: By function safety requirements, `item` is embedded in a + // `config_group`. + let c_child_group_ptr = unsafe { container_of!(item, bindings::config_group, cg_item) }; + // SAFETY: By function safety requirements, `c_child_group_ptr` is + // embedded within a `Group<Child>`. + let r_child_group_ptr = unsafe { Group::<Child>::container_of(c_child_group_ptr) }; + + if Parent::HAS_DROP_ITEM { + // SAFETY: We called `into_raw` to produce `r_child_group_ptr` in + // `make_group`. + let arc: Arc<Group<Child>> = unsafe { Arc::from_raw(r_child_group_ptr.cast_mut()) }; + + Parent::drop_item(parent_data, arc.as_arc_borrow()); + arc.into_raw(); + } + + // SAFETY: By C API contract, we are required to drop a refcount on + // `item`. + unsafe { bindings::config_item_put(item) }; + } + + const VTABLE: bindings::configfs_group_operations = bindings::configfs_group_operations { + make_item: None, + make_group: Some(Self::make_group), + disconnect_notify: None, + drop_item: Some(Self::drop_item), + is_visible: None, + is_bin_visible: None, + }; + + const fn vtable_ptr() -> *const bindings::configfs_group_operations { + &Self::VTABLE + } +} + +struct ItemOperationsVTable<Container, Data>(PhantomData<(Container, Data)>); + +impl<Data> ItemOperationsVTable<Group<Data>, Data> +where + Data: 'static, +{ + /// # Safety + /// + /// `this` must be a pointer to a `bindings::config_group` embedded in a + /// `Group<Parent>`. + /// + /// This function will destroy the pointee of `this`. The pointee of `this` + /// must not be accessed after the function returns. + unsafe extern "C" fn release(this: *mut bindings::config_item) { + // SAFETY: By function safety requirements, `this` is embedded in a + // `config_group`. + let c_group_ptr = unsafe { kernel::container_of!(this, bindings::config_group, cg_item) }; + // SAFETY: By function safety requirements, `c_group_ptr` is + // embedded within a `Group<Data>`. + let r_group_ptr = unsafe { Group::<Data>::container_of(c_group_ptr) }; + + // SAFETY: We called `into_raw` on `r_group_ptr` in + // `make_group`. + let pin_self: Arc<Group<Data>> = unsafe { Arc::from_raw(r_group_ptr.cast_mut()) }; + drop(pin_self); + } + + const VTABLE: bindings::configfs_item_operations = bindings::configfs_item_operations { + release: Some(Self::release), + allow_link: None, + drop_link: None, + }; + + const fn vtable_ptr() -> *const bindings::configfs_item_operations { + &Self::VTABLE + } +} + +impl<Data> ItemOperationsVTable<Subsystem<Data>, Data> { + const VTABLE: bindings::configfs_item_operations = bindings::configfs_item_operations { + release: None, + allow_link: None, + drop_link: None, + }; + + const fn vtable_ptr() -> *const bindings::configfs_item_operations { + &Self::VTABLE + } +} + +/// Operations implemented by configfs groups that can create subgroups. +/// +/// Implement this trait on structs that embed a [`Subsystem`] or a [`Group`]. +#[vtable] +pub trait GroupOperations { + /// The child data object type. + /// + /// This group will create subgroups (subdirectories) backed by this kind of + /// object. + type Child: 'static; + + /// Creates a new subgroup. + /// + /// The kernel will call this method in response to `mkdir(2)` in the + /// directory representing `this`. + /// + /// To accept the request to create a group, implementations should + /// return an initializer of a `Group<Self::Child>`. To prevent creation, + /// return a suitable error. + fn make_group(&self, name: &CStr) -> Result<impl PinInit<Group<Self::Child>, Error>>; + + /// Prepares the group for removal from configfs. + /// + /// The kernel will call this method before the directory representing `_child` is removed from + /// configfs. + /// + /// Implementations can use this method to do house keeping before configfs drops its + /// reference to `Child`. + /// + /// NOTE: "drop" in the name of this function is not related to the Rust drop term. Rather, the + /// name is inherited from the callback name in the underlying C code. + fn drop_item(&self, _child: ArcBorrow<'_, Group<Self::Child>>) { + kernel::build_error!(kernel::error::VTABLE_DEFAULT_ERROR) + } +} + +/// A configfs attribute. +/// +/// An attribute appears as a file in configfs, inside a folder that represent +/// the group that the attribute belongs to. +#[repr(transparent)] +pub struct Attribute<const ID: u64, O, Data> { + attribute: Opaque<bindings::configfs_attribute>, + _p: PhantomData<(O, Data)>, +} + +// SAFETY: We do not provide any operations on `Attribute`. +unsafe impl<const ID: u64, O, Data> Sync for Attribute<ID, O, Data> {} + +// SAFETY: Ownership of `Attribute` can safely be transferred to other threads. +unsafe impl<const ID: u64, O, Data> Send for Attribute<ID, O, Data> {} + +impl<const ID: u64, O, Data> Attribute<ID, O, Data> +where + O: AttributeOperations<ID, Data = Data>, +{ + /// # Safety + /// + /// `item` must be embedded in a `bindings::config_group`. + /// + /// If `item` does not represent the root group of a configfs subsystem, + /// the group must be embedded in a `Group<Data>`. + /// + /// Otherwise, the group must be a embedded in a + /// `bindings::configfs_subsystem` that is embedded in a `Subsystem<Data>`. + /// + /// `page` must point to a writable buffer of size at least [`PAGE_SIZE`]. + unsafe extern "C" fn show( + item: *mut bindings::config_item, + page: *mut kernel::ffi::c_char, + ) -> isize { + let c_group: *mut bindings::config_group = + // SAFETY: By function safety requirements, `item` is embedded in a + // `config_group`. + unsafe { container_of!(item, bindings::config_group, cg_item) }; + + // SAFETY: The function safety requirements for this function satisfy + // the conditions for this call. + let data: &Data = unsafe { get_group_data(c_group) }; + + // SAFETY: By function safety requirements, `page` is writable for `PAGE_SIZE`. + let ret = O::show(data, unsafe { &mut *(page.cast::<[u8; PAGE_SIZE]>()) }); + + match ret { + Ok(size) => size as isize, + Err(err) => err.to_errno() as isize, + } + } + + /// # Safety + /// + /// `item` must be embedded in a `bindings::config_group`. + /// + /// If `item` does not represent the root group of a configfs subsystem, + /// the group must be embedded in a `Group<Data>`. + /// + /// Otherwise, the group must be a embedded in a + /// `bindings::configfs_subsystem` that is embedded in a `Subsystem<Data>`. + /// + /// `page` must point to a readable buffer of size at least `size`. + unsafe extern "C" fn store( + item: *mut bindings::config_item, + page: *const kernel::ffi::c_char, + size: usize, + ) -> isize { + let c_group: *mut bindings::config_group = + // SAFETY: By function safety requirements, `item` is embedded in a + // `config_group`. + unsafe { container_of!(item, bindings::config_group, cg_item) }; + + // SAFETY: The function safety requirements for this function satisfy + // the conditions for this call. + let data: &Data = unsafe { get_group_data(c_group) }; + + let ret = O::store( + data, + // SAFETY: By function safety requirements, `page` is readable + // for at least `size`. + unsafe { core::slice::from_raw_parts(page.cast(), size) }, + ); + + match ret { + Ok(()) => size as isize, + Err(err) => err.to_errno() as isize, + } + } + + /// Create a new attribute. + /// + /// The attribute will appear as a file with name given by `name`. + pub const fn new(name: &'static CStr) -> Self { + Self { + attribute: Opaque::new(bindings::configfs_attribute { + ca_name: name.as_char_ptr(), + ca_owner: core::ptr::null_mut(), + ca_mode: 0o660, + show: Some(Self::show), + store: if O::HAS_STORE { + Some(Self::store) + } else { + None + }, + }), + _p: PhantomData, + } + } +} + +/// Operations supported by an attribute. +/// +/// Implement this trait on type and pass that type as generic parameter when +/// creating an [`Attribute`]. The type carrying the implementation serve no +/// purpose other than specifying the attribute operations. +/// +/// This trait must be implemented on the `Data` type of for types that +/// implement `HasGroup<Data>`. The trait must be implemented once for each +/// attribute of the group. The constant type parameter `ID` maps the +/// implementation to a specific `Attribute`. `ID` must be passed when declaring +/// attributes via the [`kernel::configfs_attrs`] macro, to tie +/// `AttributeOperations` implementations to concrete named attributes. +#[vtable] +pub trait AttributeOperations<const ID: u64 = 0> { + /// The type of the object that contains the field that is backing the + /// attribute for this operation. + type Data; + + /// Renders the value of an attribute. + /// + /// This function is called by the kernel to read the value of an attribute. + /// + /// Implementations should write the rendering of the attribute to `page` + /// and return the number of bytes written. + fn show(data: &Self::Data, page: &mut [u8; PAGE_SIZE]) -> Result<usize>; + + /// Stores the value of an attribute. + /// + /// This function is called by the kernel to update the value of an attribute. + /// + /// Implementations should parse the value from `page` and update internal + /// state to reflect the parsed value. + fn store(_data: &Self::Data, _page: &[u8]) -> Result { + kernel::build_error!(kernel::error::VTABLE_DEFAULT_ERROR) + } +} + +/// A list of attributes. +/// +/// This type is used to construct a new [`ItemType`]. It represents a list of +/// [`Attribute`] that will appear in the directory representing a [`Group`]. +/// Users should not directly instantiate this type, rather they should use the +/// [`kernel::configfs_attrs`] macro to declare a static set of attributes for a +/// group. +/// +/// # Note +/// +/// Instances of this type are constructed statically at compile by the +/// [`kernel::configfs_attrs`] macro. +#[repr(transparent)] +pub struct AttributeList<const N: usize, Data>( + /// Null terminated Array of pointers to [`Attribute`]. The type is [`c_void`] + /// to conform to the C API. + UnsafeCell<[*mut kernel::ffi::c_void; N]>, + PhantomData<Data>, +); + +// SAFETY: Ownership of `AttributeList` can safely be transferred to other threads. +unsafe impl<const N: usize, Data> Send for AttributeList<N, Data> {} + +// SAFETY: We do not provide any operations on `AttributeList` that need synchronization. +unsafe impl<const N: usize, Data> Sync for AttributeList<N, Data> {} + +impl<const N: usize, Data> AttributeList<N, Data> { + /// # Safety + /// + /// This function must only be called by the [`kernel::configfs_attrs`] + /// macro. + #[doc(hidden)] + pub const unsafe fn new() -> Self { + Self(UnsafeCell::new([core::ptr::null_mut(); N]), PhantomData) + } + + /// # Safety + /// + /// The caller must ensure that there are no other concurrent accesses to + /// `self`. That is, the caller has exclusive access to `self.` + #[doc(hidden)] + pub const unsafe fn add<const I: usize, const ID: u64, O>( + &'static self, + attribute: &'static Attribute<ID, O, Data>, + ) where + O: AttributeOperations<ID, Data = Data>, + { + // We need a space at the end of our list for a null terminator. + const { assert!(I < N - 1, "Invalid attribute index") }; + + // SAFETY: By function safety requirements, we have exclusive access to + // `self` and the reference created below will be exclusive. + unsafe { (&mut *self.0.get())[I] = core::ptr::from_ref(attribute).cast_mut().cast() }; + } +} + +/// A representation of the attributes that will appear in a [`Group`] or +/// [`Subsystem`]. +/// +/// Users should not directly instantiate objects of this type. Rather, they +/// should use the [`kernel::configfs_attrs`] macro to statically declare the +/// shape of a [`Group`] or [`Subsystem`]. +#[pin_data] +pub struct ItemType<Container, Data> { + #[pin] + item_type: Opaque<bindings::config_item_type>, + _p: PhantomData<(Container, Data)>, +} + +// SAFETY: We do not provide any operations on `ItemType` that need synchronization. +unsafe impl<Container, Data> Sync for ItemType<Container, Data> {} + +// SAFETY: Ownership of `ItemType` can safely be transferred to other threads. +unsafe impl<Container, Data> Send for ItemType<Container, Data> {} + +macro_rules! impl_item_type { + ($tpe:ty) => { + impl<Data> ItemType<$tpe, Data> { + #[doc(hidden)] + pub const fn new_with_child_ctor<const N: usize, Child>( + owner: &'static ThisModule, + attributes: &'static AttributeList<N, Data>, + ) -> Self + where + Data: GroupOperations<Child = Child>, + Child: 'static, + { + Self { + item_type: Opaque::new(bindings::config_item_type { + ct_owner: owner.as_ptr(), + ct_group_ops: GroupOperationsVTable::<Data, Child>::vtable_ptr().cast_mut(), + ct_item_ops: ItemOperationsVTable::<$tpe, Data>::vtable_ptr().cast_mut(), + ct_attrs: core::ptr::from_ref(attributes).cast_mut().cast(), + ct_bin_attrs: core::ptr::null_mut(), + }), + _p: PhantomData, + } + } + + #[doc(hidden)] + pub const fn new<const N: usize>( + owner: &'static ThisModule, + attributes: &'static AttributeList<N, Data>, + ) -> Self { + Self { + item_type: Opaque::new(bindings::config_item_type { + ct_owner: owner.as_ptr(), + ct_group_ops: core::ptr::null_mut(), + ct_item_ops: ItemOperationsVTable::<$tpe, Data>::vtable_ptr().cast_mut(), + ct_attrs: core::ptr::from_ref(attributes).cast_mut().cast(), + ct_bin_attrs: core::ptr::null_mut(), + }), + _p: PhantomData, + } + } + } + }; +} + +impl_item_type!(Subsystem<Data>); +impl_item_type!(Group<Data>); + +impl<Container, Data> ItemType<Container, Data> { + fn as_ptr(&self) -> *const bindings::config_item_type { + self.item_type.get() + } +} + +/// Define a list of configfs attributes statically. +/// +/// Invoking the macro in the following manner: +/// +/// ```ignore +/// let item_type = configfs_attrs! { +/// container: configfs::Subsystem<Configuration>, +/// data: Configuration, +/// child: Child, +/// attributes: [ +/// message: 0, +/// bar: 1, +/// ], +/// }; +/// ``` +/// +/// Expands the following output: +/// +/// ```ignore +/// let item_type = { +/// static CONFIGURATION_MESSAGE_ATTR: kernel::configfs::Attribute< +/// 0, +/// Configuration, +/// Configuration, +/// > = unsafe { +/// kernel::configfs::Attribute::new({ +/// const S: &str = "message\u{0}"; +/// const C: &kernel::str::CStr = match kernel::str::CStr::from_bytes_with_nul( +/// S.as_bytes() +/// ) { +/// Ok(v) => v, +/// Err(_) => { +/// core::panicking::panic_fmt(core::const_format_args!( +/// "string contains interior NUL" +/// )); +/// } +/// }; +/// C +/// }) +/// }; +/// +/// static CONFIGURATION_BAR_ATTR: kernel::configfs::Attribute< +/// 1, +/// Configuration, +/// Configuration +/// > = unsafe { +/// kernel::configfs::Attribute::new({ +/// const S: &str = "bar\u{0}"; +/// const C: &kernel::str::CStr = match kernel::str::CStr::from_bytes_with_nul( +/// S.as_bytes() +/// ) { +/// Ok(v) => v, +/// Err(_) => { +/// core::panicking::panic_fmt(core::const_format_args!( +/// "string contains interior NUL" +/// )); +/// } +/// }; +/// C +/// }) +/// }; +/// +/// const N: usize = (1usize + (1usize + 0usize)) + 1usize; +/// +/// static CONFIGURATION_ATTRS: kernel::configfs::AttributeList<N, Configuration> = +/// unsafe { kernel::configfs::AttributeList::new() }; +/// +/// { +/// const N: usize = 0usize; +/// unsafe { CONFIGURATION_ATTRS.add::<N, 0, _>(&CONFIGURATION_MESSAGE_ATTR) }; +/// } +/// +/// { +/// const N: usize = (1usize + 0usize); +/// unsafe { CONFIGURATION_ATTRS.add::<N, 1, _>(&CONFIGURATION_BAR_ATTR) }; +/// } +/// +/// static CONFIGURATION_TPE: +/// kernel::configfs::ItemType<configfs::Subsystem<Configuration> ,Configuration> +/// = kernel::configfs::ItemType::< +/// configfs::Subsystem<Configuration>, +/// Configuration +/// >::new_with_child_ctor::<N,Child>( +/// &THIS_MODULE, +/// &CONFIGURATION_ATTRS +/// ); +/// +/// &CONFIGURATION_TPE +/// } +/// ``` +#[macro_export] +macro_rules! configfs_attrs { + ( + container: $container:ty, + data: $data:ty, + attributes: [ + $($name:ident: $attr:literal),* $(,)? + ] $(,)? + ) => { + $crate::configfs_attrs!( + count: + @container($container), + @data($data), + @child(), + @no_child(x), + @attrs($($name $attr)*), + @eat($($name $attr,)*), + @assign(), + @cnt(0usize), + ) + }; + ( + container: $container:ty, + data: $data:ty, + child: $child:ty, + attributes: [ + $($name:ident: $attr:literal),* $(,)? + ] $(,)? + ) => { + $crate::configfs_attrs!( + count: + @container($container), + @data($data), + @child($child), + @no_child(), + @attrs($($name $attr)*), + @eat($($name $attr,)*), + @assign(), + @cnt(0usize), + ) + }; + (count: + @container($container:ty), + @data($data:ty), + @child($($child:ty)?), + @no_child($($no_child:ident)?), + @attrs($($aname:ident $aattr:literal)*), + @eat($name:ident $attr:literal, $($rname:ident $rattr:literal,)*), + @assign($($assign:block)*), + @cnt($cnt:expr), + ) => { + $crate::configfs_attrs!( + count: + @container($container), + @data($data), + @child($($child)?), + @no_child($($no_child)?), + @attrs($($aname $aattr)*), + @eat($($rname $rattr,)*), + @assign($($assign)* { + const N: usize = $cnt; + // The following macro text expands to a call to `Attribute::add`. + + // SAFETY: By design of this macro, the name of the variable we + // invoke the `add` method on below, is not visible outside of + // the macro expansion. The macro does not operate concurrently + // on this variable, and thus we have exclusive access to the + // variable. + unsafe { + $crate::macros::paste!( + [< $data:upper _ATTRS >] + .add::<N, $attr, _>(&[< $data:upper _ $name:upper _ATTR >]) + ) + }; + }), + @cnt(1usize + $cnt), + ) + }; + (count: + @container($container:ty), + @data($data:ty), + @child($($child:ty)?), + @no_child($($no_child:ident)?), + @attrs($($aname:ident $aattr:literal)*), + @eat(), + @assign($($assign:block)*), + @cnt($cnt:expr), + ) => + { + $crate::configfs_attrs!( + final: + @container($container), + @data($data), + @child($($child)?), + @no_child($($no_child)?), + @attrs($($aname $aattr)*), + @assign($($assign)*), + @cnt($cnt), + ) + }; + (final: + @container($container:ty), + @data($data:ty), + @child($($child:ty)?), + @no_child($($no_child:ident)?), + @attrs($($name:ident $attr:literal)*), + @assign($($assign:block)*), + @cnt($cnt:expr), + ) => + { + $crate::macros::paste!{ + { + $( + // SAFETY: We are expanding `configfs_attrs`. + static [< $data:upper _ $name:upper _ATTR >]: + $crate::configfs::Attribute<$attr, $data, $data> = + unsafe { + $crate::configfs::Attribute::new(c_str!(::core::stringify!($name))) + }; + )* + + + // We need space for a null terminator. + const N: usize = $cnt + 1usize; + + // SAFETY: We are expanding `configfs_attrs`. + static [< $data:upper _ATTRS >]: + $crate::configfs::AttributeList<N, $data> = + unsafe { $crate::configfs::AttributeList::new() }; + + $($assign)* + + $( + const [<$no_child:upper>]: bool = true; + + static [< $data:upper _TPE >] : $crate::configfs::ItemType<$container, $data> = + $crate::configfs::ItemType::<$container, $data>::new::<N>( + &THIS_MODULE, &[<$ data:upper _ATTRS >] + ); + )? + + $( + static [< $data:upper _TPE >]: + $crate::configfs::ItemType<$container, $data> = + $crate::configfs::ItemType::<$container, $data>:: + new_with_child_ctor::<N, $child>( + &THIS_MODULE, &[<$ data:upper _ATTRS >] + ); + )? + + & [< $data:upper _TPE >] + } + } + }; + +} diff --git a/rust/kernel/cpu.rs b/rust/kernel/cpu.rs new file mode 100644 index 000000000000..5de730c8d817 --- /dev/null +++ b/rust/kernel/cpu.rs @@ -0,0 +1,151 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Generic CPU definitions. +//! +//! C header: [`include/linux/cpu.h`](srctree/include/linux/cpu.h) + +use crate::{bindings, device::Device, error::Result, prelude::ENODEV}; + +/// Returns the maximum number of possible CPUs in the current system configuration. +#[inline] +pub fn nr_cpu_ids() -> u32 { + #[cfg(any(NR_CPUS_1, CONFIG_FORCE_NR_CPUS))] + { + bindings::NR_CPUS + } + + #[cfg(not(any(NR_CPUS_1, CONFIG_FORCE_NR_CPUS)))] + // SAFETY: `nr_cpu_ids` is a valid global provided by the kernel. + unsafe { + bindings::nr_cpu_ids + } +} + +/// The CPU ID. +/// +/// Represents a CPU identifier as a wrapper around an [`u32`]. +/// +/// # Invariants +/// +/// The CPU ID lies within the range `[0, nr_cpu_ids())`. +/// +/// # Examples +/// +/// ``` +/// use kernel::cpu::CpuId; +/// +/// let cpu = 0; +/// +/// // SAFETY: 0 is always a valid CPU number. +/// let id = unsafe { CpuId::from_u32_unchecked(cpu) }; +/// +/// assert_eq!(id.as_u32(), cpu); +/// assert!(CpuId::from_i32(0).is_some()); +/// assert!(CpuId::from_i32(-1).is_none()); +/// ``` +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +pub struct CpuId(u32); + +impl CpuId { + /// Creates a new [`CpuId`] from the given `id` without checking bounds. + /// + /// # Safety + /// + /// The caller must ensure that `id` is a valid CPU ID (i.e., `0 <= id < nr_cpu_ids()`). + #[inline] + pub unsafe fn from_i32_unchecked(id: i32) -> Self { + debug_assert!(id >= 0); + debug_assert!((id as u32) < nr_cpu_ids()); + + // INVARIANT: The function safety guarantees `id` is a valid CPU id. + Self(id as u32) + } + + /// Creates a new [`CpuId`] from the given `id`, checking that it is valid. + pub fn from_i32(id: i32) -> Option<Self> { + if id < 0 || id as u32 >= nr_cpu_ids() { + None + } else { + // INVARIANT: `id` has just been checked as a valid CPU ID. + Some(Self(id as u32)) + } + } + + /// Creates a new [`CpuId`] from the given `id` without checking bounds. + /// + /// # Safety + /// + /// The caller must ensure that `id` is a valid CPU ID (i.e., `0 <= id < nr_cpu_ids()`). + #[inline] + pub unsafe fn from_u32_unchecked(id: u32) -> Self { + debug_assert!(id < nr_cpu_ids()); + + // Ensure the `id` fits in an [`i32`] as it's also representable that way. + debug_assert!(id <= i32::MAX as u32); + + // INVARIANT: The function safety guarantees `id` is a valid CPU id. + Self(id) + } + + /// Creates a new [`CpuId`] from the given `id`, checking that it is valid. + pub fn from_u32(id: u32) -> Option<Self> { + if id >= nr_cpu_ids() { + None + } else { + // INVARIANT: `id` has just been checked as a valid CPU ID. + Some(Self(id)) + } + } + + /// Returns CPU number. + #[inline] + pub fn as_u32(&self) -> u32 { + self.0 + } + + /// Returns the ID of the CPU the code is currently running on. + /// + /// The returned value is considered unstable because it may change + /// unexpectedly due to preemption or CPU migration. It should only be + /// used when the context ensures that the task remains on the same CPU + /// or the users could use a stale (yet valid) CPU ID. + pub fn current() -> Self { + // SAFETY: raw_smp_processor_id() always returns a valid CPU ID. + unsafe { Self::from_u32_unchecked(bindings::raw_smp_processor_id()) } + } +} + +impl From<CpuId> for u32 { + fn from(id: CpuId) -> Self { + id.as_u32() + } +} + +impl From<CpuId> for i32 { + fn from(id: CpuId) -> Self { + id.as_u32() as i32 + } +} + +/// Creates a new instance of CPU's device. +/// +/// # Safety +/// +/// Reference counting is not implemented for the CPU device in the C code. When a CPU is +/// hot-unplugged, the corresponding CPU device is unregistered, but its associated memory +/// is not freed. +/// +/// Callers must ensure that the CPU device is not used after it has been unregistered. +/// This can be achieved, for example, by registering a CPU hotplug notifier and removing +/// any references to the CPU device within the notifier's callback. +pub unsafe fn from_cpu(cpu: CpuId) -> Result<&'static Device> { + // SAFETY: It is safe to call `get_cpu_device()` for any CPU. + let ptr = unsafe { bindings::get_cpu_device(u32::from(cpu)) }; + if ptr.is_null() { + return Err(ENODEV); + } + + // SAFETY: The pointer returned by `get_cpu_device()`, if not `NULL`, is a valid pointer to + // a `struct device` and is never freed by the C code. + Ok(unsafe { Device::from_raw(ptr) }) +} diff --git a/rust/kernel/cpufreq.rs b/rust/kernel/cpufreq.rs new file mode 100644 index 000000000000..afc15e72a7c3 --- /dev/null +++ b/rust/kernel/cpufreq.rs @@ -0,0 +1,1399 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! CPU frequency scaling. +//! +//! This module provides rust abstractions for interacting with the cpufreq subsystem. +//! +//! C header: [`include/linux/cpufreq.h`](srctree/include/linux/cpufreq.h) +//! +//! Reference: <https://docs.kernel.org/admin-guide/pm/cpufreq.html> + +use crate::{ + clk::Hertz, + cpu::CpuId, + cpumask, + device::{Bound, Device}, + devres, + error::{code::*, from_err_ptr, from_result, to_result, Result, VTABLE_DEFAULT_ERROR}, + ffi::{c_char, c_ulong}, + prelude::*, + types::ForeignOwnable, + types::Opaque, +}; + +#[cfg(CONFIG_COMMON_CLK)] +use crate::clk::Clk; + +use core::{ + cell::UnsafeCell, + marker::PhantomData, + mem::MaybeUninit, + ops::{Deref, DerefMut}, + pin::Pin, + ptr, +}; + +use macros::vtable; + +/// Maximum length of CPU frequency driver's name. +const CPUFREQ_NAME_LEN: usize = bindings::CPUFREQ_NAME_LEN as usize; + +/// Default transition latency value in nanoseconds. +pub const ETERNAL_LATENCY_NS: u32 = bindings::CPUFREQ_ETERNAL as u32; + +/// CPU frequency driver flags. +pub mod flags { + /// Driver needs to update internal limits even if frequency remains unchanged. + pub const NEED_UPDATE_LIMITS: u16 = 1 << 0; + + /// Platform where constants like `loops_per_jiffy` are unaffected by frequency changes. + pub const CONST_LOOPS: u16 = 1 << 1; + + /// Register driver as a thermal cooling device automatically. + pub const IS_COOLING_DEV: u16 = 1 << 2; + + /// Supports multiple clock domains with per-policy governors in `cpu/cpuN/cpufreq/`. + pub const HAVE_GOVERNOR_PER_POLICY: u16 = 1 << 3; + + /// Allows post-change notifications outside of the `target()` routine. + pub const ASYNC_NOTIFICATION: u16 = 1 << 4; + + /// Ensure CPU starts at a valid frequency from the driver's freq-table. + pub const NEED_INITIAL_FREQ_CHECK: u16 = 1 << 5; + + /// Disallow governors with `dynamic_switching` capability. + pub const NO_AUTO_DYNAMIC_SWITCHING: u16 = 1 << 6; +} + +/// Relations from the C code. +const CPUFREQ_RELATION_L: u32 = 0; +const CPUFREQ_RELATION_H: u32 = 1; +const CPUFREQ_RELATION_C: u32 = 2; + +/// Can be used with any of the above values. +const CPUFREQ_RELATION_E: u32 = 1 << 2; + +/// CPU frequency selection relations. +/// +/// CPU frequency selection relations, each optionally marked as "efficient". +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum Relation { + /// Select the lowest frequency at or above target. + Low(bool), + /// Select the highest frequency below or at target. + High(bool), + /// Select the closest frequency to the target. + Close(bool), +} + +impl Relation { + // Construct from a C-compatible `u32` value. + fn new(val: u32) -> Result<Self> { + let efficient = val & CPUFREQ_RELATION_E != 0; + + Ok(match val & !CPUFREQ_RELATION_E { + CPUFREQ_RELATION_L => Self::Low(efficient), + CPUFREQ_RELATION_H => Self::High(efficient), + CPUFREQ_RELATION_C => Self::Close(efficient), + _ => return Err(EINVAL), + }) + } +} + +impl From<Relation> for u32 { + // Convert to a C-compatible `u32` value. + fn from(rel: Relation) -> Self { + let (mut val, efficient) = match rel { + Relation::Low(e) => (CPUFREQ_RELATION_L, e), + Relation::High(e) => (CPUFREQ_RELATION_H, e), + Relation::Close(e) => (CPUFREQ_RELATION_C, e), + }; + + if efficient { + val |= CPUFREQ_RELATION_E; + } + + val + } +} + +/// Policy data. +/// +/// Rust abstraction for the C `struct cpufreq_policy_data`. +/// +/// # Invariants +/// +/// A [`PolicyData`] instance always corresponds to a valid C `struct cpufreq_policy_data`. +/// +/// The callers must ensure that the `struct cpufreq_policy_data` is valid for access and remains +/// valid for the lifetime of the returned reference. +#[repr(transparent)] +pub struct PolicyData(Opaque<bindings::cpufreq_policy_data>); + +impl PolicyData { + /// Creates a mutable reference to an existing `struct cpufreq_policy_data` pointer. + /// + /// # Safety + /// + /// The caller must ensure that `ptr` is valid for writing and remains valid for the lifetime + /// of the returned reference. + #[inline] + pub unsafe fn from_raw_mut<'a>(ptr: *mut bindings::cpufreq_policy_data) -> &'a mut Self { + // SAFETY: Guaranteed by the safety requirements of the function. + // + // INVARIANT: The caller ensures that `ptr` is valid for writing and remains valid for the + // lifetime of the returned reference. + unsafe { &mut *ptr.cast() } + } + + /// Returns a raw pointer to the underlying C `cpufreq_policy_data`. + #[inline] + pub fn as_raw(&self) -> *mut bindings::cpufreq_policy_data { + let this: *const Self = self; + this.cast_mut().cast() + } + + /// Wrapper for `cpufreq_generic_frequency_table_verify`. + #[inline] + pub fn generic_verify(&self) -> Result { + // SAFETY: By the type invariant, the pointer stored in `self` is valid. + to_result(unsafe { bindings::cpufreq_generic_frequency_table_verify(self.as_raw()) }) + } +} + +/// The frequency table index. +/// +/// Represents index with a frequency table. +/// +/// # Invariants +/// +/// The index must correspond to a valid entry in the [`Table`] it is used for. +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +pub struct TableIndex(usize); + +impl TableIndex { + /// Creates an instance of [`TableIndex`]. + /// + /// # Safety + /// + /// The caller must ensure that `index` correspond to a valid entry in the [`Table`] it is used + /// for. + pub unsafe fn new(index: usize) -> Self { + // INVARIANT: The caller ensures that `index` correspond to a valid entry in the [`Table`]. + Self(index) + } +} + +impl From<TableIndex> for usize { + #[inline] + fn from(index: TableIndex) -> Self { + index.0 + } +} + +/// CPU frequency table. +/// +/// Rust abstraction for the C `struct cpufreq_frequency_table`. +/// +/// # Invariants +/// +/// A [`Table`] instance always corresponds to a valid C `struct cpufreq_frequency_table`. +/// +/// The callers must ensure that the `struct cpufreq_frequency_table` is valid for access and +/// remains valid for the lifetime of the returned reference. +/// +/// # Examples +/// +/// The following example demonstrates how to read a frequency value from [`Table`]. +/// +/// ``` +/// use kernel::cpufreq::{Policy, TableIndex}; +/// +/// fn show_freq(policy: &Policy) -> Result { +/// let table = policy.freq_table()?; +/// +/// // SAFETY: Index is a valid entry in the table. +/// let index = unsafe { TableIndex::new(0) }; +/// +/// pr_info!("The frequency at index 0 is: {:?}\n", table.freq(index)?); +/// pr_info!("The flags at index 0 is: {}\n", table.flags(index)); +/// pr_info!("The data at index 0 is: {}\n", table.data(index)); +/// Ok(()) +/// } +/// ``` +#[repr(transparent)] +pub struct Table(Opaque<bindings::cpufreq_frequency_table>); + +impl Table { + /// Creates a reference to an existing C `struct cpufreq_frequency_table` pointer. + /// + /// # Safety + /// + /// The caller must ensure that `ptr` is valid for reading and remains valid for the lifetime + /// of the returned reference. + #[inline] + pub unsafe fn from_raw<'a>(ptr: *const bindings::cpufreq_frequency_table) -> &'a Self { + // SAFETY: Guaranteed by the safety requirements of the function. + // + // INVARIANT: The caller ensures that `ptr` is valid for reading and remains valid for the + // lifetime of the returned reference. + unsafe { &*ptr.cast() } + } + + /// Returns the raw mutable pointer to the C `struct cpufreq_frequency_table`. + #[inline] + pub fn as_raw(&self) -> *mut bindings::cpufreq_frequency_table { + let this: *const Self = self; + this.cast_mut().cast() + } + + /// Returns frequency at `index` in the [`Table`]. + #[inline] + pub fn freq(&self, index: TableIndex) -> Result<Hertz> { + // SAFETY: By the type invariant, the pointer stored in `self` is valid and `index` is + // guaranteed to be valid by its safety requirements. + Ok(Hertz::from_khz(unsafe { + (*self.as_raw().add(index.into())).frequency.try_into()? + })) + } + + /// Returns flags at `index` in the [`Table`]. + #[inline] + pub fn flags(&self, index: TableIndex) -> u32 { + // SAFETY: By the type invariant, the pointer stored in `self` is valid and `index` is + // guaranteed to be valid by its safety requirements. + unsafe { (*self.as_raw().add(index.into())).flags } + } + + /// Returns data at `index` in the [`Table`]. + #[inline] + pub fn data(&self, index: TableIndex) -> u32 { + // SAFETY: By the type invariant, the pointer stored in `self` is valid and `index` is + // guaranteed to be valid by its safety requirements. + unsafe { (*self.as_raw().add(index.into())).driver_data } + } +} + +/// CPU frequency table owned and pinned in memory, created from a [`TableBuilder`]. +pub struct TableBox { + entries: Pin<KVec<bindings::cpufreq_frequency_table>>, +} + +impl TableBox { + /// Constructs a new [`TableBox`] from a [`KVec`] of entries. + /// + /// # Errors + /// + /// Returns `EINVAL` if the entries list is empty. + #[inline] + fn new(entries: KVec<bindings::cpufreq_frequency_table>) -> Result<Self> { + if entries.is_empty() { + return Err(EINVAL); + } + + Ok(Self { + // Pin the entries to memory, since we are passing its pointer to the C code. + entries: Pin::new(entries), + }) + } + + /// Returns a raw pointer to the underlying C `cpufreq_frequency_table`. + #[inline] + fn as_raw(&self) -> *const bindings::cpufreq_frequency_table { + // The pointer is valid until the table gets dropped. + self.entries.as_ptr() + } +} + +impl Deref for TableBox { + type Target = Table; + + fn deref(&self) -> &Self::Target { + // SAFETY: The caller owns TableBox, it is safe to deref. + unsafe { Self::Target::from_raw(self.as_raw()) } + } +} + +/// CPU frequency table builder. +/// +/// This is used by the CPU frequency drivers to build a frequency table dynamically. +/// +/// # Examples +/// +/// The following example demonstrates how to create a CPU frequency table. +/// +/// ``` +/// use kernel::cpufreq::{TableBuilder, TableIndex}; +/// use kernel::clk::Hertz; +/// +/// let mut builder = TableBuilder::new(); +/// +/// // Adds few entries to the table. +/// builder.add(Hertz::from_mhz(700), 0, 1).unwrap(); +/// builder.add(Hertz::from_mhz(800), 2, 3).unwrap(); +/// builder.add(Hertz::from_mhz(900), 4, 5).unwrap(); +/// builder.add(Hertz::from_ghz(1), 6, 7).unwrap(); +/// +/// let table = builder.to_table().unwrap(); +/// +/// // SAFETY: Index values correspond to valid entries in the table. +/// let (index0, index2) = unsafe { (TableIndex::new(0), TableIndex::new(2)) }; +/// +/// assert_eq!(table.freq(index0), Ok(Hertz::from_mhz(700))); +/// assert_eq!(table.flags(index0), 0); +/// assert_eq!(table.data(index0), 1); +/// +/// assert_eq!(table.freq(index2), Ok(Hertz::from_mhz(900))); +/// assert_eq!(table.flags(index2), 4); +/// assert_eq!(table.data(index2), 5); +/// ``` +#[derive(Default)] +#[repr(transparent)] +pub struct TableBuilder { + entries: KVec<bindings::cpufreq_frequency_table>, +} + +impl TableBuilder { + /// Creates a new instance of [`TableBuilder`]. + #[inline] + pub fn new() -> Self { + Self { + entries: KVec::new(), + } + } + + /// Adds a new entry to the table. + pub fn add(&mut self, freq: Hertz, flags: u32, driver_data: u32) -> Result { + // Adds the new entry at the end of the vector. + Ok(self.entries.push( + bindings::cpufreq_frequency_table { + flags, + driver_data, + frequency: freq.as_khz() as u32, + }, + GFP_KERNEL, + )?) + } + + /// Consumes the [`TableBuilder`] and returns [`TableBox`]. + pub fn to_table(mut self) -> Result<TableBox> { + // Add last entry to the table. + self.add(Hertz(c_ulong::MAX), 0, 0)?; + + TableBox::new(self.entries) + } +} + +/// CPU frequency policy. +/// +/// Rust abstraction for the C `struct cpufreq_policy`. +/// +/// # Invariants +/// +/// A [`Policy`] instance always corresponds to a valid C `struct cpufreq_policy`. +/// +/// The callers must ensure that the `struct cpufreq_policy` is valid for access and remains valid +/// for the lifetime of the returned reference. +/// +/// # Examples +/// +/// The following example demonstrates how to create a CPU frequency table. +/// +/// ``` +/// use kernel::cpufreq::{ETERNAL_LATENCY_NS, Policy}; +/// +/// fn update_policy(policy: &mut Policy) { +/// policy +/// .set_dvfs_possible_from_any_cpu(true) +/// .set_fast_switch_possible(true) +/// .set_transition_latency_ns(ETERNAL_LATENCY_NS); +/// +/// pr_info!("The policy details are: {:?}\n", (policy.cpu(), policy.cur())); +/// } +/// ``` +#[repr(transparent)] +pub struct Policy(Opaque<bindings::cpufreq_policy>); + +impl Policy { + /// Creates a reference to an existing `struct cpufreq_policy` pointer. + /// + /// # Safety + /// + /// The caller must ensure that `ptr` is valid for reading and remains valid for the lifetime + /// of the returned reference. + #[inline] + pub unsafe fn from_raw<'a>(ptr: *const bindings::cpufreq_policy) -> &'a Self { + // SAFETY: Guaranteed by the safety requirements of the function. + // + // INVARIANT: The caller ensures that `ptr` is valid for reading and remains valid for the + // lifetime of the returned reference. + unsafe { &*ptr.cast() } + } + + /// Creates a mutable reference to an existing `struct cpufreq_policy` pointer. + /// + /// # Safety + /// + /// The caller must ensure that `ptr` is valid for writing and remains valid for the lifetime + /// of the returned reference. + #[inline] + pub unsafe fn from_raw_mut<'a>(ptr: *mut bindings::cpufreq_policy) -> &'a mut Self { + // SAFETY: Guaranteed by the safety requirements of the function. + // + // INVARIANT: The caller ensures that `ptr` is valid for writing and remains valid for the + // lifetime of the returned reference. + unsafe { &mut *ptr.cast() } + } + + /// Returns a raw mutable pointer to the C `struct cpufreq_policy`. + #[inline] + fn as_raw(&self) -> *mut bindings::cpufreq_policy { + let this: *const Self = self; + this.cast_mut().cast() + } + + #[inline] + fn as_ref(&self) -> &bindings::cpufreq_policy { + // SAFETY: By the type invariant, the pointer stored in `self` is valid. + unsafe { &*self.as_raw() } + } + + #[inline] + fn as_mut_ref(&mut self) -> &mut bindings::cpufreq_policy { + // SAFETY: By the type invariant, the pointer stored in `self` is valid. + unsafe { &mut *self.as_raw() } + } + + /// Returns the primary CPU for the [`Policy`]. + #[inline] + pub fn cpu(&self) -> CpuId { + // SAFETY: The C API guarantees that `cpu` refers to a valid CPU number. + unsafe { CpuId::from_u32_unchecked(self.as_ref().cpu) } + } + + /// Returns the minimum frequency for the [`Policy`]. + #[inline] + pub fn min(&self) -> Hertz { + Hertz::from_khz(self.as_ref().min as usize) + } + + /// Set the minimum frequency for the [`Policy`]. + #[inline] + pub fn set_min(&mut self, min: Hertz) -> &mut Self { + self.as_mut_ref().min = min.as_khz() as u32; + self + } + + /// Returns the maximum frequency for the [`Policy`]. + #[inline] + pub fn max(&self) -> Hertz { + Hertz::from_khz(self.as_ref().max as usize) + } + + /// Set the maximum frequency for the [`Policy`]. + #[inline] + pub fn set_max(&mut self, max: Hertz) -> &mut Self { + self.as_mut_ref().max = max.as_khz() as u32; + self + } + + /// Returns the current frequency for the [`Policy`]. + #[inline] + pub fn cur(&self) -> Hertz { + Hertz::from_khz(self.as_ref().cur as usize) + } + + /// Returns the suspend frequency for the [`Policy`]. + #[inline] + pub fn suspend_freq(&self) -> Hertz { + Hertz::from_khz(self.as_ref().suspend_freq as usize) + } + + /// Sets the suspend frequency for the [`Policy`]. + #[inline] + pub fn set_suspend_freq(&mut self, freq: Hertz) -> &mut Self { + self.as_mut_ref().suspend_freq = freq.as_khz() as u32; + self + } + + /// Provides a wrapper to the generic suspend routine. + #[inline] + pub fn generic_suspend(&mut self) -> Result { + // SAFETY: By the type invariant, the pointer stored in `self` is valid. + to_result(unsafe { bindings::cpufreq_generic_suspend(self.as_mut_ref()) }) + } + + /// Provides a wrapper to the generic get routine. + #[inline] + pub fn generic_get(&self) -> Result<u32> { + // SAFETY: By the type invariant, the pointer stored in `self` is valid. + Ok(unsafe { bindings::cpufreq_generic_get(u32::from(self.cpu())) }) + } + + /// Provides a wrapper to the register with energy model using the OPP core. + #[cfg(CONFIG_PM_OPP)] + #[inline] + pub fn register_em_opp(&mut self) { + // SAFETY: By the type invariant, the pointer stored in `self` is valid. + unsafe { bindings::cpufreq_register_em_with_opp(self.as_mut_ref()) }; + } + + /// Gets [`cpumask::Cpumask`] for a cpufreq [`Policy`]. + #[inline] + pub fn cpus(&mut self) -> &mut cpumask::Cpumask { + // SAFETY: The pointer to `cpus` is valid for writing and remains valid for the lifetime of + // the returned reference. + unsafe { cpumask::CpumaskVar::as_mut_ref(&mut self.as_mut_ref().cpus) } + } + + /// Sets clock for the [`Policy`]. + /// + /// # Safety + /// + /// The caller must guarantee that the returned [`Clk`] is not dropped while it is getting used + /// by the C code. + #[cfg(CONFIG_COMMON_CLK)] + pub unsafe fn set_clk(&mut self, dev: &Device, name: Option<&CStr>) -> Result<Clk> { + let clk = Clk::get(dev, name)?; + self.as_mut_ref().clk = clk.as_raw(); + Ok(clk) + } + + /// Allows / disallows frequency switching code to run on any CPU. + #[inline] + pub fn set_dvfs_possible_from_any_cpu(&mut self, val: bool) -> &mut Self { + self.as_mut_ref().dvfs_possible_from_any_cpu = val; + self + } + + /// Returns if fast switching of frequencies is possible or not. + #[inline] + pub fn fast_switch_possible(&self) -> bool { + self.as_ref().fast_switch_possible + } + + /// Enables / disables fast frequency switching. + #[inline] + pub fn set_fast_switch_possible(&mut self, val: bool) -> &mut Self { + self.as_mut_ref().fast_switch_possible = val; + self + } + + /// Sets transition latency (in nanoseconds) for the [`Policy`]. + #[inline] + pub fn set_transition_latency_ns(&mut self, latency_ns: u32) -> &mut Self { + self.as_mut_ref().cpuinfo.transition_latency = latency_ns; + self + } + + /// Sets cpuinfo `min_freq`. + #[inline] + pub fn set_cpuinfo_min_freq(&mut self, min_freq: Hertz) -> &mut Self { + self.as_mut_ref().cpuinfo.min_freq = min_freq.as_khz() as u32; + self + } + + /// Sets cpuinfo `max_freq`. + #[inline] + pub fn set_cpuinfo_max_freq(&mut self, max_freq: Hertz) -> &mut Self { + self.as_mut_ref().cpuinfo.max_freq = max_freq.as_khz() as u32; + self + } + + /// Set `transition_delay_us`, i.e. the minimum time between successive frequency change + /// requests. + #[inline] + pub fn set_transition_delay_us(&mut self, transition_delay_us: u32) -> &mut Self { + self.as_mut_ref().transition_delay_us = transition_delay_us; + self + } + + /// Returns reference to the CPU frequency [`Table`] for the [`Policy`]. + pub fn freq_table(&self) -> Result<&Table> { + if self.as_ref().freq_table.is_null() { + return Err(EINVAL); + } + + // SAFETY: The `freq_table` is guaranteed to be valid for reading and remains valid for the + // lifetime of the returned reference. + Ok(unsafe { Table::from_raw(self.as_ref().freq_table) }) + } + + /// Sets the CPU frequency [`Table`] for the [`Policy`]. + /// + /// # Safety + /// + /// The caller must guarantee that the [`Table`] is not dropped while it is getting used by the + /// C code. + #[inline] + pub unsafe fn set_freq_table(&mut self, table: &Table) -> &mut Self { + self.as_mut_ref().freq_table = table.as_raw(); + self + } + + /// Returns the [`Policy`]'s private data. + pub fn data<T: ForeignOwnable>(&mut self) -> Option<<T>::Borrowed<'_>> { + if self.as_ref().driver_data.is_null() { + None + } else { + // SAFETY: The data is earlier set from [`set_data`]. + Some(unsafe { T::borrow(self.as_ref().driver_data.cast()) }) + } + } + + /// Sets the private data of the [`Policy`] using a foreign-ownable wrapper. + /// + /// # Errors + /// + /// Returns `EBUSY` if private data is already set. + fn set_data<T: ForeignOwnable>(&mut self, data: T) -> Result { + if self.as_ref().driver_data.is_null() { + // Transfer the ownership of the data to the foreign interface. + self.as_mut_ref().driver_data = <T as ForeignOwnable>::into_foreign(data).cast(); + Ok(()) + } else { + Err(EBUSY) + } + } + + /// Clears and returns ownership of the private data. + fn clear_data<T: ForeignOwnable>(&mut self) -> Option<T> { + if self.as_ref().driver_data.is_null() { + None + } else { + let data = Some( + // SAFETY: The data is earlier set by us from [`set_data`]. It is safe to take + // back the ownership of the data from the foreign interface. + unsafe { <T as ForeignOwnable>::from_foreign(self.as_ref().driver_data.cast()) }, + ); + self.as_mut_ref().driver_data = ptr::null_mut(); + data + } + } +} + +/// CPU frequency policy created from a CPU number. +/// +/// This struct represents the CPU frequency policy obtained for a specific CPU, providing safe +/// access to the underlying `cpufreq_policy` and ensuring proper cleanup when the `PolicyCpu` is +/// dropped. +struct PolicyCpu<'a>(&'a mut Policy); + +impl<'a> PolicyCpu<'a> { + fn from_cpu(cpu: CpuId) -> Result<Self> { + // SAFETY: It is safe to call `cpufreq_cpu_get` for any valid CPU. + let ptr = from_err_ptr(unsafe { bindings::cpufreq_cpu_get(u32::from(cpu)) })?; + + Ok(Self( + // SAFETY: The `ptr` is guaranteed to be valid and remains valid for the lifetime of + // the returned reference. + unsafe { Policy::from_raw_mut(ptr) }, + )) + } +} + +impl<'a> Deref for PolicyCpu<'a> { + type Target = Policy; + + fn deref(&self) -> &Self::Target { + self.0 + } +} + +impl<'a> DerefMut for PolicyCpu<'a> { + fn deref_mut(&mut self) -> &mut Policy { + self.0 + } +} + +impl<'a> Drop for PolicyCpu<'a> { + fn drop(&mut self) { + // SAFETY: The underlying pointer is guaranteed to be valid for the lifetime of `self`. + unsafe { bindings::cpufreq_cpu_put(self.0.as_raw()) }; + } +} + +/// CPU frequency driver. +/// +/// Implement this trait to provide a CPU frequency driver and its callbacks. +/// +/// Reference: <https://docs.kernel.org/cpu-freq/cpu-drivers.html> +#[vtable] +pub trait Driver { + /// Driver's name. + const NAME: &'static CStr; + + /// Driver's flags. + const FLAGS: u16; + + /// Boost support. + const BOOST_ENABLED: bool; + + /// Policy specific data. + /// + /// Require that `PData` implements `ForeignOwnable`. We guarantee to never move the underlying + /// wrapped data structure. + type PData: ForeignOwnable; + + /// Driver's `init` callback. + fn init(policy: &mut Policy) -> Result<Self::PData>; + + /// Driver's `exit` callback. + fn exit(_policy: &mut Policy, _data: Option<Self::PData>) -> Result { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Driver's `online` callback. + fn online(_policy: &mut Policy) -> Result { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Driver's `offline` callback. + fn offline(_policy: &mut Policy) -> Result { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Driver's `suspend` callback. + fn suspend(_policy: &mut Policy) -> Result { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Driver's `resume` callback. + fn resume(_policy: &mut Policy) -> Result { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Driver's `ready` callback. + fn ready(_policy: &mut Policy) { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Driver's `verify` callback. + fn verify(data: &mut PolicyData) -> Result; + + /// Driver's `setpolicy` callback. + fn setpolicy(_policy: &mut Policy) -> Result { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Driver's `target` callback. + fn target(_policy: &mut Policy, _target_freq: u32, _relation: Relation) -> Result { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Driver's `target_index` callback. + fn target_index(_policy: &mut Policy, _index: TableIndex) -> Result { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Driver's `fast_switch` callback. + fn fast_switch(_policy: &mut Policy, _target_freq: u32) -> u32 { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Driver's `adjust_perf` callback. + fn adjust_perf(_policy: &mut Policy, _min_perf: usize, _target_perf: usize, _capacity: usize) { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Driver's `get_intermediate` callback. + fn get_intermediate(_policy: &mut Policy, _index: TableIndex) -> u32 { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Driver's `target_intermediate` callback. + fn target_intermediate(_policy: &mut Policy, _index: TableIndex) -> Result { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Driver's `get` callback. + fn get(_policy: &mut Policy) -> Result<u32> { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Driver's `update_limits` callback. + fn update_limits(_policy: &mut Policy) { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Driver's `bios_limit` callback. + fn bios_limit(_policy: &mut Policy, _limit: &mut u32) -> Result { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Driver's `set_boost` callback. + fn set_boost(_policy: &mut Policy, _state: i32) -> Result { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Driver's `register_em` callback. + fn register_em(_policy: &mut Policy) { + build_error!(VTABLE_DEFAULT_ERROR) + } +} + +/// CPU frequency driver Registration. +/// +/// # Examples +/// +/// The following example demonstrates how to register a cpufreq driver. +/// +/// ``` +/// use kernel::{ +/// cpufreq, +/// c_str, +/// device::{Core, Device}, +/// macros::vtable, +/// of, platform, +/// sync::Arc, +/// }; +/// struct SampleDevice; +/// +/// #[derive(Default)] +/// struct SampleDriver; +/// +/// #[vtable] +/// impl cpufreq::Driver for SampleDriver { +/// const NAME: &'static CStr = c_str!("cpufreq-sample"); +/// const FLAGS: u16 = cpufreq::flags::NEED_INITIAL_FREQ_CHECK | cpufreq::flags::IS_COOLING_DEV; +/// const BOOST_ENABLED: bool = true; +/// +/// type PData = Arc<SampleDevice>; +/// +/// fn init(policy: &mut cpufreq::Policy) -> Result<Self::PData> { +/// // Initialize here +/// Ok(Arc::new(SampleDevice, GFP_KERNEL)?) +/// } +/// +/// fn exit(_policy: &mut cpufreq::Policy, _data: Option<Self::PData>) -> Result { +/// Ok(()) +/// } +/// +/// fn suspend(policy: &mut cpufreq::Policy) -> Result { +/// policy.generic_suspend() +/// } +/// +/// fn verify(data: &mut cpufreq::PolicyData) -> Result { +/// data.generic_verify() +/// } +/// +/// fn target_index(policy: &mut cpufreq::Policy, index: cpufreq::TableIndex) -> Result { +/// // Update CPU frequency +/// Ok(()) +/// } +/// +/// fn get(policy: &mut cpufreq::Policy) -> Result<u32> { +/// policy.generic_get() +/// } +/// } +/// +/// impl platform::Driver for SampleDriver { +/// type IdInfo = (); +/// const OF_ID_TABLE: Option<of::IdTable<Self::IdInfo>> = None; +/// +/// fn probe( +/// pdev: &platform::Device<Core>, +/// _id_info: Option<&Self::IdInfo>, +/// ) -> Result<Pin<KBox<Self>>> { +/// cpufreq::Registration::<SampleDriver>::new_foreign_owned(pdev.as_ref())?; +/// Ok(KBox::new(Self {}, GFP_KERNEL)?.into()) +/// } +/// } +/// ``` +#[repr(transparent)] +pub struct Registration<T: Driver>(KBox<UnsafeCell<bindings::cpufreq_driver>>, PhantomData<T>); + +/// SAFETY: `Registration` doesn't offer any methods or access to fields when shared between threads +/// or CPUs, so it is safe to share it. +unsafe impl<T: Driver> Sync for Registration<T> {} + +#[allow(clippy::non_send_fields_in_send_ty)] +/// SAFETY: Registration with and unregistration from the cpufreq subsystem can happen from any +/// thread. +unsafe impl<T: Driver> Send for Registration<T> {} + +impl<T: Driver> Registration<T> { + const VTABLE: bindings::cpufreq_driver = bindings::cpufreq_driver { + name: Self::copy_name(T::NAME), + boost_enabled: T::BOOST_ENABLED, + flags: T::FLAGS, + + // Initialize mandatory callbacks. + init: Some(Self::init_callback), + verify: Some(Self::verify_callback), + + // Initialize optional callbacks based on the traits of `T`. + setpolicy: if T::HAS_SETPOLICY { + Some(Self::setpolicy_callback) + } else { + None + }, + target: if T::HAS_TARGET { + Some(Self::target_callback) + } else { + None + }, + target_index: if T::HAS_TARGET_INDEX { + Some(Self::target_index_callback) + } else { + None + }, + fast_switch: if T::HAS_FAST_SWITCH { + Some(Self::fast_switch_callback) + } else { + None + }, + adjust_perf: if T::HAS_ADJUST_PERF { + Some(Self::adjust_perf_callback) + } else { + None + }, + get_intermediate: if T::HAS_GET_INTERMEDIATE { + Some(Self::get_intermediate_callback) + } else { + None + }, + target_intermediate: if T::HAS_TARGET_INTERMEDIATE { + Some(Self::target_intermediate_callback) + } else { + None + }, + get: if T::HAS_GET { + Some(Self::get_callback) + } else { + None + }, + update_limits: if T::HAS_UPDATE_LIMITS { + Some(Self::update_limits_callback) + } else { + None + }, + bios_limit: if T::HAS_BIOS_LIMIT { + Some(Self::bios_limit_callback) + } else { + None + }, + online: if T::HAS_ONLINE { + Some(Self::online_callback) + } else { + None + }, + offline: if T::HAS_OFFLINE { + Some(Self::offline_callback) + } else { + None + }, + exit: if T::HAS_EXIT { + Some(Self::exit_callback) + } else { + None + }, + suspend: if T::HAS_SUSPEND { + Some(Self::suspend_callback) + } else { + None + }, + resume: if T::HAS_RESUME { + Some(Self::resume_callback) + } else { + None + }, + ready: if T::HAS_READY { + Some(Self::ready_callback) + } else { + None + }, + set_boost: if T::HAS_SET_BOOST { + Some(Self::set_boost_callback) + } else { + None + }, + register_em: if T::HAS_REGISTER_EM { + Some(Self::register_em_callback) + } else { + None + }, + // SAFETY: All zeros is a valid value for `bindings::cpufreq_driver`. + ..unsafe { MaybeUninit::zeroed().assume_init() } + }; + + const fn copy_name(name: &'static CStr) -> [c_char; CPUFREQ_NAME_LEN] { + let src = name.as_bytes_with_nul(); + let mut dst = [0; CPUFREQ_NAME_LEN]; + + build_assert!(src.len() <= CPUFREQ_NAME_LEN); + + let mut i = 0; + while i < src.len() { + dst[i] = src[i]; + i += 1; + } + + dst + } + + /// Registers a CPU frequency driver with the cpufreq core. + pub fn new() -> Result<Self> { + // We can't use `&Self::VTABLE` directly because the cpufreq core modifies some fields in + // the C `struct cpufreq_driver`, which requires a mutable reference. + let mut drv = KBox::new(UnsafeCell::new(Self::VTABLE), GFP_KERNEL)?; + + // SAFETY: `drv` is guaranteed to be valid for the lifetime of `Registration`. + to_result(unsafe { bindings::cpufreq_register_driver(drv.get_mut()) })?; + + Ok(Self(drv, PhantomData)) + } + + /// Same as [`Registration::new`], but does not return a [`Registration`] instance. + /// + /// Instead the [`Registration`] is owned by [`devres::register`] and will be dropped, once the + /// device is detached. + pub fn new_foreign_owned(dev: &Device<Bound>) -> Result + where + T: 'static, + { + devres::register(dev, Self::new()?, GFP_KERNEL) + } +} + +/// CPU frequency driver callbacks. +impl<T: Driver> Registration<T> { + /// Driver's `init` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn init_callback(ptr: *mut bindings::cpufreq_policy) -> c_int { + from_result(|| { + // SAFETY: The `ptr` is guaranteed to be valid by the contract with the C code for the + // lifetime of `policy`. + let policy = unsafe { Policy::from_raw_mut(ptr) }; + + let data = T::init(policy)?; + policy.set_data(data)?; + Ok(0) + }) + } + + /// Driver's `exit` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn exit_callback(ptr: *mut bindings::cpufreq_policy) { + // SAFETY: The `ptr` is guaranteed to be valid by the contract with the C code for the + // lifetime of `policy`. + let policy = unsafe { Policy::from_raw_mut(ptr) }; + + let data = policy.clear_data(); + let _ = T::exit(policy, data); + } + + /// Driver's `online` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn online_callback(ptr: *mut bindings::cpufreq_policy) -> c_int { + from_result(|| { + // SAFETY: The `ptr` is guaranteed to be valid by the contract with the C code for the + // lifetime of `policy`. + let policy = unsafe { Policy::from_raw_mut(ptr) }; + T::online(policy).map(|()| 0) + }) + } + + /// Driver's `offline` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn offline_callback(ptr: *mut bindings::cpufreq_policy) -> c_int { + from_result(|| { + // SAFETY: The `ptr` is guaranteed to be valid by the contract with the C code for the + // lifetime of `policy`. + let policy = unsafe { Policy::from_raw_mut(ptr) }; + T::offline(policy).map(|()| 0) + }) + } + + /// Driver's `suspend` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn suspend_callback(ptr: *mut bindings::cpufreq_policy) -> c_int { + from_result(|| { + // SAFETY: The `ptr` is guaranteed to be valid by the contract with the C code for the + // lifetime of `policy`. + let policy = unsafe { Policy::from_raw_mut(ptr) }; + T::suspend(policy).map(|()| 0) + }) + } + + /// Driver's `resume` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn resume_callback(ptr: *mut bindings::cpufreq_policy) -> c_int { + from_result(|| { + // SAFETY: The `ptr` is guaranteed to be valid by the contract with the C code for the + // lifetime of `policy`. + let policy = unsafe { Policy::from_raw_mut(ptr) }; + T::resume(policy).map(|()| 0) + }) + } + + /// Driver's `ready` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn ready_callback(ptr: *mut bindings::cpufreq_policy) { + // SAFETY: The `ptr` is guaranteed to be valid by the contract with the C code for the + // lifetime of `policy`. + let policy = unsafe { Policy::from_raw_mut(ptr) }; + T::ready(policy); + } + + /// Driver's `verify` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn verify_callback(ptr: *mut bindings::cpufreq_policy_data) -> c_int { + from_result(|| { + // SAFETY: The `ptr` is guaranteed to be valid by the contract with the C code for the + // lifetime of `policy`. + let data = unsafe { PolicyData::from_raw_mut(ptr) }; + T::verify(data).map(|()| 0) + }) + } + + /// Driver's `setpolicy` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn setpolicy_callback(ptr: *mut bindings::cpufreq_policy) -> c_int { + from_result(|| { + // SAFETY: The `ptr` is guaranteed to be valid by the contract with the C code for the + // lifetime of `policy`. + let policy = unsafe { Policy::from_raw_mut(ptr) }; + T::setpolicy(policy).map(|()| 0) + }) + } + + /// Driver's `target` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn target_callback( + ptr: *mut bindings::cpufreq_policy, + target_freq: c_uint, + relation: c_uint, + ) -> c_int { + from_result(|| { + // SAFETY: The `ptr` is guaranteed to be valid by the contract with the C code for the + // lifetime of `policy`. + let policy = unsafe { Policy::from_raw_mut(ptr) }; + T::target(policy, target_freq, Relation::new(relation)?).map(|()| 0) + }) + } + + /// Driver's `target_index` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn target_index_callback( + ptr: *mut bindings::cpufreq_policy, + index: c_uint, + ) -> c_int { + from_result(|| { + // SAFETY: The `ptr` is guaranteed to be valid by the contract with the C code for the + // lifetime of `policy`. + let policy = unsafe { Policy::from_raw_mut(ptr) }; + + // SAFETY: The C code guarantees that `index` corresponds to a valid entry in the + // frequency table. + let index = unsafe { TableIndex::new(index as usize) }; + + T::target_index(policy, index).map(|()| 0) + }) + } + + /// Driver's `fast_switch` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn fast_switch_callback( + ptr: *mut bindings::cpufreq_policy, + target_freq: c_uint, + ) -> c_uint { + // SAFETY: The `ptr` is guaranteed to be valid by the contract with the C code for the + // lifetime of `policy`. + let policy = unsafe { Policy::from_raw_mut(ptr) }; + T::fast_switch(policy, target_freq) + } + + /// Driver's `adjust_perf` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + unsafe extern "C" fn adjust_perf_callback( + cpu: c_uint, + min_perf: c_ulong, + target_perf: c_ulong, + capacity: c_ulong, + ) { + // SAFETY: The C API guarantees that `cpu` refers to a valid CPU number. + let cpu_id = unsafe { CpuId::from_u32_unchecked(cpu) }; + + if let Ok(mut policy) = PolicyCpu::from_cpu(cpu_id) { + T::adjust_perf(&mut policy, min_perf, target_perf, capacity); + } + } + + /// Driver's `get_intermediate` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn get_intermediate_callback( + ptr: *mut bindings::cpufreq_policy, + index: c_uint, + ) -> c_uint { + // SAFETY: The `ptr` is guaranteed to be valid by the contract with the C code for the + // lifetime of `policy`. + let policy = unsafe { Policy::from_raw_mut(ptr) }; + + // SAFETY: The C code guarantees that `index` corresponds to a valid entry in the + // frequency table. + let index = unsafe { TableIndex::new(index as usize) }; + + T::get_intermediate(policy, index) + } + + /// Driver's `target_intermediate` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn target_intermediate_callback( + ptr: *mut bindings::cpufreq_policy, + index: c_uint, + ) -> c_int { + from_result(|| { + // SAFETY: The `ptr` is guaranteed to be valid by the contract with the C code for the + // lifetime of `policy`. + let policy = unsafe { Policy::from_raw_mut(ptr) }; + + // SAFETY: The C code guarantees that `index` corresponds to a valid entry in the + // frequency table. + let index = unsafe { TableIndex::new(index as usize) }; + + T::target_intermediate(policy, index).map(|()| 0) + }) + } + + /// Driver's `get` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + unsafe extern "C" fn get_callback(cpu: c_uint) -> c_uint { + // SAFETY: The C API guarantees that `cpu` refers to a valid CPU number. + let cpu_id = unsafe { CpuId::from_u32_unchecked(cpu) }; + + PolicyCpu::from_cpu(cpu_id).map_or(0, |mut policy| T::get(&mut policy).map_or(0, |f| f)) + } + + /// Driver's `update_limit` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn update_limits_callback(ptr: *mut bindings::cpufreq_policy) { + // SAFETY: The `ptr` is guaranteed to be valid by the contract with the C code for the + // lifetime of `policy`. + let policy = unsafe { Policy::from_raw_mut(ptr) }; + T::update_limits(policy); + } + + /// Driver's `bios_limit` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn bios_limit_callback(cpu: c_int, limit: *mut c_uint) -> c_int { + // SAFETY: The C API guarantees that `cpu` refers to a valid CPU number. + let cpu_id = unsafe { CpuId::from_i32_unchecked(cpu) }; + + from_result(|| { + let mut policy = PolicyCpu::from_cpu(cpu_id)?; + + // SAFETY: `limit` is guaranteed by the C code to be valid. + T::bios_limit(&mut policy, &mut (unsafe { *limit })).map(|()| 0) + }) + } + + /// Driver's `set_boost` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn set_boost_callback( + ptr: *mut bindings::cpufreq_policy, + state: c_int, + ) -> c_int { + from_result(|| { + // SAFETY: The `ptr` is guaranteed to be valid by the contract with the C code for the + // lifetime of `policy`. + let policy = unsafe { Policy::from_raw_mut(ptr) }; + T::set_boost(policy, state).map(|()| 0) + }) + } + + /// Driver's `register_em` callback. + /// + /// # Safety + /// + /// - This function may only be called from the cpufreq C infrastructure. + /// - The pointer arguments must be valid pointers. + unsafe extern "C" fn register_em_callback(ptr: *mut bindings::cpufreq_policy) { + // SAFETY: The `ptr` is guaranteed to be valid by the contract with the C code for the + // lifetime of `policy`. + let policy = unsafe { Policy::from_raw_mut(ptr) }; + T::register_em(policy); + } +} + +impl<T: Driver> Drop for Registration<T> { + /// Unregisters with the cpufreq core. + fn drop(&mut self) { + // SAFETY: `self.0` is guaranteed to be valid for the lifetime of `Registration`. + unsafe { bindings::cpufreq_unregister_driver(self.0.get_mut()) }; + } +} diff --git a/rust/kernel/cpumask.rs b/rust/kernel/cpumask.rs new file mode 100644 index 000000000000..05e1c882404e --- /dev/null +++ b/rust/kernel/cpumask.rs @@ -0,0 +1,343 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! CPU Mask abstractions. +//! +//! C header: [`include/linux/cpumask.h`](srctree/include/linux/cpumask.h) + +use crate::{ + alloc::{AllocError, Flags}, + cpu::CpuId, + prelude::*, + types::Opaque, +}; + +#[cfg(CONFIG_CPUMASK_OFFSTACK)] +use core::ptr::{self, NonNull}; + +use core::ops::{Deref, DerefMut}; + +/// A CPU Mask. +/// +/// Rust abstraction for the C `struct cpumask`. +/// +/// # Invariants +/// +/// A [`Cpumask`] instance always corresponds to a valid C `struct cpumask`. +/// +/// The callers must ensure that the `struct cpumask` is valid for access and +/// remains valid for the lifetime of the returned reference. +/// +/// # Examples +/// +/// The following example demonstrates how to update a [`Cpumask`]. +/// +/// ``` +/// use kernel::bindings; +/// use kernel::cpu::CpuId; +/// use kernel::cpumask::Cpumask; +/// +/// fn set_clear_cpu(ptr: *mut bindings::cpumask, set_cpu: CpuId, clear_cpu: CpuId) { +/// // SAFETY: The `ptr` is valid for writing and remains valid for the lifetime of the +/// // returned reference. +/// let mask = unsafe { Cpumask::as_mut_ref(ptr) }; +/// +/// mask.set(set_cpu); +/// mask.clear(clear_cpu); +/// } +/// ``` +#[repr(transparent)] +pub struct Cpumask(Opaque<bindings::cpumask>); + +impl Cpumask { + /// Creates a mutable reference to an existing `struct cpumask` pointer. + /// + /// # Safety + /// + /// The caller must ensure that `ptr` is valid for writing and remains valid for the lifetime + /// of the returned reference. + pub unsafe fn as_mut_ref<'a>(ptr: *mut bindings::cpumask) -> &'a mut Self { + // SAFETY: Guaranteed by the safety requirements of the function. + // + // INVARIANT: The caller ensures that `ptr` is valid for writing and remains valid for the + // lifetime of the returned reference. + unsafe { &mut *ptr.cast() } + } + + /// Creates a reference to an existing `struct cpumask` pointer. + /// + /// # Safety + /// + /// The caller must ensure that `ptr` is valid for reading and remains valid for the lifetime + /// of the returned reference. + pub unsafe fn as_ref<'a>(ptr: *const bindings::cpumask) -> &'a Self { + // SAFETY: Guaranteed by the safety requirements of the function. + // + // INVARIANT: The caller ensures that `ptr` is valid for reading and remains valid for the + // lifetime of the returned reference. + unsafe { &*ptr.cast() } + } + + /// Obtain the raw `struct cpumask` pointer. + pub fn as_raw(&self) -> *mut bindings::cpumask { + let this: *const Self = self; + this.cast_mut().cast() + } + + /// Set `cpu` in the cpumask. + /// + /// ATTENTION: Contrary to C, this Rust `set()` method is non-atomic. + /// This mismatches kernel naming convention and corresponds to the C + /// function `__cpumask_set_cpu()`. + #[inline] + pub fn set(&mut self, cpu: CpuId) { + // SAFETY: By the type invariant, `self.as_raw` is a valid argument to `__cpumask_set_cpu`. + unsafe { bindings::__cpumask_set_cpu(u32::from(cpu), self.as_raw()) }; + } + + /// Clear `cpu` in the cpumask. + /// + /// ATTENTION: Contrary to C, this Rust `clear()` method is non-atomic. + /// This mismatches kernel naming convention and corresponds to the C + /// function `__cpumask_clear_cpu()`. + #[inline] + pub fn clear(&mut self, cpu: CpuId) { + // SAFETY: By the type invariant, `self.as_raw` is a valid argument to + // `__cpumask_clear_cpu`. + unsafe { bindings::__cpumask_clear_cpu(i32::from(cpu), self.as_raw()) }; + } + + /// Test `cpu` in the cpumask. + /// + /// Equivalent to the kernel's `cpumask_test_cpu` API. + #[inline] + pub fn test(&self, cpu: CpuId) -> bool { + // SAFETY: By the type invariant, `self.as_raw` is a valid argument to `cpumask_test_cpu`. + unsafe { bindings::cpumask_test_cpu(i32::from(cpu), self.as_raw()) } + } + + /// Set all CPUs in the cpumask. + /// + /// Equivalent to the kernel's `cpumask_setall` API. + #[inline] + pub fn setall(&mut self) { + // SAFETY: By the type invariant, `self.as_raw` is a valid argument to `cpumask_setall`. + unsafe { bindings::cpumask_setall(self.as_raw()) }; + } + + /// Checks if cpumask is empty. + /// + /// Equivalent to the kernel's `cpumask_empty` API. + #[inline] + pub fn empty(&self) -> bool { + // SAFETY: By the type invariant, `self.as_raw` is a valid argument to `cpumask_empty`. + unsafe { bindings::cpumask_empty(self.as_raw()) } + } + + /// Checks if cpumask is full. + /// + /// Equivalent to the kernel's `cpumask_full` API. + #[inline] + pub fn full(&self) -> bool { + // SAFETY: By the type invariant, `self.as_raw` is a valid argument to `cpumask_full`. + unsafe { bindings::cpumask_full(self.as_raw()) } + } + + /// Get weight of the cpumask. + /// + /// Equivalent to the kernel's `cpumask_weight` API. + #[inline] + pub fn weight(&self) -> u32 { + // SAFETY: By the type invariant, `self.as_raw` is a valid argument to `cpumask_weight`. + unsafe { bindings::cpumask_weight(self.as_raw()) } + } + + /// Copy cpumask. + /// + /// Equivalent to the kernel's `cpumask_copy` API. + #[inline] + pub fn copy(&self, dstp: &mut Self) { + // SAFETY: By the type invariant, `Self::as_raw` is a valid argument to `cpumask_copy`. + unsafe { bindings::cpumask_copy(dstp.as_raw(), self.as_raw()) }; + } +} + +/// A CPU Mask pointer. +/// +/// Rust abstraction for the C `struct cpumask_var_t`. +/// +/// # Invariants +/// +/// A [`CpumaskVar`] instance always corresponds to a valid C `struct cpumask_var_t`. +/// +/// The callers must ensure that the `struct cpumask_var_t` is valid for access and remains valid +/// for the lifetime of [`CpumaskVar`]. +/// +/// # Examples +/// +/// The following example demonstrates how to create and update a [`CpumaskVar`]. +/// +/// ``` +/// use kernel::cpu::CpuId; +/// use kernel::cpumask::CpumaskVar; +/// +/// let mut mask = CpumaskVar::new_zero(GFP_KERNEL).unwrap(); +/// +/// assert!(mask.empty()); +/// let mut count = 0; +/// +/// let cpu2 = CpuId::from_u32(2); +/// if let Some(cpu) = cpu2 { +/// mask.set(cpu); +/// assert!(mask.test(cpu)); +/// count += 1; +/// } +/// +/// let cpu3 = CpuId::from_u32(3); +/// if let Some(cpu) = cpu3 { +/// mask.set(cpu); +/// assert!(mask.test(cpu)); +/// count += 1; +/// } +/// +/// assert_eq!(mask.weight(), count); +/// +/// let mask2 = CpumaskVar::try_clone(&mask).unwrap(); +/// +/// if let Some(cpu) = cpu2 { +/// assert!(mask2.test(cpu)); +/// } +/// +/// if let Some(cpu) = cpu3 { +/// assert!(mask2.test(cpu)); +/// } +/// assert_eq!(mask2.weight(), count); +/// ``` +#[repr(transparent)] +pub struct CpumaskVar { + #[cfg(CONFIG_CPUMASK_OFFSTACK)] + ptr: NonNull<Cpumask>, + #[cfg(not(CONFIG_CPUMASK_OFFSTACK))] + mask: Cpumask, +} + +impl CpumaskVar { + /// Creates a zero-initialized instance of the [`CpumaskVar`]. + pub fn new_zero(_flags: Flags) -> Result<Self, AllocError> { + Ok(Self { + #[cfg(CONFIG_CPUMASK_OFFSTACK)] + ptr: { + let mut ptr: *mut bindings::cpumask = ptr::null_mut(); + + // SAFETY: It is safe to call this method as the reference to `ptr` is valid. + // + // INVARIANT: The associated memory is freed when the `CpumaskVar` goes out of + // scope. + unsafe { bindings::zalloc_cpumask_var(&mut ptr, _flags.as_raw()) }; + NonNull::new(ptr.cast()).ok_or(AllocError)? + }, + + #[cfg(not(CONFIG_CPUMASK_OFFSTACK))] + mask: Cpumask(Opaque::zeroed()), + }) + } + + /// Creates an instance of the [`CpumaskVar`]. + /// + /// # Safety + /// + /// The caller must ensure that the returned [`CpumaskVar`] is properly initialized before + /// getting used. + pub unsafe fn new(_flags: Flags) -> Result<Self, AllocError> { + Ok(Self { + #[cfg(CONFIG_CPUMASK_OFFSTACK)] + ptr: { + let mut ptr: *mut bindings::cpumask = ptr::null_mut(); + + // SAFETY: It is safe to call this method as the reference to `ptr` is valid. + // + // INVARIANT: The associated memory is freed when the `CpumaskVar` goes out of + // scope. + unsafe { bindings::alloc_cpumask_var(&mut ptr, _flags.as_raw()) }; + NonNull::new(ptr.cast()).ok_or(AllocError)? + }, + #[cfg(not(CONFIG_CPUMASK_OFFSTACK))] + mask: Cpumask(Opaque::uninit()), + }) + } + + /// Creates a mutable reference to an existing `struct cpumask_var_t` pointer. + /// + /// # Safety + /// + /// The caller must ensure that `ptr` is valid for writing and remains valid for the lifetime + /// of the returned reference. + pub unsafe fn as_mut_ref<'a>(ptr: *mut bindings::cpumask_var_t) -> &'a mut Self { + // SAFETY: Guaranteed by the safety requirements of the function. + // + // INVARIANT: The caller ensures that `ptr` is valid for writing and remains valid for the + // lifetime of the returned reference. + unsafe { &mut *ptr.cast() } + } + + /// Creates a reference to an existing `struct cpumask_var_t` pointer. + /// + /// # Safety + /// + /// The caller must ensure that `ptr` is valid for reading and remains valid for the lifetime + /// of the returned reference. + pub unsafe fn as_ref<'a>(ptr: *const bindings::cpumask_var_t) -> &'a Self { + // SAFETY: Guaranteed by the safety requirements of the function. + // + // INVARIANT: The caller ensures that `ptr` is valid for reading and remains valid for the + // lifetime of the returned reference. + unsafe { &*ptr.cast() } + } + + /// Clones cpumask. + pub fn try_clone(cpumask: &Cpumask) -> Result<Self> { + // SAFETY: The returned cpumask_var is initialized right after this call. + let mut cpumask_var = unsafe { Self::new(GFP_KERNEL) }?; + + cpumask.copy(&mut cpumask_var); + Ok(cpumask_var) + } +} + +// Make [`CpumaskVar`] behave like a pointer to [`Cpumask`]. +impl Deref for CpumaskVar { + type Target = Cpumask; + + #[cfg(CONFIG_CPUMASK_OFFSTACK)] + fn deref(&self) -> &Self::Target { + // SAFETY: The caller owns CpumaskVar, so it is safe to deref the cpumask. + unsafe { &*self.ptr.as_ptr() } + } + + #[cfg(not(CONFIG_CPUMASK_OFFSTACK))] + fn deref(&self) -> &Self::Target { + &self.mask + } +} + +impl DerefMut for CpumaskVar { + #[cfg(CONFIG_CPUMASK_OFFSTACK)] + fn deref_mut(&mut self) -> &mut Cpumask { + // SAFETY: The caller owns CpumaskVar, so it is safe to deref the cpumask. + unsafe { self.ptr.as_mut() } + } + + #[cfg(not(CONFIG_CPUMASK_OFFSTACK))] + fn deref_mut(&mut self) -> &mut Cpumask { + &mut self.mask + } +} + +impl Drop for CpumaskVar { + fn drop(&mut self) { + #[cfg(CONFIG_CPUMASK_OFFSTACK)] + // SAFETY: By the type invariant, `self.as_raw` is a valid argument to `free_cpumask_var`. + unsafe { + bindings::free_cpumask_var(self.as_raw()) + }; + } +} diff --git a/rust/kernel/cred.rs b/rust/kernel/cred.rs index 81d67789b16f..2599f01e8b28 100644 --- a/rust/kernel/cred.rs +++ b/rust/kernel/cred.rs @@ -47,6 +47,7 @@ impl Credential { /// /// The caller must ensure that `ptr` is valid and remains valid for the lifetime of the /// returned [`Credential`] reference. + #[inline] pub unsafe fn from_ptr<'a>(ptr: *const bindings::cred) -> &'a Credential { // SAFETY: The safety requirements guarantee the validity of the dereference, while the // `Credential` type being transparent makes the cast ok. @@ -54,6 +55,7 @@ impl Credential { } /// Get the id for this security context. + #[inline] pub fn get_secid(&self) -> u32 { let mut secid = 0; // SAFETY: The invariants of this type ensures that the pointer is valid. @@ -62,6 +64,7 @@ impl Credential { } /// Returns the effective UID of the given credential. + #[inline] pub fn euid(&self) -> Kuid { // SAFETY: By the type invariant, we know that `self.0` is valid. Furthermore, the `euid` // field of a credential is never changed after initialization, so there is no potential @@ -72,11 +75,13 @@ impl Credential { // SAFETY: The type invariants guarantee that `Credential` is always ref-counted. unsafe impl AlwaysRefCounted for Credential { + #[inline] fn inc_ref(&self) { // SAFETY: The existence of a shared reference means that the refcount is nonzero. unsafe { bindings::get_cred(self.0.get()) }; } + #[inline] unsafe fn dec_ref(obj: core::ptr::NonNull<Credential>) { // SAFETY: The safety requirements guarantee that the refcount is nonzero. The cast is okay // because `Credential` has the same representation as `struct cred`. diff --git a/rust/kernel/device.rs b/rust/kernel/device.rs index d5e6a19ff6b7..a1db49eb159a 100644 --- a/rust/kernel/device.rs +++ b/rust/kernel/device.rs @@ -6,30 +6,141 @@ use crate::{ bindings, - types::{ARef, Opaque}, + types::{ARef, ForeignOwnable, Opaque}, }; -use core::{fmt, ptr}; +use core::{fmt, marker::PhantomData, ptr}; #[cfg(CONFIG_PRINTK)] use crate::c_str; -/// A reference-counted device. +pub mod property; + +/// The core representation of a device in the kernel's driver model. +/// +/// This structure represents the Rust abstraction for a C `struct device`. A [`Device`] can either +/// exist as temporary reference (see also [`Device::from_raw`]), which is only valid within a +/// certain scope or as [`ARef<Device>`], owning a dedicated reference count. +/// +/// # Device Types +/// +/// A [`Device`] can represent either a bus device or a class device. +/// +/// ## Bus Devices +/// +/// A bus device is a [`Device`] that is associated with a physical or virtual bus. Examples of +/// buses include PCI, USB, I2C, and SPI. Devices attached to a bus are registered with a specific +/// bus type, which facilitates matching devices with appropriate drivers based on IDs or other +/// identifying information. Bus devices are visible in sysfs under `/sys/bus/<bus-name>/devices/`. +/// +/// ## Class Devices +/// +/// A class device is a [`Device`] that is associated with a logical category of functionality +/// rather than a physical bus. Examples of classes include block devices, network interfaces, sound +/// cards, and input devices. Class devices are grouped under a common class and exposed to +/// userspace via entries in `/sys/class/<class-name>/`. /// -/// This structure represents the Rust abstraction for a C `struct device`. This implementation -/// abstracts the usage of an already existing C `struct device` within Rust code that we get -/// passed from the C side. +/// # Device Context /// -/// An instance of this abstraction can be obtained temporarily or permanent. +/// [`Device`] references are generic over a [`DeviceContext`], which represents the type state of +/// a [`Device`]. /// -/// A temporary one is bound to the lifetime of the C `struct device` pointer used for creation. -/// A permanent instance is always reference-counted and hence not restricted by any lifetime -/// boundaries. +/// As the name indicates, this type state represents the context of the scope the [`Device`] +/// reference is valid in. For instance, the [`Bound`] context guarantees that the [`Device`] is +/// bound to a driver for the entire duration of the existence of a [`Device<Bound>`] reference. /// -/// For subsystems it is recommended to create a permanent instance to wrap into a subsystem -/// specific device structure (e.g. `pci::Device`). This is useful for passing it to drivers in -/// `T::probe()`, such that a driver can store the `ARef<Device>` (equivalent to storing a -/// `struct device` pointer in a C driver) for arbitrary purposes, e.g. allocating DMA coherent -/// memory. +/// Other [`DeviceContext`] types besides [`Bound`] are [`Normal`], [`Core`] and [`CoreInternal`]. +/// +/// Unless selected otherwise [`Device`] defaults to the [`Normal`] [`DeviceContext`], which by +/// itself has no additional requirements. +/// +/// It is always up to the caller of [`Device::from_raw`] to select the correct [`DeviceContext`] +/// type for the corresponding scope the [`Device`] reference is created in. +/// +/// All [`DeviceContext`] types other than [`Normal`] are intended to be used with +/// [bus devices](#bus-devices) only. +/// +/// # Implementing Bus Devices +/// +/// This section provides a guideline to implement bus specific devices, such as [`pci::Device`] or +/// [`platform::Device`]. +/// +/// A bus specific device should be defined as follows. +/// +/// ```ignore +/// #[repr(transparent)] +/// pub struct Device<Ctx: device::DeviceContext = device::Normal>( +/// Opaque<bindings::bus_device_type>, +/// PhantomData<Ctx>, +/// ); +/// ``` +/// +/// Since devices are reference counted, [`AlwaysRefCounted`] should be implemented for `Device` +/// (i.e. `Device<Normal>`). Note that [`AlwaysRefCounted`] must not be implemented for any other +/// [`DeviceContext`], since all other device context types are only valid within a certain scope. +/// +/// In order to be able to implement the [`DeviceContext`] dereference hierarchy, bus device +/// implementations should call the [`impl_device_context_deref`] macro as shown below. +/// +/// ```ignore +/// // SAFETY: `Device` is a transparent wrapper of a type that doesn't depend on `Device`'s +/// // generic argument. +/// kernel::impl_device_context_deref!(unsafe { Device }); +/// ``` +/// +/// In order to convert from a any [`Device<Ctx>`] to [`ARef<Device>`], bus devices can implement +/// the following macro call. +/// +/// ```ignore +/// kernel::impl_device_context_into_aref!(Device); +/// ``` +/// +/// Bus devices should also implement the following [`AsRef`] implementation, such that users can +/// easily derive a generic [`Device`] reference. +/// +/// ```ignore +/// impl<Ctx: device::DeviceContext> AsRef<device::Device<Ctx>> for Device<Ctx> { +/// fn as_ref(&self) -> &device::Device<Ctx> { +/// ... +/// } +/// } +/// ``` +/// +/// # Implementing Class Devices +/// +/// Class device implementations require less infrastructure and depend slightly more on the +/// specific subsystem. +/// +/// An example implementation for a class device could look like this. +/// +/// ```ignore +/// #[repr(C)] +/// pub struct Device<T: class::Driver> { +/// dev: Opaque<bindings::class_device_type>, +/// data: T::Data, +/// } +/// ``` +/// +/// This class device uses the sub-classing pattern to embed the driver's private data within the +/// allocation of the class device. For this to be possible the class device is generic over the +/// class specific `Driver` trait implementation. +/// +/// Just like any device, class devices are reference counted and should hence implement +/// [`AlwaysRefCounted`] for `Device`. +/// +/// Class devices should also implement the following [`AsRef`] implementation, such that users can +/// easily derive a generic [`Device`] reference. +/// +/// ```ignore +/// impl<T: class::Driver> AsRef<device::Device> for Device<T> { +/// fn as_ref(&self) -> &device::Device { +/// ... +/// } +/// } +/// ``` +/// +/// An example for a class device implementation is +#[cfg_attr(CONFIG_DRM = "y", doc = "[`drm::Device`](kernel::drm::Device).")] +#[cfg_attr(not(CONFIG_DRM = "y"), doc = "`drm::Device`.")] /// /// # Invariants /// @@ -40,8 +151,13 @@ use crate::c_str; /// /// `bindings::device::release` is valid to be called from any thread, hence `ARef<Device>` can be /// dropped from any thread. +/// +/// [`AlwaysRefCounted`]: kernel::types::AlwaysRefCounted +/// [`impl_device_context_deref`]: kernel::impl_device_context_deref +/// [`pci::Device`]: kernel::pci::Device +/// [`platform::Device`]: kernel::platform::Device #[repr(transparent)] -pub struct Device(Opaque<bindings::device>); +pub struct Device<Ctx: DeviceContext = Normal>(Opaque<bindings::device>, PhantomData<Ctx>); impl Device { /// Creates a new reference-counted abstraction instance of an existing `struct device` pointer. @@ -56,14 +172,101 @@ impl Device { /// While not officially documented, this should be the case for any `struct device`. pub unsafe fn get_device(ptr: *mut bindings::device) -> ARef<Self> { // SAFETY: By the safety requirements ptr is valid - unsafe { Self::as_ref(ptr) }.into() + unsafe { Self::from_raw(ptr) }.into() + } + + /// Convert a [`&Device`](Device) into a [`&Device<Bound>`](Device<Bound>). + /// + /// # Safety + /// + /// The caller is responsible to ensure that the returned [`&Device<Bound>`](Device<Bound>) + /// only lives as long as it can be guaranteed that the [`Device`] is actually bound. + pub unsafe fn as_bound(&self) -> &Device<Bound> { + let ptr = core::ptr::from_ref(self); + + // CAST: By the safety requirements the caller is responsible to guarantee that the + // returned reference only lives as long as the device is actually bound. + let ptr = ptr.cast(); + + // SAFETY: + // - `ptr` comes from `from_ref(self)` above, hence it's guaranteed to be valid. + // - Any valid `Device` pointer is also a valid pointer for `Device<Bound>`. + unsafe { &*ptr } + } +} + +impl Device<CoreInternal> { + /// Store a pointer to the bound driver's private data. + pub fn set_drvdata(&self, data: impl ForeignOwnable) { + // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`. + unsafe { bindings::dev_set_drvdata(self.as_raw(), data.into_foreign().cast()) } } + /// Take ownership of the private data stored in this [`Device`]. + /// + /// # Safety + /// + /// - Must only be called once after a preceding call to [`Device::set_drvdata`]. + /// - The type `T` must match the type of the `ForeignOwnable` previously stored by + /// [`Device::set_drvdata`]. + pub unsafe fn drvdata_obtain<T: ForeignOwnable>(&self) -> T { + // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`. + let ptr = unsafe { bindings::dev_get_drvdata(self.as_raw()) }; + + // SAFETY: + // - By the safety requirements of this function, `ptr` comes from a previous call to + // `into_foreign()`. + // - `dev_get_drvdata()` guarantees to return the same pointer given to `dev_set_drvdata()` + // in `into_foreign()`. + unsafe { T::from_foreign(ptr.cast()) } + } + + /// Borrow the driver's private data bound to this [`Device`]. + /// + /// # Safety + /// + /// - Must only be called after a preceding call to [`Device::set_drvdata`] and before + /// [`Device::drvdata_obtain`]. + /// - The type `T` must match the type of the `ForeignOwnable` previously stored by + /// [`Device::set_drvdata`]. + pub unsafe fn drvdata_borrow<T: ForeignOwnable>(&self) -> T::Borrowed<'_> { + // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`. + let ptr = unsafe { bindings::dev_get_drvdata(self.as_raw()) }; + + // SAFETY: + // - By the safety requirements of this function, `ptr` comes from a previous call to + // `into_foreign()`. + // - `dev_get_drvdata()` guarantees to return the same pointer given to `dev_set_drvdata()` + // in `into_foreign()`. + unsafe { T::borrow(ptr.cast()) } + } +} + +impl<Ctx: DeviceContext> Device<Ctx> { /// Obtain the raw `struct device *`. pub(crate) fn as_raw(&self) -> *mut bindings::device { self.0.get() } + /// Returns a reference to the parent device, if any. + #[cfg_attr(not(CONFIG_AUXILIARY_BUS), expect(dead_code))] + pub(crate) fn parent(&self) -> Option<&Self> { + // SAFETY: + // - By the type invariant `self.as_raw()` is always valid. + // - The parent device is only ever set at device creation. + let parent = unsafe { (*self.as_raw()).parent }; + + if parent.is_null() { + None + } else { + // SAFETY: + // - Since `parent` is not NULL, it must be a valid pointer to a `struct device`. + // - `parent` is valid for the lifetime of `self`, since a `struct device` holds a + // reference count of its parent. + Some(unsafe { Self::from_raw(parent) }) + } + } + /// Convert a raw C `struct device` pointer to a `&'a Device`. /// /// # Safety @@ -72,7 +275,7 @@ impl Device { /// i.e. it must be ensured that the reference count of the C `struct device` `ptr` points to /// can't drop to zero, for the duration of this function call and the entire duration when the /// returned reference exists. - pub unsafe fn as_ref<'a>(ptr: *mut bindings::device) -> &'a Self { + pub unsafe fn from_raw<'a>(ptr: *mut bindings::device) -> &'a Self { // SAFETY: Guaranteed by the safety requirements of the function. unsafe { &*ptr.cast() } } @@ -173,15 +376,35 @@ impl Device { #[cfg(CONFIG_PRINTK)] unsafe { bindings::_dev_printk( - klevel as *const _ as *const crate::ffi::c_char, + klevel.as_ptr().cast::<crate::ffi::c_char>(), self.as_raw(), c_str!("%pA").as_char_ptr(), - &msg as *const _ as *const crate::ffi::c_void, + core::ptr::from_ref(&msg).cast::<crate::ffi::c_void>(), ) }; } + + /// Obtain the [`FwNode`](property::FwNode) corresponding to this [`Device`]. + pub fn fwnode(&self) -> Option<&property::FwNode> { + // SAFETY: `self` is valid. + let fwnode_handle = unsafe { bindings::__dev_fwnode(self.as_raw()) }; + if fwnode_handle.is_null() { + return None; + } + // SAFETY: `fwnode_handle` is valid. Its lifetime is tied to `&self`. We + // return a reference instead of an `ARef<FwNode>` because `dev_fwnode()` + // doesn't increment the refcount. It is safe to cast from a + // `struct fwnode_handle*` to a `*const FwNode` because `FwNode` is + // defined as a `#[repr(transparent)]` wrapper around `fwnode_handle`. + Some(unsafe { &*fwnode_handle.cast() }) + } } +// SAFETY: `Device` is a transparent wrapper of a type that doesn't depend on `Device`'s generic +// argument. +kernel::impl_device_context_deref!(unsafe { Device }); +kernel::impl_device_context_into_aref!(Device); + // SAFETY: Instances of `Device` are always reference-counted. unsafe impl crate::types::AlwaysRefCounted for Device { fn inc_ref(&self) { @@ -202,12 +425,178 @@ unsafe impl Send for Device {} // synchronization in `struct device`. unsafe impl Sync for Device {} +/// Marker trait for the context or scope of a bus specific device. +/// +/// [`DeviceContext`] is a marker trait for types representing the context of a bus specific +/// [`Device`]. +/// +/// The specific device context types are: [`CoreInternal`], [`Core`], [`Bound`] and [`Normal`]. +/// +/// [`DeviceContext`] types are hierarchical, which means that there is a strict hierarchy that +/// defines which [`DeviceContext`] type can be derived from another. For instance, any +/// [`Device<Core>`] can dereference to a [`Device<Bound>`]. +/// +/// The following enumeration illustrates the dereference hierarchy of [`DeviceContext`] types. +/// +/// - [`CoreInternal`] => [`Core`] => [`Bound`] => [`Normal`] +/// +/// Bus devices can automatically implement the dereference hierarchy by using +/// [`impl_device_context_deref`]. +/// +/// Note that the guarantee for a [`Device`] reference to have a certain [`DeviceContext`] comes +/// from the specific scope the [`Device`] reference is valid in. +/// +/// [`impl_device_context_deref`]: kernel::impl_device_context_deref +pub trait DeviceContext: private::Sealed {} + +/// The [`Normal`] context is the default [`DeviceContext`] of any [`Device`]. +/// +/// The normal context does not indicate any specific context. Any `Device<Ctx>` is also a valid +/// [`Device<Normal>`]. It is the only [`DeviceContext`] for which it is valid to implement +/// [`AlwaysRefCounted`] for. +/// +/// [`AlwaysRefCounted`]: kernel::types::AlwaysRefCounted +pub struct Normal; + +/// The [`Core`] context is the context of a bus specific device when it appears as argument of +/// any bus specific callback, such as `probe()`. +/// +/// The core context indicates that the [`Device<Core>`] reference's scope is limited to the bus +/// callback it appears in. It is intended to be used for synchronization purposes. Bus device +/// implementations can implement methods for [`Device<Core>`], such that they can only be called +/// from bus callbacks. +pub struct Core; + +/// Semantically the same as [`Core`], but reserved for internal usage of the corresponding bus +/// abstraction. +/// +/// The internal core context is intended to be used in exactly the same way as the [`Core`] +/// context, with the difference that this [`DeviceContext`] is internal to the corresponding bus +/// abstraction. +/// +/// This context mainly exists to share generic [`Device`] infrastructure that should only be called +/// from bus callbacks with bus abstractions, but without making them accessible for drivers. +pub struct CoreInternal; + +/// The [`Bound`] context is the [`DeviceContext`] of a bus specific device when it is guaranteed to +/// be bound to a driver. +/// +/// The bound context indicates that for the entire duration of the lifetime of a [`Device<Bound>`] +/// reference, the [`Device`] is guaranteed to be bound to a driver. +/// +/// Some APIs, such as [`dma::CoherentAllocation`] or [`Devres`] rely on the [`Device`] to be bound, +/// which can be proven with the [`Bound`] device context. +/// +/// Any abstraction that can guarantee a scope where the corresponding bus device is bound, should +/// provide a [`Device<Bound>`] reference to its users for this scope. This allows users to benefit +/// from optimizations for accessing device resources, see also [`Devres::access`]. +/// +/// [`Devres`]: kernel::devres::Devres +/// [`Devres::access`]: kernel::devres::Devres::access +/// [`dma::CoherentAllocation`]: kernel::dma::CoherentAllocation +pub struct Bound; + +mod private { + pub trait Sealed {} + + impl Sealed for super::Bound {} + impl Sealed for super::Core {} + impl Sealed for super::CoreInternal {} + impl Sealed for super::Normal {} +} + +impl DeviceContext for Bound {} +impl DeviceContext for Core {} +impl DeviceContext for CoreInternal {} +impl DeviceContext for Normal {} + +/// # Safety +/// +/// The type given as `$device` must be a transparent wrapper of a type that doesn't depend on the +/// generic argument of `$device`. +#[doc(hidden)] +#[macro_export] +macro_rules! __impl_device_context_deref { + (unsafe { $device:ident, $src:ty => $dst:ty }) => { + impl ::core::ops::Deref for $device<$src> { + type Target = $device<$dst>; + + fn deref(&self) -> &Self::Target { + let ptr: *const Self = self; + + // CAST: `$device<$src>` and `$device<$dst>` transparently wrap the same type by the + // safety requirement of the macro. + let ptr = ptr.cast::<Self::Target>(); + + // SAFETY: `ptr` was derived from `&self`. + unsafe { &*ptr } + } + } + }; +} + +/// Implement [`core::ops::Deref`] traits for allowed [`DeviceContext`] conversions of a (bus +/// specific) device. +/// +/// # Safety +/// +/// The type given as `$device` must be a transparent wrapper of a type that doesn't depend on the +/// generic argument of `$device`. +#[macro_export] +macro_rules! impl_device_context_deref { + (unsafe { $device:ident }) => { + // SAFETY: This macro has the exact same safety requirement as + // `__impl_device_context_deref!`. + ::kernel::__impl_device_context_deref!(unsafe { + $device, + $crate::device::CoreInternal => $crate::device::Core + }); + + // SAFETY: This macro has the exact same safety requirement as + // `__impl_device_context_deref!`. + ::kernel::__impl_device_context_deref!(unsafe { + $device, + $crate::device::Core => $crate::device::Bound + }); + + // SAFETY: This macro has the exact same safety requirement as + // `__impl_device_context_deref!`. + ::kernel::__impl_device_context_deref!(unsafe { + $device, + $crate::device::Bound => $crate::device::Normal + }); + }; +} + +#[doc(hidden)] +#[macro_export] +macro_rules! __impl_device_context_into_aref { + ($src:ty, $device:tt) => { + impl ::core::convert::From<&$device<$src>> for $crate::types::ARef<$device> { + fn from(dev: &$device<$src>) -> Self { + (&**dev).into() + } + } + }; +} + +/// Implement [`core::convert::From`], such that all `&Device<Ctx>` can be converted to an +/// `ARef<Device>`. +#[macro_export] +macro_rules! impl_device_context_into_aref { + ($device:tt) => { + ::kernel::__impl_device_context_into_aref!($crate::device::CoreInternal, $device); + ::kernel::__impl_device_context_into_aref!($crate::device::Core, $device); + ::kernel::__impl_device_context_into_aref!($crate::device::Bound, $device); + }; +} + #[doc(hidden)] #[macro_export] macro_rules! dev_printk { ($method:ident, $dev:expr, $($f:tt)*) => { { - ($dev).$method(core::format_args!($($f)*)); + ($dev).$method(::core::format_args!($($f)*)); } } } @@ -219,9 +608,10 @@ macro_rules! dev_printk { /// Equivalent to the kernel's `dev_emerg` macro. /// /// Mimics the interface of [`std::print!`]. More information about the syntax is available from -/// [`core::fmt`] and `alloc::format!`. +/// [`core::fmt`] and [`std::format!`]. /// /// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html +/// [`std::format!`]: https://doc.rust-lang.org/std/macro.format.html /// /// # Examples /// @@ -244,9 +634,10 @@ macro_rules! dev_emerg { /// Equivalent to the kernel's `dev_alert` macro. /// /// Mimics the interface of [`std::print!`]. More information about the syntax is available from -/// [`core::fmt`] and `alloc::format!`. +/// [`core::fmt`] and [`std::format!`]. /// /// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html +/// [`std::format!`]: https://doc.rust-lang.org/std/macro.format.html /// /// # Examples /// @@ -269,9 +660,10 @@ macro_rules! dev_alert { /// Equivalent to the kernel's `dev_crit` macro. /// /// Mimics the interface of [`std::print!`]. More information about the syntax is available from -/// [`core::fmt`] and `alloc::format!`. +/// [`core::fmt`] and [`std::format!`]. /// /// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html +/// [`std::format!`]: https://doc.rust-lang.org/std/macro.format.html /// /// # Examples /// @@ -294,9 +686,10 @@ macro_rules! dev_crit { /// Equivalent to the kernel's `dev_err` macro. /// /// Mimics the interface of [`std::print!`]. More information about the syntax is available from -/// [`core::fmt`] and `alloc::format!`. +/// [`core::fmt`] and [`std::format!`]. /// /// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html +/// [`std::format!`]: https://doc.rust-lang.org/std/macro.format.html /// /// # Examples /// @@ -319,9 +712,10 @@ macro_rules! dev_err { /// Equivalent to the kernel's `dev_warn` macro. /// /// Mimics the interface of [`std::print!`]. More information about the syntax is available from -/// [`core::fmt`] and `alloc::format!`. +/// [`core::fmt`] and [`std::format!`]. /// /// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html +/// [`std::format!`]: https://doc.rust-lang.org/std/macro.format.html /// /// # Examples /// @@ -344,9 +738,10 @@ macro_rules! dev_warn { /// Equivalent to the kernel's `dev_notice` macro. /// /// Mimics the interface of [`std::print!`]. More information about the syntax is available from -/// [`core::fmt`] and `alloc::format!`. +/// [`core::fmt`] and [`std::format!`]. /// /// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html +/// [`std::format!`]: https://doc.rust-lang.org/std/macro.format.html /// /// # Examples /// @@ -369,9 +764,10 @@ macro_rules! dev_notice { /// Equivalent to the kernel's `dev_info` macro. /// /// Mimics the interface of [`std::print!`]. More information about the syntax is available from -/// [`core::fmt`] and `alloc::format!`. +/// [`core::fmt`] and [`std::format!`]. /// /// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html +/// [`std::format!`]: https://doc.rust-lang.org/std/macro.format.html /// /// # Examples /// @@ -394,9 +790,10 @@ macro_rules! dev_info { /// Equivalent to the kernel's `dev_dbg` macro, except that it doesn't support dynamic debug yet. /// /// Mimics the interface of [`std::print!`]. More information about the syntax is available from -/// [`core::fmt`] and `alloc::format!`. +/// [`core::fmt`] and [`std::format!`]. /// /// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html +/// [`std::format!`]: https://doc.rust-lang.org/std/macro.format.html /// /// # Examples /// diff --git a/rust/kernel/device/property.rs b/rust/kernel/device/property.rs new file mode 100644 index 000000000000..49ee12a906db --- /dev/null +++ b/rust/kernel/device/property.rs @@ -0,0 +1,631 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Unified device property interface. +//! +//! C header: [`include/linux/property.h`](srctree/include/linux/property.h) + +use core::{mem::MaybeUninit, ptr}; + +use super::private::Sealed; +use crate::{ + alloc::KVec, + bindings, + error::{to_result, Result}, + prelude::*, + str::{CStr, CString}, + types::{ARef, Opaque}, +}; + +/// A reference-counted fwnode_handle. +/// +/// This structure represents the Rust abstraction for a +/// C `struct fwnode_handle`. This implementation abstracts the usage of an +/// already existing C `struct fwnode_handle` within Rust code that we get +/// passed from the C side. +/// +/// # Invariants +/// +/// A `FwNode` instance represents a valid `struct fwnode_handle` created by the +/// C portion of the kernel. +/// +/// Instances of this type are always reference-counted, that is, a call to +/// `fwnode_handle_get` ensures that the allocation remains valid at least until +/// the matching call to `fwnode_handle_put`. +#[repr(transparent)] +pub struct FwNode(Opaque<bindings::fwnode_handle>); + +impl FwNode { + /// # Safety + /// + /// Callers must ensure that: + /// - The reference count was incremented at least once. + /// - They relinquish that increment. That is, if there is only one + /// increment, callers must not use the underlying object anymore -- it is + /// only safe to do so via the newly created `ARef<FwNode>`. + unsafe fn from_raw(raw: *mut bindings::fwnode_handle) -> ARef<Self> { + // SAFETY: As per the safety requirements of this function: + // - `NonNull::new_unchecked`: + // - `raw` is not null. + // - `ARef::from_raw`: + // - `raw` has an incremented refcount. + // - that increment is relinquished, i.e. it won't be decremented + // elsewhere. + // CAST: It is safe to cast from a `*mut fwnode_handle` to + // `*mut FwNode`, because `FwNode` is defined as a + // `#[repr(transparent)]` wrapper around `fwnode_handle`. + unsafe { ARef::from_raw(ptr::NonNull::new_unchecked(raw.cast())) } + } + + /// Obtain the raw `struct fwnode_handle *`. + pub(crate) fn as_raw(&self) -> *mut bindings::fwnode_handle { + self.0.get() + } + + /// Returns `true` if `&self` is an OF node, `false` otherwise. + pub fn is_of_node(&self) -> bool { + // SAFETY: The type invariant of `Self` guarantees that `self.as_raw() is a pointer to a + // valid `struct fwnode_handle`. + unsafe { bindings::is_of_node(self.as_raw()) } + } + + /// Returns an object that implements [`Display`](core::fmt::Display) for + /// printing the name of a node. + /// + /// This is an alternative to the default `Display` implementation, which + /// prints the full path. + pub fn display_name(&self) -> impl core::fmt::Display + '_ { + struct FwNodeDisplayName<'a>(&'a FwNode); + + impl core::fmt::Display for FwNodeDisplayName<'_> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + // SAFETY: `self` is valid by its type invariant. + let name = unsafe { bindings::fwnode_get_name(self.0.as_raw()) }; + if name.is_null() { + return Ok(()); + } + // SAFETY: + // - `fwnode_get_name` returns null or a valid C string. + // - `name` was checked to be non-null. + let name = unsafe { CStr::from_char_ptr(name) }; + write!(f, "{name}") + } + } + + FwNodeDisplayName(self) + } + + /// Checks if property is present or not. + pub fn property_present(&self, name: &CStr) -> bool { + // SAFETY: By the invariant of `CStr`, `name` is null-terminated. + unsafe { bindings::fwnode_property_present(self.as_raw().cast_const(), name.as_char_ptr()) } + } + + /// Returns firmware property `name` boolean value. + pub fn property_read_bool(&self, name: &CStr) -> bool { + // SAFETY: + // - `name` is non-null and null-terminated. + // - `self.as_raw()` is valid because `self` is valid. + unsafe { bindings::fwnode_property_read_bool(self.as_raw(), name.as_char_ptr()) } + } + + /// Returns the index of matching string `match_str` for firmware string + /// property `name`. + pub fn property_match_string(&self, name: &CStr, match_str: &CStr) -> Result<usize> { + // SAFETY: + // - `name` and `match_str` are non-null and null-terminated. + // - `self.as_raw` is valid because `self` is valid. + let ret = unsafe { + bindings::fwnode_property_match_string( + self.as_raw(), + name.as_char_ptr(), + match_str.as_char_ptr(), + ) + }; + to_result(ret)?; + Ok(ret as usize) + } + + /// Returns firmware property `name` integer array values in a [`KVec`]. + pub fn property_read_array_vec<'fwnode, 'name, T: PropertyInt>( + &'fwnode self, + name: &'name CStr, + len: usize, + ) -> Result<PropertyGuard<'fwnode, 'name, KVec<T>>> { + let mut val: KVec<T> = KVec::with_capacity(len, GFP_KERNEL)?; + + let res = T::read_array_from_fwnode_property(self, name, val.spare_capacity_mut()); + let res = match res { + Ok(_) => { + // SAFETY: + // - `len` is equal to `val.capacity - val.len`, because + // `val.capacity` is `len` and `val.len` is zero. + // - All elements within the interval [`0`, `len`) were initialized + // by `read_array_from_fwnode_property`. + unsafe { val.inc_len(len) } + Ok(val) + } + Err(e) => Err(e), + }; + Ok(PropertyGuard { + inner: res, + fwnode: self, + name, + }) + } + + /// Returns integer array length for firmware property `name`. + pub fn property_count_elem<T: PropertyInt>(&self, name: &CStr) -> Result<usize> { + T::read_array_len_from_fwnode_property(self, name) + } + + /// Returns the value of firmware property `name`. + /// + /// This method is generic over the type of value to read. The types that + /// can be read are strings, integers and arrays of integers. + /// + /// Reading a [`KVec`] of integers is done with the separate + /// method [`Self::property_read_array_vec`], because it takes an + /// additional `len` argument. + /// + /// Reading a boolean is done with the separate method + /// [`Self::property_read_bool`], because this operation is infallible. + /// + /// For more precise documentation about what types can be read, see + /// the [implementors of Property][Property#implementors] and [its + /// implementations on foreign types][Property#foreign-impls]. + /// + /// # Examples + /// + /// ``` + /// # use kernel::{c_str, device::{Device, property::FwNode}, str::CString}; + /// fn examples(dev: &Device) -> Result { + /// let fwnode = dev.fwnode().ok_or(ENOENT)?; + /// let b: u32 = fwnode.property_read(c_str!("some-number")).required_by(dev)?; + /// if let Some(s) = fwnode.property_read::<CString>(c_str!("some-str")).optional() { + /// // ... + /// } + /// Ok(()) + /// } + /// ``` + pub fn property_read<'fwnode, 'name, T: Property>( + &'fwnode self, + name: &'name CStr, + ) -> PropertyGuard<'fwnode, 'name, T> { + PropertyGuard { + inner: T::read_from_fwnode_property(self, name), + fwnode: self, + name, + } + } + + /// Returns first matching named child node handle. + pub fn get_child_by_name(&self, name: &CStr) -> Option<ARef<Self>> { + // SAFETY: `self` and `name` are valid by their type invariants. + let child = + unsafe { bindings::fwnode_get_named_child_node(self.as_raw(), name.as_char_ptr()) }; + if child.is_null() { + return None; + } + // SAFETY: + // - `fwnode_get_named_child_node` returns a pointer with its refcount + // incremented. + // - That increment is relinquished, i.e. the underlying object is not + // used anymore except via the newly created `ARef`. + Some(unsafe { Self::from_raw(child) }) + } + + /// Returns an iterator over a node's children. + pub fn children<'a>(&'a self) -> impl Iterator<Item = ARef<FwNode>> + 'a { + let mut prev: Option<ARef<FwNode>> = None; + + core::iter::from_fn(move || { + let prev_ptr = match prev.take() { + None => ptr::null_mut(), + Some(prev) => { + // We will pass `prev` to `fwnode_get_next_child_node`, + // which decrements its refcount, so we use + // `ARef::into_raw` to avoid decrementing the refcount + // twice. + let prev = ARef::into_raw(prev); + prev.as_ptr().cast() + } + }; + // SAFETY: + // - `self.as_raw()` is valid by its type invariant. + // - `prev_ptr` may be null, which is allowed and corresponds to + // getting the first child. Otherwise, `prev_ptr` is valid, as it + // is the stored return value from the previous invocation. + // - `prev_ptr` has its refount incremented. + // - The increment of `prev_ptr` is relinquished, i.e. the + // underlying object won't be used anymore. + let next = unsafe { bindings::fwnode_get_next_child_node(self.as_raw(), prev_ptr) }; + if next.is_null() { + return None; + } + // SAFETY: + // - `next` is valid because `fwnode_get_next_child_node` returns a + // pointer with its refcount incremented. + // - That increment is relinquished, i.e. the underlying object + // won't be used anymore, except via the newly created + // `ARef<Self>`. + let next = unsafe { FwNode::from_raw(next) }; + prev = Some(next.clone()); + Some(next) + }) + } + + /// Finds a reference with arguments. + pub fn property_get_reference_args( + &self, + prop: &CStr, + nargs: NArgs<'_>, + index: u32, + ) -> Result<FwNodeReferenceArgs> { + let mut out_args = FwNodeReferenceArgs::default(); + + let (nargs_prop, nargs) = match nargs { + NArgs::Prop(nargs_prop) => (nargs_prop.as_char_ptr(), 0), + NArgs::N(nargs) => (ptr::null(), nargs), + }; + + // SAFETY: + // - `self.0.get()` is valid. + // - `prop.as_char_ptr()` is valid and zero-terminated. + // - `nargs_prop` is valid and zero-terminated if `nargs` + // is zero, otherwise it is allowed to be a null-pointer. + // - The function upholds the type invariants of `out_args`, + // namely: + // - It may fill the field `fwnode` with a valid pointer, + // in which case its refcount is incremented. + // - It may modify the field `nargs`, in which case it + // initializes at least as many elements in `args`. + let ret = unsafe { + bindings::fwnode_property_get_reference_args( + self.0.get(), + prop.as_char_ptr(), + nargs_prop, + nargs, + index, + &mut out_args.0, + ) + }; + to_result(ret)?; + + Ok(out_args) + } +} + +/// The number of arguments to request [`FwNodeReferenceArgs`]. +pub enum NArgs<'a> { + /// The name of the property of the reference indicating the number of + /// arguments. + Prop(&'a CStr), + /// The known number of arguments. + N(u32), +} + +/// The return value of [`FwNode::property_get_reference_args`]. +/// +/// This structure represents the Rust abstraction for a C +/// `struct fwnode_reference_args` which was initialized by the C side. +/// +/// # Invariants +/// +/// If the field `fwnode` is valid, it owns an increment of its refcount. +/// +/// The field `args` contains at least as many initialized elements as indicated +/// by the field `nargs`. +#[repr(transparent)] +#[derive(Default)] +pub struct FwNodeReferenceArgs(bindings::fwnode_reference_args); + +impl Drop for FwNodeReferenceArgs { + fn drop(&mut self) { + if !self.0.fwnode.is_null() { + // SAFETY: + // - By the type invariants of `FwNodeReferenceArgs`, its field + // `fwnode` owns an increment of its refcount. + // - That increment is relinquished. The underlying object won't be + // used anymore because we are dropping it. + let _ = unsafe { FwNode::from_raw(self.0.fwnode) }; + } + } +} + +impl FwNodeReferenceArgs { + /// Returns the slice of reference arguments. + pub fn as_slice(&self) -> &[u64] { + // SAFETY: As per the safety invariant of `FwNodeReferenceArgs`, `nargs` + // is the minimum number of elements in `args` that is valid. + unsafe { core::slice::from_raw_parts(self.0.args.as_ptr(), self.0.nargs as usize) } + } + + /// Returns the number of reference arguments. + pub fn len(&self) -> usize { + self.0.nargs as usize + } + + /// Returns `true` if there are no reference arguments. + pub fn is_empty(&self) -> bool { + self.0.nargs == 0 + } +} + +impl core::fmt::Debug for FwNodeReferenceArgs { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "{:?}", self.as_slice()) + } +} + +// SAFETY: Instances of `FwNode` are always reference-counted. +unsafe impl crate::types::AlwaysRefCounted for FwNode { + fn inc_ref(&self) { + // SAFETY: The existence of a shared reference guarantees that the + // refcount is non-zero. + unsafe { bindings::fwnode_handle_get(self.as_raw()) }; + } + + unsafe fn dec_ref(obj: ptr::NonNull<Self>) { + // SAFETY: The safety requirements guarantee that the refcount is + // non-zero. + unsafe { bindings::fwnode_handle_put(obj.cast().as_ptr()) } + } +} + +enum Node<'a> { + Borrowed(&'a FwNode), + Owned(ARef<FwNode>), +} + +impl core::fmt::Display for FwNode { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + // The logic here is the same as the one in lib/vsprintf.c + // (fwnode_full_name_string). + + // SAFETY: `self.as_raw()` is valid by its type invariant. + let num_parents = unsafe { bindings::fwnode_count_parents(self.as_raw()) }; + + for depth in (0..=num_parents).rev() { + let fwnode = if depth == 0 { + Node::Borrowed(self) + } else { + // SAFETY: `self.as_raw()` is valid. + let ptr = unsafe { bindings::fwnode_get_nth_parent(self.as_raw(), depth) }; + // SAFETY: + // - The depth passed to `fwnode_get_nth_parent` is + // within the valid range, so the returned pointer is + // not null. + // - The reference count was incremented by + // `fwnode_get_nth_parent`. + // - That increment is relinquished to + // `FwNode::from_raw`. + Node::Owned(unsafe { FwNode::from_raw(ptr) }) + }; + // Take a reference to the owned or borrowed `FwNode`. + let fwnode: &FwNode = match &fwnode { + Node::Borrowed(f) => f, + Node::Owned(f) => f, + }; + + // SAFETY: `fwnode` is valid by its type invariant. + let prefix = unsafe { bindings::fwnode_get_name_prefix(fwnode.as_raw()) }; + if !prefix.is_null() { + // SAFETY: `fwnode_get_name_prefix` returns null or a + // valid C string. + let prefix = unsafe { CStr::from_char_ptr(prefix) }; + write!(f, "{prefix}")?; + } + write!(f, "{}", fwnode.display_name())?; + } + + Ok(()) + } +} + +/// Implemented for types that can be read as properties. +/// +/// This is implemented for strings, integers and arrays of integers. It's used +/// to make [`FwNode::property_read`] generic over the type of property being +/// read. There are also two dedicated methods to read other types, because they +/// require more specialized function signatures: +/// - [`property_read_bool`](FwNode::property_read_bool) +/// - [`property_read_array_vec`](FwNode::property_read_array_vec) +/// +/// It must be public, because it appears in the signatures of other public +/// functions, but its methods shouldn't be used outside the kernel crate. +pub trait Property: Sized + Sealed { + /// Used to make [`FwNode::property_read`] generic. + fn read_from_fwnode_property(fwnode: &FwNode, name: &CStr) -> Result<Self>; +} + +impl Sealed for CString {} + +impl Property for CString { + fn read_from_fwnode_property(fwnode: &FwNode, name: &CStr) -> Result<Self> { + let mut str: *mut u8 = ptr::null_mut(); + let pstr: *mut _ = &mut str; + + // SAFETY: + // - `name` is non-null and null-terminated. + // - `fwnode.as_raw` is valid because `fwnode` is valid. + let ret = unsafe { + bindings::fwnode_property_read_string(fwnode.as_raw(), name.as_char_ptr(), pstr.cast()) + }; + to_result(ret)?; + + // SAFETY: + // - `pstr` is a valid pointer to a NUL-terminated C string. + // - It is valid for at least as long as `fwnode`, but it's only used + // within the current function. + // - The memory it points to is not mutated during that time. + let str = unsafe { CStr::from_char_ptr(*pstr) }; + Ok(str.try_into()?) + } +} + +/// Implemented for all integers that can be read as properties. +/// +/// This helper trait is needed on top of the existing [`Property`] +/// trait to associate the integer types of various sizes with their +/// corresponding `fwnode_property_read_*_array` functions. +/// +/// It must be public, because it appears in the signatures of other public +/// functions, but its methods shouldn't be used outside the kernel crate. +pub trait PropertyInt: Copy + Sealed { + /// Reads a property array. + fn read_array_from_fwnode_property<'a>( + fwnode: &FwNode, + name: &CStr, + out: &'a mut [MaybeUninit<Self>], + ) -> Result<&'a mut [Self]>; + + /// Reads the length of a property array. + fn read_array_len_from_fwnode_property(fwnode: &FwNode, name: &CStr) -> Result<usize>; +} +// This macro generates implementations of the traits `Property` and +// `PropertyInt` for integers of various sizes. Its input is a list +// of pairs separated by commas. The first element of the pair is the +// type of the integer, the second one is the name of its corresponding +// `fwnode_property_read_*_array` function. +macro_rules! impl_property_for_int { + ($($int:ty: $f:ident),* $(,)?) => { $( + impl Sealed for $int {} + impl<const N: usize> Sealed for [$int; N] {} + + impl PropertyInt for $int { + fn read_array_from_fwnode_property<'a>( + fwnode: &FwNode, + name: &CStr, + out: &'a mut [MaybeUninit<Self>], + ) -> Result<&'a mut [Self]> { + // SAFETY: + // - `fwnode`, `name` and `out` are all valid by their type + // invariants. + // - `out.len()` is a valid bound for the memory pointed to by + // `out.as_mut_ptr()`. + // CAST: It's ok to cast from `*mut MaybeUninit<$int>` to a + // `*mut $int` because they have the same memory layout. + let ret = unsafe { + bindings::$f( + fwnode.as_raw(), + name.as_char_ptr(), + out.as_mut_ptr().cast(), + out.len(), + ) + }; + to_result(ret)?; + // SAFETY: Transmuting from `&'a mut [MaybeUninit<Self>]` to + // `&'a mut [Self]` is sound, because the previous call to a + // `fwnode_property_read_*_array` function (which didn't fail) + // fully initialized the slice. + Ok(unsafe { core::mem::transmute::<&mut [MaybeUninit<Self>], &mut [Self]>(out) }) + } + + fn read_array_len_from_fwnode_property(fwnode: &FwNode, name: &CStr) -> Result<usize> { + // SAFETY: + // - `fwnode` and `name` are valid by their type invariants. + // - It's ok to pass a null pointer to the + // `fwnode_property_read_*_array` functions if `nval` is zero. + // This will return the length of the array. + let ret = unsafe { + bindings::$f( + fwnode.as_raw(), + name.as_char_ptr(), + ptr::null_mut(), + 0, + ) + }; + to_result(ret)?; + Ok(ret as usize) + } + } + + impl Property for $int { + fn read_from_fwnode_property(fwnode: &FwNode, name: &CStr) -> Result<Self> { + let val: [_; 1] = <[$int; 1]>::read_from_fwnode_property(fwnode, name)?; + Ok(val[0]) + } + } + + impl<const N: usize> Property for [$int; N] { + fn read_from_fwnode_property(fwnode: &FwNode, name: &CStr) -> Result<Self> { + let mut val: [MaybeUninit<$int>; N] = [const { MaybeUninit::uninit() }; N]; + + <$int>::read_array_from_fwnode_property(fwnode, name, &mut val)?; + + // SAFETY: `val` is always initialized when + // `fwnode_property_read_*_array` is successful. + Ok(val.map(|v| unsafe { v.assume_init() })) + } + } + )* }; +} +impl_property_for_int! { + u8: fwnode_property_read_u8_array, + u16: fwnode_property_read_u16_array, + u32: fwnode_property_read_u32_array, + u64: fwnode_property_read_u64_array, + i8: fwnode_property_read_u8_array, + i16: fwnode_property_read_u16_array, + i32: fwnode_property_read_u32_array, + i64: fwnode_property_read_u64_array, +} + +/// A helper for reading device properties. +/// +/// Use [`Self::required_by`] if a missing property is considered a bug and +/// [`Self::optional`] otherwise. +/// +/// For convenience, [`Self::or`] and [`Self::or_default`] are provided. +pub struct PropertyGuard<'fwnode, 'name, T> { + /// The result of reading the property. + inner: Result<T>, + /// The fwnode of the property, used for logging in the "required" case. + fwnode: &'fwnode FwNode, + /// The name of the property, used for logging in the "required" case. + name: &'name CStr, +} + +impl<T> PropertyGuard<'_, '_, T> { + /// Access the property, indicating it is required. + /// + /// If the property is not present, the error is automatically logged. If a + /// missing property is not an error, use [`Self::optional`] instead. The + /// device is required to associate the log with it. + pub fn required_by(self, dev: &super::Device) -> Result<T> { + if self.inner.is_err() { + dev_err!( + dev, + "{}: property '{}' is missing\n", + self.fwnode, + self.name + ); + } + self.inner + } + + /// Access the property, indicating it is optional. + /// + /// In contrast to [`Self::required_by`], no error message is logged if + /// the property is not present. + pub fn optional(self) -> Option<T> { + self.inner.ok() + } + + /// Access the property or the specified default value. + /// + /// Do not pass a sentinel value as default to detect a missing property. + /// Use [`Self::required_by`] or [`Self::optional`] instead. + pub fn or(self, default: T) -> T { + self.inner.unwrap_or(default) + } +} + +impl<T: Default> PropertyGuard<'_, '_, T> { + /// Access the property or a default value. + /// + /// Use [`Self::or`] to specify a custom default value. + pub fn or_default(self) -> T { + self.inner.unwrap_or_default() + } +} diff --git a/rust/kernel/device_id.rs b/rust/kernel/device_id.rs new file mode 100644 index 000000000000..70d57814ff79 --- /dev/null +++ b/rust/kernel/device_id.rs @@ -0,0 +1,206 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Generic implementation of device IDs. +//! +//! Each bus / subsystem that matches device and driver through a bus / subsystem specific ID is +//! expected to implement [`RawDeviceId`]. + +use core::mem::MaybeUninit; + +/// Marker trait to indicate a Rust device ID type represents a corresponding C device ID type. +/// +/// This is meant to be implemented by buses/subsystems so that they can use [`IdTable`] to +/// guarantee (at compile-time) zero-termination of device id tables provided by drivers. +/// +/// # Safety +/// +/// Implementers must ensure that `Self` is layout-compatible with [`RawDeviceId::RawType`]; +/// i.e. it's safe to transmute to `RawDeviceId`. +/// +/// This requirement is needed so `IdArray::new` can convert `Self` to `RawType` when building +/// the ID table. +/// +/// Ideally, this should be achieved using a const function that does conversion instead of +/// transmute; however, const trait functions relies on `const_trait_impl` unstable feature, +/// which is broken/gone in Rust 1.73. +pub unsafe trait RawDeviceId { + /// The raw type that holds the device id. + /// + /// Id tables created from [`Self`] are going to hold this type in its zero-terminated array. + type RawType: Copy; +} + +/// Extension trait for [`RawDeviceId`] for devices that embed an index or context value. +/// +/// This is typically used when the device ID struct includes a field like `driver_data` +/// that is used to store a pointer-sized value (e.g., an index or context pointer). +/// +/// # Safety +/// +/// Implementers must ensure that `DRIVER_DATA_OFFSET` is the correct offset (in bytes) to +/// the context/data field (e.g., the `driver_data` field) within the raw device ID structure. +/// This field must be correctly sized to hold a `usize`. +/// +/// Ideally, the data should be added during `Self` to `RawType` conversion, +/// but there's currently no way to do it when using traits in const. +pub unsafe trait RawDeviceIdIndex: RawDeviceId { + /// The offset (in bytes) to the context/data field in the raw device ID. + const DRIVER_DATA_OFFSET: usize; + + /// The index stored at `DRIVER_DATA_OFFSET` of the implementor of the [`RawDeviceIdIndex`] + /// trait. + fn index(&self) -> usize; +} + +/// A zero-terminated device id array. +#[repr(C)] +pub struct RawIdArray<T: RawDeviceId, const N: usize> { + ids: [T::RawType; N], + sentinel: MaybeUninit<T::RawType>, +} + +impl<T: RawDeviceId, const N: usize> RawIdArray<T, N> { + #[doc(hidden)] + pub const fn size(&self) -> usize { + core::mem::size_of::<Self>() + } +} + +/// A zero-terminated device id array, followed by context data. +#[repr(C)] +pub struct IdArray<T: RawDeviceId, U, const N: usize> { + raw_ids: RawIdArray<T, N>, + id_infos: [U; N], +} + +impl<T: RawDeviceId, U, const N: usize> IdArray<T, U, N> { + /// Creates a new instance of the array. + /// + /// The contents are derived from the given identifiers and context information. + /// + /// # Safety + /// + /// `data_offset` as `None` is always safe. + /// If `data_offset` is `Some(data_offset)`, then: + /// - `data_offset` must be the correct offset (in bytes) to the context/data field + /// (e.g., the `driver_data` field) within the raw device ID structure. + /// - The field at `data_offset` must be correctly sized to hold a `usize`. + const unsafe fn build(ids: [(T, U); N], data_offset: Option<usize>) -> Self { + let mut raw_ids = [const { MaybeUninit::<T::RawType>::uninit() }; N]; + let mut infos = [const { MaybeUninit::uninit() }; N]; + + let mut i = 0usize; + while i < N { + // SAFETY: by the safety requirement of `RawDeviceId`, we're guaranteed that `T` is + // layout-wise compatible with `RawType`. + raw_ids[i] = unsafe { core::mem::transmute_copy(&ids[i].0) }; + if let Some(data_offset) = data_offset { + // SAFETY: by the safety requirement of this function, this would be effectively + // `raw_ids[i].driver_data = i;`. + unsafe { + raw_ids[i] + .as_mut_ptr() + .byte_add(data_offset) + .cast::<usize>() + .write(i); + } + } + + // SAFETY: this is effectively a move: `infos[i] = ids[i].1`. We make a copy here but + // later forget `ids`. + infos[i] = MaybeUninit::new(unsafe { core::ptr::read(&ids[i].1) }); + i += 1; + } + + core::mem::forget(ids); + + Self { + raw_ids: RawIdArray { + // SAFETY: this is effectively `array_assume_init`, which is unstable, so we use + // `transmute_copy` instead. We have initialized all elements of `raw_ids` so this + // `array_assume_init` is safe. + ids: unsafe { core::mem::transmute_copy(&raw_ids) }, + sentinel: MaybeUninit::zeroed(), + }, + // SAFETY: We have initialized all elements of `infos` so this `array_assume_init` is + // safe. + id_infos: unsafe { core::mem::transmute_copy(&infos) }, + } + } + + /// Creates a new instance of the array without writing index values. + /// + /// The contents are derived from the given identifiers and context information. + /// If the device implements [`RawDeviceIdIndex`], consider using [`IdArray::new`] instead. + pub const fn new_without_index(ids: [(T, U); N]) -> Self { + // SAFETY: Calling `Self::build` with `offset = None` is always safe, + // because no raw memory writes are performed in this case. + unsafe { Self::build(ids, None) } + } + + /// Reference to the contained [`RawIdArray`]. + pub const fn raw_ids(&self) -> &RawIdArray<T, N> { + &self.raw_ids + } +} + +impl<T: RawDeviceId + RawDeviceIdIndex, U, const N: usize> IdArray<T, U, N> { + /// Creates a new instance of the array. + /// + /// The contents are derived from the given identifiers and context information. + pub const fn new(ids: [(T, U); N]) -> Self { + // SAFETY: by the safety requirement of `RawDeviceIdIndex`, + // `T::DRIVER_DATA_OFFSET` is guaranteed to be the correct offset (in bytes) to + // a field within `T::RawType`. + unsafe { Self::build(ids, Some(T::DRIVER_DATA_OFFSET)) } + } +} + +/// A device id table. +/// +/// This trait is only implemented by `IdArray`. +/// +/// The purpose of this trait is to allow `&'static dyn IdArray<T, U>` to be in context when `N` in +/// `IdArray` doesn't matter. +pub trait IdTable<T: RawDeviceId, U> { + /// Obtain the pointer to the ID table. + fn as_ptr(&self) -> *const T::RawType; + + /// Obtain the pointer to the bus specific device ID from an index. + fn id(&self, index: usize) -> &T::RawType; + + /// Obtain the pointer to the driver-specific information from an index. + fn info(&self, index: usize) -> &U; +} + +impl<T: RawDeviceId, U, const N: usize> IdTable<T, U> for IdArray<T, U, N> { + fn as_ptr(&self) -> *const T::RawType { + // This cannot be `self.ids.as_ptr()`, as the return pointer must have correct provenance + // to access the sentinel. + core::ptr::from_ref(self).cast() + } + + fn id(&self, index: usize) -> &T::RawType { + &self.raw_ids.ids[index] + } + + fn info(&self, index: usize) -> &U { + &self.id_infos[index] + } +} + +/// Create device table alias for modpost. +#[macro_export] +macro_rules! module_device_table { + ($table_type: literal, $module_table_name:ident, $table_name:ident) => { + #[rustfmt::skip] + #[export_name = + concat!("__mod_device_table__", $table_type, + "__", module_path!(), + "_", line!(), + "_", stringify!($table_name)) + ] + static $module_table_name: [::core::mem::MaybeUninit<u8>; $table_name.raw_ids().size()] = + unsafe { ::core::mem::transmute_copy($table_name.raw_ids()) }; + }; +} diff --git a/rust/kernel/devres.rs b/rust/kernel/devres.rs new file mode 100644 index 000000000000..d04e3fcebafb --- /dev/null +++ b/rust/kernel/devres.rs @@ -0,0 +1,369 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Devres abstraction +//! +//! [`Devres`] represents an abstraction for the kernel devres (device resource management) +//! implementation. + +use crate::{ + alloc::Flags, + bindings, + device::{Bound, Device}, + error::{to_result, Error, Result}, + ffi::c_void, + prelude::*, + revocable::{Revocable, RevocableGuard}, + sync::{rcu, Completion}, + types::{ARef, ForeignOwnable, Opaque, ScopeGuard}, +}; + +use pin_init::Wrapper; + +/// [`Devres`] inner data accessed from [`Devres::callback`]. +#[pin_data] +struct Inner<T: Send> { + #[pin] + data: Revocable<T>, + /// Tracks whether [`Devres::callback`] has been completed. + #[pin] + devm: Completion, + /// Tracks whether revoking [`Self::data`] has been completed. + #[pin] + revoke: Completion, +} + +/// This abstraction is meant to be used by subsystems to containerize [`Device`] bound resources to +/// manage their lifetime. +/// +/// [`Device`] bound resources should be freed when either the resource goes out of scope or the +/// [`Device`] is unbound respectively, depending on what happens first. In any case, it is always +/// guaranteed that revoking the device resource is completed before the corresponding [`Device`] +/// is unbound. +/// +/// To achieve that [`Devres`] registers a devres callback on creation, which is called once the +/// [`Device`] is unbound, revoking access to the encapsulated resource (see also [`Revocable`]). +/// +/// After the [`Devres`] has been unbound it is not possible to access the encapsulated resource +/// anymore. +/// +/// [`Devres`] users should make sure to simply free the corresponding backing resource in `T`'s +/// [`Drop`] implementation. +/// +/// # Examples +/// +/// ```no_run +/// # use kernel::{bindings, device::{Bound, Device}, devres::Devres, io::{Io, IoRaw}}; +/// # use core::ops::Deref; +/// +/// // See also [`pci::Bar`] for a real example. +/// struct IoMem<const SIZE: usize>(IoRaw<SIZE>); +/// +/// impl<const SIZE: usize> IoMem<SIZE> { +/// /// # Safety +/// /// +/// /// [`paddr`, `paddr` + `SIZE`) must be a valid MMIO region that is mappable into the CPUs +/// /// virtual address space. +/// unsafe fn new(paddr: usize) -> Result<Self>{ +/// // SAFETY: By the safety requirements of this function [`paddr`, `paddr` + `SIZE`) is +/// // valid for `ioremap`. +/// let addr = unsafe { bindings::ioremap(paddr as bindings::phys_addr_t, SIZE) }; +/// if addr.is_null() { +/// return Err(ENOMEM); +/// } +/// +/// Ok(IoMem(IoRaw::new(addr as usize, SIZE)?)) +/// } +/// } +/// +/// impl<const SIZE: usize> Drop for IoMem<SIZE> { +/// fn drop(&mut self) { +/// // SAFETY: `self.0.addr()` is guaranteed to be properly mapped by `Self::new`. +/// unsafe { bindings::iounmap(self.0.addr() as *mut c_void); }; +/// } +/// } +/// +/// impl<const SIZE: usize> Deref for IoMem<SIZE> { +/// type Target = Io<SIZE>; +/// +/// fn deref(&self) -> &Self::Target { +/// // SAFETY: The memory range stored in `self` has been properly mapped in `Self::new`. +/// unsafe { Io::from_raw(&self.0) } +/// } +/// } +/// # fn no_run(dev: &Device<Bound>) -> Result<(), Error> { +/// // SAFETY: Invalid usage for example purposes. +/// let iomem = unsafe { IoMem::<{ core::mem::size_of::<u32>() }>::new(0xBAAAAAAD)? }; +/// let devres = KBox::pin_init(Devres::new(dev, iomem), GFP_KERNEL)?; +/// +/// let res = devres.try_access().ok_or(ENXIO)?; +/// res.write8(0x42, 0x0); +/// # Ok(()) +/// # } +/// ``` +/// +/// # Invariants +/// +/// [`Self::inner`] is guaranteed to be initialized and is always accessed read-only. +#[pin_data(PinnedDrop)] +pub struct Devres<T: Send> { + dev: ARef<Device>, + /// Pointer to [`Self::devres_callback`]. + /// + /// Has to be stored, since Rust does not guarantee to always return the same address for a + /// function. However, the C API uses the address as a key. + callback: unsafe extern "C" fn(*mut c_void), + /// Contains all the fields shared with [`Self::callback`]. + // TODO: Replace with `UnsafePinned`, once available. + // + // Subsequently, the `drop_in_place()` in `Devres::drop` and `Devres::new` as well as the + // explicit `Send` and `Sync' impls can be removed. + #[pin] + inner: Opaque<Inner<T>>, + _add_action: (), +} + +impl<T: Send> Devres<T> { + /// Creates a new [`Devres`] instance of the given `data`. + /// + /// The `data` encapsulated within the returned `Devres` instance' `data` will be + /// (revoked)[`Revocable`] once the device is detached. + pub fn new<'a, E>( + dev: &'a Device<Bound>, + data: impl PinInit<T, E> + 'a, + ) -> impl PinInit<Self, Error> + 'a + where + T: 'a, + Error: From<E>, + { + let callback = Self::devres_callback; + + try_pin_init!(&this in Self { + dev: dev.into(), + callback, + // INVARIANT: `inner` is properly initialized. + inner <- Opaque::pin_init(try_pin_init!(Inner { + devm <- Completion::new(), + revoke <- Completion::new(), + data <- Revocable::new(data), + })), + // TODO: Replace with "initializer code blocks" [1] once available. + // + // [1] https://github.com/Rust-for-Linux/pin-init/pull/69 + _add_action: { + // SAFETY: `this` is a valid pointer to uninitialized memory. + let inner = unsafe { &raw mut (*this.as_ptr()).inner }; + + // SAFETY: + // - `dev.as_raw()` is a pointer to a valid bound device. + // - `inner` is guaranteed to be a valid for the duration of the lifetime of `Self`. + // - `devm_add_action()` is guaranteed not to call `callback` until `this` has been + // properly initialized, because we require `dev` (i.e. the *bound* device) to + // live at least as long as the returned `impl PinInit<Self, Error>`. + to_result(unsafe { + bindings::devm_add_action(dev.as_raw(), Some(callback), inner.cast()) + }).inspect_err(|_| { + let inner = Opaque::cast_into(inner); + + // SAFETY: `inner` is a valid pointer to an `Inner<T>` and valid for both reads + // and writes. + unsafe { core::ptr::drop_in_place(inner) }; + })?; + }, + }) + } + + fn inner(&self) -> &Inner<T> { + // SAFETY: By the type invairants of `Self`, `inner` is properly initialized and always + // accessed read-only. + unsafe { &*self.inner.get() } + } + + fn data(&self) -> &Revocable<T> { + &self.inner().data + } + + #[allow(clippy::missing_safety_doc)] + unsafe extern "C" fn devres_callback(ptr: *mut kernel::ffi::c_void) { + // SAFETY: In `Self::new` we've passed a valid pointer to `Inner` to `devm_add_action()`, + // hence `ptr` must be a valid pointer to `Inner`. + let inner = unsafe { &*ptr.cast::<Inner<T>>() }; + + // Ensure that `inner` can't be used anymore after we signal completion of this callback. + let inner = ScopeGuard::new_with_data(inner, |inner| inner.devm.complete_all()); + + if !inner.data.revoke() { + // If `revoke()` returns false, it means that `Devres::drop` already started revoking + // `data` for us. Hence we have to wait until `Devres::drop` signals that it + // completed revoking `data`. + inner.revoke.wait_for_completion(); + } + } + + fn remove_action(&self) -> bool { + // SAFETY: + // - `self.dev` is a valid `Device`, + // - the `action` and `data` pointers are the exact same ones as given to + // `devm_add_action()` previously, + (unsafe { + bindings::devm_remove_action_nowarn( + self.dev.as_raw(), + Some(self.callback), + core::ptr::from_ref(self.inner()).cast_mut().cast(), + ) + } == 0) + } + + /// Return a reference of the [`Device`] this [`Devres`] instance has been created with. + pub fn device(&self) -> &Device { + &self.dev + } + + /// Obtain `&'a T`, bypassing the [`Revocable`]. + /// + /// This method allows to directly obtain a `&'a T`, bypassing the [`Revocable`], by presenting + /// a `&'a Device<Bound>` of the same [`Device`] this [`Devres`] instance has been created with. + /// + /// # Errors + /// + /// An error is returned if `dev` does not match the same [`Device`] this [`Devres`] instance + /// has been created with. + /// + /// # Examples + /// + /// ```no_run + /// # #![cfg(CONFIG_PCI)] + /// # use kernel::{device::Core, devres::Devres, pci}; + /// + /// fn from_core(dev: &pci::Device<Core>, devres: Devres<pci::Bar<0x4>>) -> Result { + /// let bar = devres.access(dev.as_ref())?; + /// + /// let _ = bar.read32(0x0); + /// + /// // might_sleep() + /// + /// bar.write32(0x42, 0x0); + /// + /// Ok(()) + /// } + /// ``` + pub fn access<'a>(&'a self, dev: &'a Device<Bound>) -> Result<&'a T> { + if self.dev.as_raw() != dev.as_raw() { + return Err(EINVAL); + } + + // SAFETY: `dev` being the same device as the device this `Devres` has been created for + // proves that `self.data` hasn't been revoked and is guaranteed to not be revoked as long + // as `dev` lives; `dev` lives at least as long as `self`. + Ok(unsafe { self.data().access() }) + } + + /// [`Devres`] accessor for [`Revocable::try_access`]. + pub fn try_access(&self) -> Option<RevocableGuard<'_, T>> { + self.data().try_access() + } + + /// [`Devres`] accessor for [`Revocable::try_access_with`]. + pub fn try_access_with<R, F: FnOnce(&T) -> R>(&self, f: F) -> Option<R> { + self.data().try_access_with(f) + } + + /// [`Devres`] accessor for [`Revocable::try_access_with_guard`]. + pub fn try_access_with_guard<'a>(&'a self, guard: &'a rcu::Guard) -> Option<&'a T> { + self.data().try_access_with_guard(guard) + } +} + +// SAFETY: `Devres` can be send to any task, if `T: Send`. +unsafe impl<T: Send> Send for Devres<T> {} + +// SAFETY: `Devres` can be shared with any task, if `T: Sync`. +unsafe impl<T: Send + Sync> Sync for Devres<T> {} + +#[pinned_drop] +impl<T: Send> PinnedDrop for Devres<T> { + fn drop(self: Pin<&mut Self>) { + // SAFETY: When `drop` runs, it is guaranteed that nobody is accessing the revocable data + // anymore, hence it is safe not to wait for the grace period to finish. + if unsafe { self.data().revoke_nosync() } { + // We revoked `self.data` before the devres action did, hence try to remove it. + if !self.remove_action() { + // We could not remove the devres action, which means that it now runs concurrently, + // hence signal that `self.data` has been revoked by us successfully. + self.inner().revoke.complete_all(); + + // Wait for `Self::devres_callback` to be done using this object. + self.inner().devm.wait_for_completion(); + } + } else { + // `Self::devres_callback` revokes `self.data` for us, hence wait for it to be done + // using this object. + self.inner().devm.wait_for_completion(); + } + + // INVARIANT: At this point it is guaranteed that `inner` can't be accessed any more. + // + // SAFETY: `inner` is valid for dropping. + unsafe { core::ptr::drop_in_place(self.inner.get()) }; + } +} + +/// Consume `data` and [`Drop::drop`] `data` once `dev` is unbound. +fn register_foreign<P>(dev: &Device<Bound>, data: P) -> Result +where + P: ForeignOwnable + Send + 'static, +{ + let ptr = data.into_foreign(); + + #[allow(clippy::missing_safety_doc)] + unsafe extern "C" fn callback<P: ForeignOwnable>(ptr: *mut kernel::ffi::c_void) { + // SAFETY: `ptr` is the pointer to the `ForeignOwnable` leaked above and hence valid. + drop(unsafe { P::from_foreign(ptr.cast()) }); + } + + // SAFETY: + // - `dev.as_raw()` is a pointer to a valid and bound device. + // - `ptr` is a valid pointer the `ForeignOwnable` devres takes ownership of. + to_result(unsafe { + // `devm_add_action_or_reset()` also calls `callback` on failure, such that the + // `ForeignOwnable` is released eventually. + bindings::devm_add_action_or_reset(dev.as_raw(), Some(callback::<P>), ptr.cast()) + }) +} + +/// Encapsulate `data` in a [`KBox`] and [`Drop::drop`] `data` once `dev` is unbound. +/// +/// # Examples +/// +/// ```no_run +/// use kernel::{device::{Bound, Device}, devres}; +/// +/// /// Registration of e.g. a class device, IRQ, etc. +/// struct Registration; +/// +/// impl Registration { +/// fn new() -> Self { +/// // register +/// +/// Self +/// } +/// } +/// +/// impl Drop for Registration { +/// fn drop(&mut self) { +/// // unregister +/// } +/// } +/// +/// fn from_bound_context(dev: &Device<Bound>) -> Result { +/// devres::register(dev, Registration::new(), GFP_KERNEL) +/// } +/// ``` +pub fn register<T, E>(dev: &Device<Bound>, data: impl PinInit<T, E>, flags: Flags) -> Result +where + T: Send + 'static, + Error: From<E>, +{ + let data = KBox::pin_init(data, flags)?; + + register_foreign(dev, data) +} diff --git a/rust/kernel/dma.rs b/rust/kernel/dma.rs new file mode 100644 index 000000000000..2bc8ab51ec28 --- /dev/null +++ b/rust/kernel/dma.rs @@ -0,0 +1,664 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Direct memory access (DMA). +//! +//! C header: [`include/linux/dma-mapping.h`](srctree/include/linux/dma-mapping.h) + +use crate::{ + bindings, build_assert, device, + device::{Bound, Core}, + error::{to_result, Result}, + prelude::*, + transmute::{AsBytes, FromBytes}, + types::ARef, +}; + +/// Trait to be implemented by DMA capable bus devices. +/// +/// The [`dma::Device`](Device) trait should be implemented by bus specific device representations, +/// where the underlying bus is DMA capable, such as [`pci::Device`](::kernel::pci::Device) or +/// [`platform::Device`](::kernel::platform::Device). +pub trait Device: AsRef<device::Device<Core>> { + /// Set up the device's DMA streaming addressing capabilities. + /// + /// This method is usually called once from `probe()` as soon as the device capabilities are + /// known. + /// + /// # Safety + /// + /// This method must not be called concurrently with any DMA allocation or mapping primitives, + /// such as [`CoherentAllocation::alloc_attrs`]. + unsafe fn dma_set_mask(&self, mask: DmaMask) -> Result { + // SAFETY: + // - By the type invariant of `device::Device`, `self.as_ref().as_raw()` is valid. + // - The safety requirement of this function guarantees that there are no concurrent calls + // to DMA allocation and mapping primitives using this mask. + to_result(unsafe { bindings::dma_set_mask(self.as_ref().as_raw(), mask.value()) }) + } + + /// Set up the device's DMA coherent addressing capabilities. + /// + /// This method is usually called once from `probe()` as soon as the device capabilities are + /// known. + /// + /// # Safety + /// + /// This method must not be called concurrently with any DMA allocation or mapping primitives, + /// such as [`CoherentAllocation::alloc_attrs`]. + unsafe fn dma_set_coherent_mask(&self, mask: DmaMask) -> Result { + // SAFETY: + // - By the type invariant of `device::Device`, `self.as_ref().as_raw()` is valid. + // - The safety requirement of this function guarantees that there are no concurrent calls + // to DMA allocation and mapping primitives using this mask. + to_result(unsafe { bindings::dma_set_coherent_mask(self.as_ref().as_raw(), mask.value()) }) + } + + /// Set up the device's DMA addressing capabilities. + /// + /// This is a combination of [`Device::dma_set_mask`] and [`Device::dma_set_coherent_mask`]. + /// + /// This method is usually called once from `probe()` as soon as the device capabilities are + /// known. + /// + /// # Safety + /// + /// This method must not be called concurrently with any DMA allocation or mapping primitives, + /// such as [`CoherentAllocation::alloc_attrs`]. + unsafe fn dma_set_mask_and_coherent(&self, mask: DmaMask) -> Result { + // SAFETY: + // - By the type invariant of `device::Device`, `self.as_ref().as_raw()` is valid. + // - The safety requirement of this function guarantees that there are no concurrent calls + // to DMA allocation and mapping primitives using this mask. + to_result(unsafe { + bindings::dma_set_mask_and_coherent(self.as_ref().as_raw(), mask.value()) + }) + } +} + +/// A DMA mask that holds a bitmask with the lowest `n` bits set. +/// +/// Use [`DmaMask::new`] or [`DmaMask::try_new`] to construct a value. Values +/// are guaranteed to never exceed the bit width of `u64`. +/// +/// This is the Rust equivalent of the C macro `DMA_BIT_MASK()`. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct DmaMask(u64); + +impl DmaMask { + /// Constructs a `DmaMask` with the lowest `n` bits set to `1`. + /// + /// For `n <= 64`, sets exactly the lowest `n` bits. + /// For `n > 64`, results in a build error. + /// + /// # Examples + /// + /// ``` + /// use kernel::dma::DmaMask; + /// + /// let mask0 = DmaMask::new::<0>(); + /// assert_eq!(mask0.value(), 0); + /// + /// let mask1 = DmaMask::new::<1>(); + /// assert_eq!(mask1.value(), 0b1); + /// + /// let mask64 = DmaMask::new::<64>(); + /// assert_eq!(mask64.value(), u64::MAX); + /// + /// // Build failure. + /// // let mask_overflow = DmaMask::new::<100>(); + /// ``` + #[inline] + pub const fn new<const N: u32>() -> Self { + let Ok(mask) = Self::try_new(N) else { + build_error!("Invalid DMA Mask."); + }; + + mask + } + + /// Constructs a `DmaMask` with the lowest `n` bits set to `1`. + /// + /// For `n <= 64`, sets exactly the lowest `n` bits. + /// For `n > 64`, returns [`EINVAL`]. + /// + /// # Examples + /// + /// ``` + /// use kernel::dma::DmaMask; + /// + /// let mask0 = DmaMask::try_new(0)?; + /// assert_eq!(mask0.value(), 0); + /// + /// let mask1 = DmaMask::try_new(1)?; + /// assert_eq!(mask1.value(), 0b1); + /// + /// let mask64 = DmaMask::try_new(64)?; + /// assert_eq!(mask64.value(), u64::MAX); + /// + /// let mask_overflow = DmaMask::try_new(100); + /// assert!(mask_overflow.is_err()); + /// # Ok::<(), Error>(()) + /// ``` + #[inline] + pub const fn try_new(n: u32) -> Result<Self> { + Ok(Self(match n { + 0 => 0, + 1..=64 => u64::MAX >> (64 - n), + _ => return Err(EINVAL), + })) + } + + /// Returns the underlying `u64` bitmask value. + #[inline] + pub const fn value(&self) -> u64 { + self.0 + } +} + +/// Possible attributes associated with a DMA mapping. +/// +/// They can be combined with the operators `|`, `&`, and `!`. +/// +/// Values can be used from the [`attrs`] module. +/// +/// # Examples +/// +/// ``` +/// # use kernel::device::{Bound, Device}; +/// use kernel::dma::{attrs::*, CoherentAllocation}; +/// +/// # fn test(dev: &Device<Bound>) -> Result { +/// let attribs = DMA_ATTR_FORCE_CONTIGUOUS | DMA_ATTR_NO_WARN; +/// let c: CoherentAllocation<u64> = +/// CoherentAllocation::alloc_attrs(dev, 4, GFP_KERNEL, attribs)?; +/// # Ok::<(), Error>(()) } +/// ``` +#[derive(Clone, Copy, PartialEq)] +#[repr(transparent)] +pub struct Attrs(u32); + +impl Attrs { + /// Get the raw representation of this attribute. + pub(crate) fn as_raw(self) -> crate::ffi::c_ulong { + self.0 as crate::ffi::c_ulong + } + + /// Check whether `flags` is contained in `self`. + pub fn contains(self, flags: Attrs) -> bool { + (self & flags) == flags + } +} + +impl core::ops::BitOr for Attrs { + type Output = Self; + fn bitor(self, rhs: Self) -> Self::Output { + Self(self.0 | rhs.0) + } +} + +impl core::ops::BitAnd for Attrs { + type Output = Self; + fn bitand(self, rhs: Self) -> Self::Output { + Self(self.0 & rhs.0) + } +} + +impl core::ops::Not for Attrs { + type Output = Self; + fn not(self) -> Self::Output { + Self(!self.0) + } +} + +/// DMA mapping attributes. +pub mod attrs { + use super::Attrs; + + /// Specifies that reads and writes to the mapping may be weakly ordered, that is that reads + /// and writes may pass each other. + pub const DMA_ATTR_WEAK_ORDERING: Attrs = Attrs(bindings::DMA_ATTR_WEAK_ORDERING); + + /// Specifies that writes to the mapping may be buffered to improve performance. + pub const DMA_ATTR_WRITE_COMBINE: Attrs = Attrs(bindings::DMA_ATTR_WRITE_COMBINE); + + /// Lets the platform to avoid creating a kernel virtual mapping for the allocated buffer. + pub const DMA_ATTR_NO_KERNEL_MAPPING: Attrs = Attrs(bindings::DMA_ATTR_NO_KERNEL_MAPPING); + + /// Allows platform code to skip synchronization of the CPU cache for the given buffer assuming + /// that it has been already transferred to 'device' domain. + pub const DMA_ATTR_SKIP_CPU_SYNC: Attrs = Attrs(bindings::DMA_ATTR_SKIP_CPU_SYNC); + + /// Forces contiguous allocation of the buffer in physical memory. + pub const DMA_ATTR_FORCE_CONTIGUOUS: Attrs = Attrs(bindings::DMA_ATTR_FORCE_CONTIGUOUS); + + /// Hints DMA-mapping subsystem that it's probably not worth the time to try + /// to allocate memory to in a way that gives better TLB efficiency. + pub const DMA_ATTR_ALLOC_SINGLE_PAGES: Attrs = Attrs(bindings::DMA_ATTR_ALLOC_SINGLE_PAGES); + + /// This tells the DMA-mapping subsystem to suppress allocation failure reports (similarly to + /// `__GFP_NOWARN`). + pub const DMA_ATTR_NO_WARN: Attrs = Attrs(bindings::DMA_ATTR_NO_WARN); + + /// Indicates that the buffer is fully accessible at an elevated privilege level (and + /// ideally inaccessible or at least read-only at lesser-privileged levels). + pub const DMA_ATTR_PRIVILEGED: Attrs = Attrs(bindings::DMA_ATTR_PRIVILEGED); +} + +/// An abstraction of the `dma_alloc_coherent` API. +/// +/// This is an abstraction around the `dma_alloc_coherent` API which is used to allocate and map +/// large coherent DMA regions. +/// +/// A [`CoherentAllocation`] instance contains a pointer to the allocated region (in the +/// processor's virtual address space) and the device address which can be given to the device +/// as the DMA address base of the region. The region is released once [`CoherentAllocation`] +/// is dropped. +/// +/// # Invariants +/// +/// - For the lifetime of an instance of [`CoherentAllocation`], the `cpu_addr` is a valid pointer +/// to an allocated region of coherent memory and `dma_handle` is the DMA address base of the +/// region. +/// - The size in bytes of the allocation is equal to `size_of::<T> * count`. +/// - `size_of::<T> * count` fits into a `usize`. +// TODO +// +// DMA allocations potentially carry device resources (e.g.IOMMU mappings), hence for soundness +// reasons DMA allocation would need to be embedded in a `Devres` container, in order to ensure +// that device resources can never survive device unbind. +// +// However, it is neither desirable nor necessary to protect the allocated memory of the DMA +// allocation from surviving device unbind; it would require RCU read side critical sections to +// access the memory, which may require subsequent unnecessary copies. +// +// Hence, find a way to revoke the device resources of a `CoherentAllocation`, but not the +// entire `CoherentAllocation` including the allocated memory itself. +pub struct CoherentAllocation<T: AsBytes + FromBytes> { + dev: ARef<device::Device>, + dma_handle: bindings::dma_addr_t, + count: usize, + cpu_addr: *mut T, + dma_attrs: Attrs, +} + +impl<T: AsBytes + FromBytes> CoherentAllocation<T> { + /// Allocates a region of `size_of::<T> * count` of coherent memory. + /// + /// # Examples + /// + /// ``` + /// # use kernel::device::{Bound, Device}; + /// use kernel::dma::{attrs::*, CoherentAllocation}; + /// + /// # fn test(dev: &Device<Bound>) -> Result { + /// let c: CoherentAllocation<u64> = + /// CoherentAllocation::alloc_attrs(dev, 4, GFP_KERNEL, DMA_ATTR_NO_WARN)?; + /// # Ok::<(), Error>(()) } + /// ``` + pub fn alloc_attrs( + dev: &device::Device<Bound>, + count: usize, + gfp_flags: kernel::alloc::Flags, + dma_attrs: Attrs, + ) -> Result<CoherentAllocation<T>> { + build_assert!( + core::mem::size_of::<T>() > 0, + "It doesn't make sense for the allocated type to be a ZST" + ); + + let size = count + .checked_mul(core::mem::size_of::<T>()) + .ok_or(EOVERFLOW)?; + let mut dma_handle = 0; + // SAFETY: Device pointer is guaranteed as valid by the type invariant on `Device`. + let ret = unsafe { + bindings::dma_alloc_attrs( + dev.as_raw(), + size, + &mut dma_handle, + gfp_flags.as_raw(), + dma_attrs.as_raw(), + ) + }; + if ret.is_null() { + return Err(ENOMEM); + } + // INVARIANT: + // - We just successfully allocated a coherent region which is accessible for + // `count` elements, hence the cpu address is valid. We also hold a refcounted reference + // to the device. + // - The allocated `size` is equal to `size_of::<T> * count`. + // - The allocated `size` fits into a `usize`. + Ok(Self { + dev: dev.into(), + dma_handle, + count, + cpu_addr: ret.cast::<T>(), + dma_attrs, + }) + } + + /// Performs the same functionality as [`CoherentAllocation::alloc_attrs`], except the + /// `dma_attrs` is 0 by default. + pub fn alloc_coherent( + dev: &device::Device<Bound>, + count: usize, + gfp_flags: kernel::alloc::Flags, + ) -> Result<CoherentAllocation<T>> { + CoherentAllocation::alloc_attrs(dev, count, gfp_flags, Attrs(0)) + } + + /// Returns the number of elements `T` in this allocation. + /// + /// Note that this is not the size of the allocation in bytes, which is provided by + /// [`Self::size`]. + pub fn count(&self) -> usize { + self.count + } + + /// Returns the size in bytes of this allocation. + pub fn size(&self) -> usize { + // INVARIANT: The type invariant of `Self` guarantees that `size_of::<T> * count` fits into + // a `usize`. + self.count * core::mem::size_of::<T>() + } + + /// Returns the base address to the allocated region in the CPU's virtual address space. + pub fn start_ptr(&self) -> *const T { + self.cpu_addr + } + + /// Returns the base address to the allocated region in the CPU's virtual address space as + /// a mutable pointer. + pub fn start_ptr_mut(&mut self) -> *mut T { + self.cpu_addr + } + + /// Returns a DMA handle which may be given to the device as the DMA address base of + /// the region. + pub fn dma_handle(&self) -> bindings::dma_addr_t { + self.dma_handle + } + + /// Returns a DMA handle starting at `offset` (in units of `T`) which may be given to the + /// device as the DMA address base of the region. + /// + /// Returns `EINVAL` if `offset` is not within the bounds of the allocation. + pub fn dma_handle_with_offset(&self, offset: usize) -> Result<bindings::dma_addr_t> { + if offset >= self.count { + Err(EINVAL) + } else { + // INVARIANT: The type invariant of `Self` guarantees that `size_of::<T> * count` fits + // into a `usize`, and `offset` is inferior to `count`. + Ok(self.dma_handle + (offset * core::mem::size_of::<T>()) as bindings::dma_addr_t) + } + } + + /// Common helper to validate a range applied from the allocated region in the CPU's virtual + /// address space. + fn validate_range(&self, offset: usize, count: usize) -> Result { + if offset.checked_add(count).ok_or(EOVERFLOW)? > self.count { + return Err(EINVAL); + } + Ok(()) + } + + /// Returns the data from the region starting from `offset` as a slice. + /// `offset` and `count` are in units of `T`, not the number of bytes. + /// + /// For ringbuffer type of r/w access or use-cases where the pointer to the live data is needed, + /// [`CoherentAllocation::start_ptr`] or [`CoherentAllocation::start_ptr_mut`] could be used + /// instead. + /// + /// # Safety + /// + /// * Callers must ensure that the device does not read/write to/from memory while the returned + /// slice is live. + /// * Callers must ensure that this call does not race with a write to the same region while + /// the returned slice is live. + pub unsafe fn as_slice(&self, offset: usize, count: usize) -> Result<&[T]> { + self.validate_range(offset, count)?; + // SAFETY: + // - The pointer is valid due to type invariant on `CoherentAllocation`, + // we've just checked that the range and index is within bounds. The immutability of the + // data is also guaranteed by the safety requirements of the function. + // - `offset + count` can't overflow since it is smaller than `self.count` and we've checked + // that `self.count` won't overflow early in the constructor. + Ok(unsafe { core::slice::from_raw_parts(self.cpu_addr.add(offset), count) }) + } + + /// Performs the same functionality as [`CoherentAllocation::as_slice`], except that a mutable + /// slice is returned. + /// + /// # Safety + /// + /// * Callers must ensure that the device does not read/write to/from memory while the returned + /// slice is live. + /// * Callers must ensure that this call does not race with a read or write to the same region + /// while the returned slice is live. + pub unsafe fn as_slice_mut(&mut self, offset: usize, count: usize) -> Result<&mut [T]> { + self.validate_range(offset, count)?; + // SAFETY: + // - The pointer is valid due to type invariant on `CoherentAllocation`, + // we've just checked that the range and index is within bounds. The immutability of the + // data is also guaranteed by the safety requirements of the function. + // - `offset + count` can't overflow since it is smaller than `self.count` and we've checked + // that `self.count` won't overflow early in the constructor. + Ok(unsafe { core::slice::from_raw_parts_mut(self.cpu_addr.add(offset), count) }) + } + + /// Writes data to the region starting from `offset`. `offset` is in units of `T`, not the + /// number of bytes. + /// + /// # Safety + /// + /// * Callers must ensure that the device does not read/write to/from memory while the returned + /// slice is live. + /// * Callers must ensure that this call does not race with a read or write to the same region + /// that overlaps with this write. + /// + /// # Examples + /// + /// ``` + /// # fn test(alloc: &mut kernel::dma::CoherentAllocation<u8>) -> Result { + /// let somedata: [u8; 4] = [0xf; 4]; + /// let buf: &[u8] = &somedata; + /// // SAFETY: There is no concurrent HW operation on the device and no other R/W access to the + /// // region. + /// unsafe { alloc.write(buf, 0)?; } + /// # Ok::<(), Error>(()) } + /// ``` + pub unsafe fn write(&mut self, src: &[T], offset: usize) -> Result { + self.validate_range(offset, src.len())?; + // SAFETY: + // - The pointer is valid due to type invariant on `CoherentAllocation` + // and we've just checked that the range and index is within bounds. + // - `offset + count` can't overflow since it is smaller than `self.count` and we've checked + // that `self.count` won't overflow early in the constructor. + unsafe { + core::ptr::copy_nonoverlapping(src.as_ptr(), self.cpu_addr.add(offset), src.len()) + }; + Ok(()) + } + + /// Returns a pointer to an element from the region with bounds checking. `offset` is in + /// units of `T`, not the number of bytes. + /// + /// Public but hidden since it should only be used from [`dma_read`] and [`dma_write`] macros. + #[doc(hidden)] + pub fn item_from_index(&self, offset: usize) -> Result<*mut T> { + if offset >= self.count { + return Err(EINVAL); + } + // SAFETY: + // - The pointer is valid due to type invariant on `CoherentAllocation` + // and we've just checked that the range and index is within bounds. + // - `offset` can't overflow since it is smaller than `self.count` and we've checked + // that `self.count` won't overflow early in the constructor. + Ok(unsafe { self.cpu_addr.add(offset) }) + } + + /// Reads the value of `field` and ensures that its type is [`FromBytes`]. + /// + /// # Safety + /// + /// This must be called from the [`dma_read`] macro which ensures that the `field` pointer is + /// validated beforehand. + /// + /// Public but hidden since it should only be used from [`dma_read`] macro. + #[doc(hidden)] + pub unsafe fn field_read<F: FromBytes>(&self, field: *const F) -> F { + // SAFETY: + // - By the safety requirements field is valid. + // - Using read_volatile() here is not sound as per the usual rules, the usage here is + // a special exception with the following notes in place. When dealing with a potential + // race from a hardware or code outside kernel (e.g. user-space program), we need that + // read on a valid memory is not UB. Currently read_volatile() is used for this, and the + // rationale behind is that it should generate the same code as READ_ONCE() which the + // kernel already relies on to avoid UB on data races. Note that the usage of + // read_volatile() is limited to this particular case, it cannot be used to prevent + // the UB caused by racing between two kernel functions nor do they provide atomicity. + unsafe { field.read_volatile() } + } + + /// Writes a value to `field` and ensures that its type is [`AsBytes`]. + /// + /// # Safety + /// + /// This must be called from the [`dma_write`] macro which ensures that the `field` pointer is + /// validated beforehand. + /// + /// Public but hidden since it should only be used from [`dma_write`] macro. + #[doc(hidden)] + pub unsafe fn field_write<F: AsBytes>(&self, field: *mut F, val: F) { + // SAFETY: + // - By the safety requirements field is valid. + // - Using write_volatile() here is not sound as per the usual rules, the usage here is + // a special exception with the following notes in place. When dealing with a potential + // race from a hardware or code outside kernel (e.g. user-space program), we need that + // write on a valid memory is not UB. Currently write_volatile() is used for this, and the + // rationale behind is that it should generate the same code as WRITE_ONCE() which the + // kernel already relies on to avoid UB on data races. Note that the usage of + // write_volatile() is limited to this particular case, it cannot be used to prevent + // the UB caused by racing between two kernel functions nor do they provide atomicity. + unsafe { field.write_volatile(val) } + } +} + +/// Note that the device configured to do DMA must be halted before this object is dropped. +impl<T: AsBytes + FromBytes> Drop for CoherentAllocation<T> { + fn drop(&mut self) { + let size = self.count * core::mem::size_of::<T>(); + // SAFETY: Device pointer is guaranteed as valid by the type invariant on `Device`. + // The cpu address, and the dma handle are valid due to the type invariants on + // `CoherentAllocation`. + unsafe { + bindings::dma_free_attrs( + self.dev.as_raw(), + size, + self.cpu_addr.cast(), + self.dma_handle, + self.dma_attrs.as_raw(), + ) + } + } +} + +// SAFETY: It is safe to send a `CoherentAllocation` to another thread if `T` +// can be sent to another thread. +unsafe impl<T: AsBytes + FromBytes + Send> Send for CoherentAllocation<T> {} + +/// Reads a field of an item from an allocated region of structs. +/// +/// # Examples +/// +/// ``` +/// use kernel::device::Device; +/// use kernel::dma::{attrs::*, CoherentAllocation}; +/// +/// struct MyStruct { field: u32, } +/// +/// // SAFETY: All bit patterns are acceptable values for `MyStruct`. +/// unsafe impl kernel::transmute::FromBytes for MyStruct{}; +/// // SAFETY: Instances of `MyStruct` have no uninitialized portions. +/// unsafe impl kernel::transmute::AsBytes for MyStruct{}; +/// +/// # fn test(alloc: &kernel::dma::CoherentAllocation<MyStruct>) -> Result { +/// let whole = kernel::dma_read!(alloc[2]); +/// let field = kernel::dma_read!(alloc[1].field); +/// # Ok::<(), Error>(()) } +/// ``` +#[macro_export] +macro_rules! dma_read { + ($dma:expr, $idx: expr, $($field:tt)*) => {{ + (|| -> ::core::result::Result<_, $crate::error::Error> { + let item = $crate::dma::CoherentAllocation::item_from_index(&$dma, $idx)?; + // SAFETY: `item_from_index` ensures that `item` is always a valid pointer and can be + // dereferenced. The compiler also further validates the expression on whether `field` + // is a member of `item` when expanded by the macro. + unsafe { + let ptr_field = ::core::ptr::addr_of!((*item) $($field)*); + ::core::result::Result::Ok( + $crate::dma::CoherentAllocation::field_read(&$dma, ptr_field) + ) + } + })() + }}; + ($dma:ident [ $idx:expr ] $($field:tt)* ) => { + $crate::dma_read!($dma, $idx, $($field)*) + }; + ($($dma:ident).* [ $idx:expr ] $($field:tt)* ) => { + $crate::dma_read!($($dma).*, $idx, $($field)*) + }; +} + +/// Writes to a field of an item from an allocated region of structs. +/// +/// # Examples +/// +/// ``` +/// use kernel::device::Device; +/// use kernel::dma::{attrs::*, CoherentAllocation}; +/// +/// struct MyStruct { member: u32, } +/// +/// // SAFETY: All bit patterns are acceptable values for `MyStruct`. +/// unsafe impl kernel::transmute::FromBytes for MyStruct{}; +/// // SAFETY: Instances of `MyStruct` have no uninitialized portions. +/// unsafe impl kernel::transmute::AsBytes for MyStruct{}; +/// +/// # fn test(alloc: &kernel::dma::CoherentAllocation<MyStruct>) -> Result { +/// kernel::dma_write!(alloc[2].member = 0xf); +/// kernel::dma_write!(alloc[1] = MyStruct { member: 0xf }); +/// # Ok::<(), Error>(()) } +/// ``` +#[macro_export] +macro_rules! dma_write { + ($dma:ident [ $idx:expr ] $($field:tt)*) => {{ + $crate::dma_write!($dma, $idx, $($field)*) + }}; + ($($dma:ident).* [ $idx:expr ] $($field:tt)* ) => {{ + $crate::dma_write!($($dma).*, $idx, $($field)*) + }}; + ($dma:expr, $idx: expr, = $val:expr) => { + (|| -> ::core::result::Result<_, $crate::error::Error> { + let item = $crate::dma::CoherentAllocation::item_from_index(&$dma, $idx)?; + // SAFETY: `item_from_index` ensures that `item` is always a valid item. + unsafe { $crate::dma::CoherentAllocation::field_write(&$dma, item, $val) } + ::core::result::Result::Ok(()) + })() + }; + ($dma:expr, $idx: expr, $(.$field:ident)* = $val:expr) => { + (|| -> ::core::result::Result<_, $crate::error::Error> { + let item = $crate::dma::CoherentAllocation::item_from_index(&$dma, $idx)?; + // SAFETY: `item_from_index` ensures that `item` is always a valid pointer and can be + // dereferenced. The compiler also further validates the expression on whether `field` + // is a member of `item` when expanded by the macro. + unsafe { + let ptr_field = ::core::ptr::addr_of_mut!((*item) $(.$field)*); + $crate::dma::CoherentAllocation::field_write(&$dma, ptr_field, $val) + } + ::core::result::Result::Ok(()) + })() + }; +} diff --git a/rust/kernel/driver.rs b/rust/kernel/driver.rs new file mode 100644 index 000000000000..279e3af20682 --- /dev/null +++ b/rust/kernel/driver.rs @@ -0,0 +1,318 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Generic support for drivers of different buses (e.g., PCI, Platform, Amba, etc.). +//! +//! This documentation describes how to implement a bus specific driver API and how to align it with +//! the design of (bus specific) devices. +//! +//! Note: Readers are expected to know the content of the documentation of [`Device`] and +//! [`DeviceContext`]. +//! +//! # Driver Trait +//! +//! The main driver interface is defined by a bus specific driver trait. For instance: +//! +//! ```ignore +//! pub trait Driver: Send { +//! /// The type holding information about each device ID supported by the driver. +//! type IdInfo: 'static; +//! +//! /// The table of OF device ids supported by the driver. +//! const OF_ID_TABLE: Option<of::IdTable<Self::IdInfo>> = None; +//! +//! /// The table of ACPI device ids supported by the driver. +//! const ACPI_ID_TABLE: Option<acpi::IdTable<Self::IdInfo>> = None; +//! +//! /// Driver probe. +//! fn probe(dev: &Device<device::Core>, id_info: &Self::IdInfo) -> Result<Pin<KBox<Self>>>; +//! +//! /// Driver unbind (optional). +//! fn unbind(dev: &Device<device::Core>, this: Pin<&Self>) { +//! let _ = (dev, this); +//! } +//! } +//! ``` +//! +//! For specific examples see [`auxiliary::Driver`], [`pci::Driver`] and [`platform::Driver`]. +//! +//! The `probe()` callback should return a `Result<Pin<KBox<Self>>>`, i.e. the driver's private +//! data. The bus abstraction should store the pointer in the corresponding bus device. The generic +//! [`Device`] infrastructure provides common helpers for this purpose on its +//! [`Device<CoreInternal>`] implementation. +//! +//! All driver callbacks should provide a reference to the driver's private data. Once the driver +//! is unbound from the device, the bus abstraction should take back the ownership of the driver's +//! private data from the corresponding [`Device`] and [`drop`] it. +//! +//! All driver callbacks should provide a [`Device<Core>`] reference (see also [`device::Core`]). +//! +//! # Adapter +//! +//! The adapter implementation of a bus represents the abstraction layer between the C bus +//! callbacks and the Rust bus callbacks. It therefore has to be generic over an implementation of +//! the [driver trait](#driver-trait). +//! +//! ```ignore +//! pub struct Adapter<T: Driver>; +//! ``` +//! +//! There's a common [`Adapter`] trait that can be implemented to inherit common driver +//! infrastructure, such as finding the ID info from an [`of::IdTable`] or [`acpi::IdTable`]. +//! +//! # Driver Registration +//! +//! In order to register C driver types (such as `struct platform_driver`) the [adapter](#adapter) +//! should implement the [`RegistrationOps`] trait. +//! +//! This trait implementation can be used to create the actual registration with the common +//! [`Registration`] type. +//! +//! Typically, bus abstractions want to provide a bus specific `module_bus_driver!` macro, which +//! creates a kernel module with exactly one [`Registration`] for the bus specific adapter. +//! +//! The generic driver infrastructure provides a helper for this with the [`module_driver`] macro. +//! +//! # Device IDs +//! +//! Besides the common device ID types, such as [`of::DeviceId`] and [`acpi::DeviceId`], most buses +//! may need to implement their own device ID types. +//! +//! For this purpose the generic infrastructure in [`device_id`] should be used. +//! +//! [`auxiliary::Driver`]: kernel::auxiliary::Driver +//! [`Core`]: device::Core +//! [`Device`]: device::Device +//! [`Device<Core>`]: device::Device<device::Core> +//! [`Device<CoreInternal>`]: device::Device<device::CoreInternal> +//! [`DeviceContext`]: device::DeviceContext +//! [`device_id`]: kernel::device_id +//! [`module_driver`]: kernel::module_driver +//! [`pci::Driver`]: kernel::pci::Driver +//! [`platform::Driver`]: kernel::platform::Driver + +use crate::error::{Error, Result}; +use crate::{acpi, device, of, str::CStr, try_pin_init, types::Opaque, ThisModule}; +use core::pin::Pin; +use pin_init::{pin_data, pinned_drop, PinInit}; + +/// The [`RegistrationOps`] trait serves as generic interface for subsystems (e.g., PCI, Platform, +/// Amba, etc.) to provide the corresponding subsystem specific implementation to register / +/// unregister a driver of the particular type (`RegType`). +/// +/// For instance, the PCI subsystem would set `RegType` to `bindings::pci_driver` and call +/// `bindings::__pci_register_driver` from `RegistrationOps::register` and +/// `bindings::pci_unregister_driver` from `RegistrationOps::unregister`. +/// +/// # Safety +/// +/// A call to [`RegistrationOps::unregister`] for a given instance of `RegType` is only valid if a +/// preceding call to [`RegistrationOps::register`] has been successful. +pub unsafe trait RegistrationOps { + /// The type that holds information about the registration. This is typically a struct defined + /// by the C portion of the kernel. + type RegType: Default; + + /// Registers a driver. + /// + /// # Safety + /// + /// On success, `reg` must remain pinned and valid until the matching call to + /// [`RegistrationOps::unregister`]. + unsafe fn register( + reg: &Opaque<Self::RegType>, + name: &'static CStr, + module: &'static ThisModule, + ) -> Result; + + /// Unregisters a driver previously registered with [`RegistrationOps::register`]. + /// + /// # Safety + /// + /// Must only be called after a preceding successful call to [`RegistrationOps::register`] for + /// the same `reg`. + unsafe fn unregister(reg: &Opaque<Self::RegType>); +} + +/// A [`Registration`] is a generic type that represents the registration of some driver type (e.g. +/// `bindings::pci_driver`). Therefore a [`Registration`] must be initialized with a type that +/// implements the [`RegistrationOps`] trait, such that the generic `T::register` and +/// `T::unregister` calls result in the subsystem specific registration calls. +/// +///Once the `Registration` structure is dropped, the driver is unregistered. +#[pin_data(PinnedDrop)] +pub struct Registration<T: RegistrationOps> { + #[pin] + reg: Opaque<T::RegType>, +} + +// SAFETY: `Registration` has no fields or methods accessible via `&Registration`, so it is safe to +// share references to it with multiple threads as nothing can be done. +unsafe impl<T: RegistrationOps> Sync for Registration<T> {} + +// SAFETY: Both registration and unregistration are implemented in C and safe to be performed from +// any thread, so `Registration` is `Send`. +unsafe impl<T: RegistrationOps> Send for Registration<T> {} + +impl<T: RegistrationOps> Registration<T> { + /// Creates a new instance of the registration object. + pub fn new(name: &'static CStr, module: &'static ThisModule) -> impl PinInit<Self, Error> { + try_pin_init!(Self { + reg <- Opaque::try_ffi_init(|ptr: *mut T::RegType| { + // SAFETY: `try_ffi_init` guarantees that `ptr` is valid for write. + unsafe { ptr.write(T::RegType::default()) }; + + // SAFETY: `try_ffi_init` guarantees that `ptr` is valid for write, and it has + // just been initialised above, so it's also valid for read. + let drv = unsafe { &*(ptr as *const Opaque<T::RegType>) }; + + // SAFETY: `drv` is guaranteed to be pinned until `T::unregister`. + unsafe { T::register(drv, name, module) } + }), + }) + } +} + +#[pinned_drop] +impl<T: RegistrationOps> PinnedDrop for Registration<T> { + fn drop(self: Pin<&mut Self>) { + // SAFETY: The existence of `self` guarantees that `self.reg` has previously been + // successfully registered with `T::register` + unsafe { T::unregister(&self.reg) }; + } +} + +/// Declares a kernel module that exposes a single driver. +/// +/// It is meant to be used as a helper by other subsystems so they can more easily expose their own +/// macros. +#[macro_export] +macro_rules! module_driver { + (<$gen_type:ident>, $driver_ops:ty, { type: $type:ty, $($f:tt)* }) => { + type Ops<$gen_type> = $driver_ops; + + #[$crate::prelude::pin_data] + struct DriverModule { + #[pin] + _driver: $crate::driver::Registration<Ops<$type>>, + } + + impl $crate::InPlaceModule for DriverModule { + fn init( + module: &'static $crate::ThisModule + ) -> impl ::pin_init::PinInit<Self, $crate::error::Error> { + $crate::try_pin_init!(Self { + _driver <- $crate::driver::Registration::new( + <Self as $crate::ModuleMetadata>::NAME, + module, + ), + }) + } + } + + $crate::prelude::module! { + type: DriverModule, + $($f)* + } + } +} + +/// The bus independent adapter to match a drivers and a devices. +/// +/// This trait should be implemented by the bus specific adapter, which represents the connection +/// of a device and a driver. +/// +/// It provides bus independent functions for device / driver interactions. +pub trait Adapter { + /// The type holding driver private data about each device id supported by the driver. + type IdInfo: 'static; + + /// The [`acpi::IdTable`] of the corresponding driver + fn acpi_id_table() -> Option<acpi::IdTable<Self::IdInfo>>; + + /// Returns the driver's private data from the matching entry in the [`acpi::IdTable`], if any. + /// + /// If this returns `None`, it means there is no match with an entry in the [`acpi::IdTable`]. + fn acpi_id_info(dev: &device::Device) -> Option<&'static Self::IdInfo> { + #[cfg(not(CONFIG_ACPI))] + { + let _ = dev; + None + } + + #[cfg(CONFIG_ACPI)] + { + let table = Self::acpi_id_table()?; + + // SAFETY: + // - `table` has static lifetime, hence it's valid for read, + // - `dev` is guaranteed to be valid while it's alive, and so is `dev.as_raw()`. + let raw_id = unsafe { bindings::acpi_match_device(table.as_ptr(), dev.as_raw()) }; + + if raw_id.is_null() { + None + } else { + // SAFETY: `DeviceId` is a `#[repr(transparent)]` wrapper of `struct acpi_device_id` + // and does not add additional invariants, so it's safe to transmute. + let id = unsafe { &*raw_id.cast::<acpi::DeviceId>() }; + + Some(table.info(<acpi::DeviceId as crate::device_id::RawDeviceIdIndex>::index(id))) + } + } + } + + /// The [`of::IdTable`] of the corresponding driver. + fn of_id_table() -> Option<of::IdTable<Self::IdInfo>>; + + /// Returns the driver's private data from the matching entry in the [`of::IdTable`], if any. + /// + /// If this returns `None`, it means there is no match with an entry in the [`of::IdTable`]. + fn of_id_info(dev: &device::Device) -> Option<&'static Self::IdInfo> { + #[cfg(not(CONFIG_OF))] + { + let _ = dev; + None + } + + #[cfg(CONFIG_OF)] + { + let table = Self::of_id_table()?; + + // SAFETY: + // - `table` has static lifetime, hence it's valid for read, + // - `dev` is guaranteed to be valid while it's alive, and so is `dev.as_raw()`. + let raw_id = unsafe { bindings::of_match_device(table.as_ptr(), dev.as_raw()) }; + + if raw_id.is_null() { + None + } else { + // SAFETY: `DeviceId` is a `#[repr(transparent)]` wrapper of `struct of_device_id` + // and does not add additional invariants, so it's safe to transmute. + let id = unsafe { &*raw_id.cast::<of::DeviceId>() }; + + Some( + table.info(<of::DeviceId as crate::device_id::RawDeviceIdIndex>::index( + id, + )), + ) + } + } + } + + /// Returns the driver's private data from the matching entry of any of the ID tables, if any. + /// + /// If this returns `None`, it means that there is no match in any of the ID tables directly + /// associated with a [`device::Device`]. + fn id_info(dev: &device::Device) -> Option<&'static Self::IdInfo> { + let id = Self::acpi_id_info(dev); + if id.is_some() { + return id; + } + + let id = Self::of_id_info(dev); + if id.is_some() { + return id; + } + + None + } +} diff --git a/rust/kernel/drm/device.rs b/rust/kernel/drm/device.rs new file mode 100644 index 000000000000..f8f1db5eeb0f --- /dev/null +++ b/rust/kernel/drm/device.rs @@ -0,0 +1,228 @@ +// SPDX-License-Identifier: GPL-2.0 OR MIT + +//! DRM device. +//! +//! C header: [`include/drm/drm_device.h`](srctree/include/drm/drm_device.h) + +use crate::{ + alloc::allocator::Kmalloc, + bindings, device, drm, + drm::driver::AllocImpl, + error::from_err_ptr, + error::Result, + prelude::*, + types::{ARef, AlwaysRefCounted, Opaque}, +}; +use core::{alloc::Layout, mem, ops::Deref, ptr, ptr::NonNull}; + +#[cfg(CONFIG_DRM_LEGACY)] +macro_rules! drm_legacy_fields { + ( $($field:ident: $val:expr),* $(,)? ) => { + bindings::drm_driver { + $( $field: $val ),*, + firstopen: None, + preclose: None, + dma_ioctl: None, + dma_quiescent: None, + context_dtor: None, + irq_handler: None, + irq_preinstall: None, + irq_postinstall: None, + irq_uninstall: None, + get_vblank_counter: None, + enable_vblank: None, + disable_vblank: None, + dev_priv_size: 0, + } + } +} + +#[cfg(not(CONFIG_DRM_LEGACY))] +macro_rules! drm_legacy_fields { + ( $($field:ident: $val:expr),* $(,)? ) => { + bindings::drm_driver { + $( $field: $val ),* + } + } +} + +/// A typed DRM device with a specific `drm::Driver` implementation. +/// +/// The device is always reference-counted. +/// +/// # Invariants +/// +/// `self.dev` is a valid instance of a `struct device`. +#[repr(C)] +pub struct Device<T: drm::Driver> { + dev: Opaque<bindings::drm_device>, + data: T::Data, +} + +impl<T: drm::Driver> Device<T> { + const VTABLE: bindings::drm_driver = drm_legacy_fields! { + load: None, + open: Some(drm::File::<T::File>::open_callback), + postclose: Some(drm::File::<T::File>::postclose_callback), + unload: None, + release: Some(Self::release), + master_set: None, + master_drop: None, + debugfs_init: None, + gem_create_object: T::Object::ALLOC_OPS.gem_create_object, + prime_handle_to_fd: T::Object::ALLOC_OPS.prime_handle_to_fd, + prime_fd_to_handle: T::Object::ALLOC_OPS.prime_fd_to_handle, + gem_prime_import: T::Object::ALLOC_OPS.gem_prime_import, + gem_prime_import_sg_table: T::Object::ALLOC_OPS.gem_prime_import_sg_table, + dumb_create: T::Object::ALLOC_OPS.dumb_create, + dumb_map_offset: T::Object::ALLOC_OPS.dumb_map_offset, + show_fdinfo: None, + fbdev_probe: None, + + major: T::INFO.major, + minor: T::INFO.minor, + patchlevel: T::INFO.patchlevel, + name: T::INFO.name.as_char_ptr().cast_mut(), + desc: T::INFO.desc.as_char_ptr().cast_mut(), + + driver_features: drm::driver::FEAT_GEM, + ioctls: T::IOCTLS.as_ptr(), + num_ioctls: T::IOCTLS.len() as i32, + fops: &Self::GEM_FOPS, + }; + + const GEM_FOPS: bindings::file_operations = drm::gem::create_fops(); + + /// Create a new `drm::Device` for a `drm::Driver`. + pub fn new(dev: &device::Device, data: impl PinInit<T::Data, Error>) -> Result<ARef<Self>> { + // `__drm_dev_alloc` uses `kmalloc()` to allocate memory, hence ensure a `kmalloc()` + // compatible `Layout`. + let layout = Kmalloc::aligned_layout(Layout::new::<Self>()); + + // SAFETY: + // - `VTABLE`, as a `const` is pinned to the read-only section of the compilation, + // - `dev` is valid by its type invarants, + let raw_drm: *mut Self = unsafe { + bindings::__drm_dev_alloc( + dev.as_raw(), + &Self::VTABLE, + layout.size(), + mem::offset_of!(Self, dev), + ) + } + .cast(); + let raw_drm = NonNull::new(from_err_ptr(raw_drm)?).ok_or(ENOMEM)?; + + // SAFETY: `raw_drm` is a valid pointer to `Self`. + let raw_data = unsafe { ptr::addr_of_mut!((*raw_drm.as_ptr()).data) }; + + // SAFETY: + // - `raw_data` is a valid pointer to uninitialized memory. + // - `raw_data` will not move until it is dropped. + unsafe { data.__pinned_init(raw_data) }.inspect_err(|_| { + // SAFETY: `raw_drm` is a valid pointer to `Self`, given that `__drm_dev_alloc` was + // successful. + let drm_dev = unsafe { Self::into_drm_device(raw_drm) }; + + // SAFETY: `__drm_dev_alloc()` was successful, hence `drm_dev` must be valid and the + // refcount must be non-zero. + unsafe { bindings::drm_dev_put(drm_dev) }; + })?; + + // SAFETY: The reference count is one, and now we take ownership of that reference as a + // `drm::Device`. + Ok(unsafe { ARef::from_raw(raw_drm) }) + } + + pub(crate) fn as_raw(&self) -> *mut bindings::drm_device { + self.dev.get() + } + + /// # Safety + /// + /// `ptr` must be a valid pointer to a `struct device` embedded in `Self`. + unsafe fn from_drm_device(ptr: *const bindings::drm_device) -> *mut Self { + // SAFETY: By the safety requirements of this function `ptr` is a valid pointer to a + // `struct drm_device` embedded in `Self`. + unsafe { crate::container_of!(Opaque::cast_from(ptr), Self, dev) }.cast_mut() + } + + /// # Safety + /// + /// `ptr` must be a valid pointer to `Self`. + unsafe fn into_drm_device(ptr: NonNull<Self>) -> *mut bindings::drm_device { + // SAFETY: By the safety requirements of this function, `ptr` is a valid pointer to `Self`. + unsafe { &raw mut (*ptr.as_ptr()).dev }.cast() + } + + /// Not intended to be called externally, except via declare_drm_ioctls!() + /// + /// # Safety + /// + /// Callers must ensure that `ptr` is valid, non-null, and has a non-zero reference count, + /// i.e. it must be ensured that the reference count of the C `struct drm_device` `ptr` points + /// to can't drop to zero, for the duration of this function call and the entire duration when + /// the returned reference exists. + /// + /// Additionally, callers must ensure that the `struct device`, `ptr` is pointing to, is + /// embedded in `Self`. + #[doc(hidden)] + pub unsafe fn from_raw<'a>(ptr: *const bindings::drm_device) -> &'a Self { + // SAFETY: By the safety requirements of this function `ptr` is a valid pointer to a + // `struct drm_device` embedded in `Self`. + let ptr = unsafe { Self::from_drm_device(ptr) }; + + // SAFETY: `ptr` is valid by the safety requirements of this function. + unsafe { &*ptr.cast() } + } + + extern "C" fn release(ptr: *mut bindings::drm_device) { + // SAFETY: `ptr` is a valid pointer to a `struct drm_device` and embedded in `Self`. + let this = unsafe { Self::from_drm_device(ptr) }; + + // SAFETY: + // - When `release` runs it is guaranteed that there is no further access to `this`. + // - `this` is valid for dropping. + unsafe { core::ptr::drop_in_place(this) }; + } +} + +impl<T: drm::Driver> Deref for Device<T> { + type Target = T::Data; + + fn deref(&self) -> &Self::Target { + &self.data + } +} + +// SAFETY: DRM device objects are always reference counted and the get/put functions +// satisfy the requirements. +unsafe impl<T: drm::Driver> AlwaysRefCounted for Device<T> { + fn inc_ref(&self) { + // SAFETY: The existence of a shared reference guarantees that the refcount is non-zero. + unsafe { bindings::drm_dev_get(self.as_raw()) }; + } + + unsafe fn dec_ref(obj: NonNull<Self>) { + // SAFETY: `obj` is a valid pointer to `Self`. + let drm_dev = unsafe { Self::into_drm_device(obj) }; + + // SAFETY: The safety requirements guarantee that the refcount is non-zero. + unsafe { bindings::drm_dev_put(drm_dev) }; + } +} + +impl<T: drm::Driver> AsRef<device::Device> for Device<T> { + fn as_ref(&self) -> &device::Device { + // SAFETY: `bindings::drm_device::dev` is valid as long as the DRM device itself is valid, + // which is guaranteed by the type invariant. + unsafe { device::Device::from_raw((*self.as_raw()).dev) } + } +} + +// SAFETY: A `drm::Device` can be released from any thread. +unsafe impl<T: drm::Driver> Send for Device<T> {} + +// SAFETY: A `drm::Device` can be shared among threads because all immutable methods are protected +// by the synchronization in `struct drm_device`. +unsafe impl<T: drm::Driver> Sync for Device<T> {} diff --git a/rust/kernel/drm/driver.rs b/rust/kernel/drm/driver.rs new file mode 100644 index 000000000000..d2dad77274c4 --- /dev/null +++ b/rust/kernel/drm/driver.rs @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: GPL-2.0 OR MIT + +//! DRM driver core. +//! +//! C header: [`include/drm/drm_drv.h`](srctree/include/drm/drm_drv.h) + +use crate::{ + bindings, device, devres, drm, + error::{to_result, Result}, + prelude::*, + types::ARef, +}; +use macros::vtable; + +/// Driver use the GEM memory manager. This should be set for all modern drivers. +pub(crate) const FEAT_GEM: u32 = bindings::drm_driver_feature_DRIVER_GEM; + +/// Information data for a DRM Driver. +pub struct DriverInfo { + /// Driver major version. + pub major: i32, + /// Driver minor version. + pub minor: i32, + /// Driver patchlevel version. + pub patchlevel: i32, + /// Driver name. + pub name: &'static CStr, + /// Driver description. + pub desc: &'static CStr, +} + +/// Internal memory management operation set, normally created by memory managers (e.g. GEM). +pub struct AllocOps { + pub(crate) gem_create_object: Option< + unsafe extern "C" fn( + dev: *mut bindings::drm_device, + size: usize, + ) -> *mut bindings::drm_gem_object, + >, + pub(crate) prime_handle_to_fd: Option< + unsafe extern "C" fn( + dev: *mut bindings::drm_device, + file_priv: *mut bindings::drm_file, + handle: u32, + flags: u32, + prime_fd: *mut core::ffi::c_int, + ) -> core::ffi::c_int, + >, + pub(crate) prime_fd_to_handle: Option< + unsafe extern "C" fn( + dev: *mut bindings::drm_device, + file_priv: *mut bindings::drm_file, + prime_fd: core::ffi::c_int, + handle: *mut u32, + ) -> core::ffi::c_int, + >, + pub(crate) gem_prime_import: Option< + unsafe extern "C" fn( + dev: *mut bindings::drm_device, + dma_buf: *mut bindings::dma_buf, + ) -> *mut bindings::drm_gem_object, + >, + pub(crate) gem_prime_import_sg_table: Option< + unsafe extern "C" fn( + dev: *mut bindings::drm_device, + attach: *mut bindings::dma_buf_attachment, + sgt: *mut bindings::sg_table, + ) -> *mut bindings::drm_gem_object, + >, + pub(crate) dumb_create: Option< + unsafe extern "C" fn( + file_priv: *mut bindings::drm_file, + dev: *mut bindings::drm_device, + args: *mut bindings::drm_mode_create_dumb, + ) -> core::ffi::c_int, + >, + pub(crate) dumb_map_offset: Option< + unsafe extern "C" fn( + file_priv: *mut bindings::drm_file, + dev: *mut bindings::drm_device, + handle: u32, + offset: *mut u64, + ) -> core::ffi::c_int, + >, +} + +/// Trait for memory manager implementations. Implemented internally. +pub trait AllocImpl: super::private::Sealed + drm::gem::IntoGEMObject { + /// The C callback operations for this memory manager. + const ALLOC_OPS: AllocOps; +} + +/// The DRM `Driver` trait. +/// +/// This trait must be implemented by drivers in order to create a `struct drm_device` and `struct +/// drm_driver` to be registered in the DRM subsystem. +#[vtable] +pub trait Driver { + /// Context data associated with the DRM driver + type Data: Sync + Send; + + /// The type used to manage memory for this driver. + type Object: AllocImpl; + + /// The type used to represent a DRM File (client) + type File: drm::file::DriverFile; + + /// Driver metadata + const INFO: DriverInfo; + + /// IOCTL list. See `kernel::drm::ioctl::declare_drm_ioctls!{}`. + const IOCTLS: &'static [drm::ioctl::DrmIoctlDescriptor]; +} + +/// The registration type of a `drm::Device`. +/// +/// Once the `Registration` structure is dropped, the device is unregistered. +pub struct Registration<T: Driver>(ARef<drm::Device<T>>); + +impl<T: Driver> Registration<T> { + /// Creates a new [`Registration`] and registers it. + fn new(drm: &drm::Device<T>, flags: usize) -> Result<Self> { + // SAFETY: `drm.as_raw()` is valid by the invariants of `drm::Device`. + to_result(unsafe { bindings::drm_dev_register(drm.as_raw(), flags) })?; + + Ok(Self(drm.into())) + } + + /// Same as [`Registration::new`}, but transfers ownership of the [`Registration`] to + /// [`devres::register`]. + pub fn new_foreign_owned( + drm: &drm::Device<T>, + dev: &device::Device<device::Bound>, + flags: usize, + ) -> Result + where + T: 'static, + { + if drm.as_ref().as_raw() != dev.as_raw() { + return Err(EINVAL); + } + + let reg = Registration::<T>::new(drm, flags)?; + + devres::register(dev, reg, GFP_KERNEL) + } + + /// Returns a reference to the `Device` instance for this registration. + pub fn device(&self) -> &drm::Device<T> { + &self.0 + } +} + +// SAFETY: `Registration` doesn't offer any methods or access to fields when shared between +// threads, hence it's safe to share it. +unsafe impl<T: Driver> Sync for Registration<T> {} + +// SAFETY: Registration with and unregistration from the DRM subsystem can happen from any thread. +unsafe impl<T: Driver> Send for Registration<T> {} + +impl<T: Driver> Drop for Registration<T> { + fn drop(&mut self) { + // SAFETY: Safe by the invariant of `ARef<drm::Device<T>>`. The existence of this + // `Registration` also guarantees the this `drm::Device` is actually registered. + unsafe { bindings::drm_dev_unregister(self.0.as_raw()) }; + } +} diff --git a/rust/kernel/drm/file.rs b/rust/kernel/drm/file.rs new file mode 100644 index 000000000000..8c46f8d51951 --- /dev/null +++ b/rust/kernel/drm/file.rs @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: GPL-2.0 OR MIT + +//! DRM File objects. +//! +//! C header: [`include/drm/drm_file.h`](srctree/include/drm/drm_file.h) + +use crate::{bindings, drm, error::Result, prelude::*, types::Opaque}; +use core::marker::PhantomData; +use core::pin::Pin; + +/// Trait that must be implemented by DRM drivers to represent a DRM File (a client instance). +pub trait DriverFile { + /// The parent `Driver` implementation for this `DriverFile`. + type Driver: drm::Driver; + + /// Open a new file (called when a client opens the DRM device). + fn open(device: &drm::Device<Self::Driver>) -> Result<Pin<KBox<Self>>>; +} + +/// An open DRM File. +/// +/// # Invariants +/// +/// `self.0` is a valid instance of a `struct drm_file`. +#[repr(transparent)] +pub struct File<T: DriverFile>(Opaque<bindings::drm_file>, PhantomData<T>); + +impl<T: DriverFile> File<T> { + #[doc(hidden)] + /// Not intended to be called externally, except via declare_drm_ioctls!() + /// + /// # Safety + /// + /// `raw_file` must be a valid pointer to an open `struct drm_file`, opened through `T::open`. + pub unsafe fn from_raw<'a>(ptr: *mut bindings::drm_file) -> &'a File<T> { + // SAFETY: `raw_file` is valid by the safety requirements of this function. + unsafe { &*ptr.cast() } + } + + pub(super) fn as_raw(&self) -> *mut bindings::drm_file { + self.0.get() + } + + fn driver_priv(&self) -> *mut T { + // SAFETY: By the type invariants of `Self`, `self.as_raw()` is always valid. + unsafe { (*self.as_raw()).driver_priv }.cast() + } + + /// Return a pinned reference to the driver file structure. + pub fn inner(&self) -> Pin<&T> { + // SAFETY: By the type invariant the pointer `self.as_raw()` points to a valid and opened + // `struct drm_file`, hence `driver_priv` has been properly initialized by `open_callback`. + unsafe { Pin::new_unchecked(&*(self.driver_priv())) } + } + + /// The open callback of a `struct drm_file`. + pub(crate) extern "C" fn open_callback( + raw_dev: *mut bindings::drm_device, + raw_file: *mut bindings::drm_file, + ) -> core::ffi::c_int { + // SAFETY: A callback from `struct drm_driver::open` guarantees that + // - `raw_dev` is valid pointer to a `struct drm_device`, + // - the corresponding `struct drm_device` has been registered. + let drm = unsafe { drm::Device::from_raw(raw_dev) }; + + // SAFETY: `raw_file` is a valid pointer to a `struct drm_file`. + let file = unsafe { File::<T>::from_raw(raw_file) }; + + let inner = match T::open(drm) { + Err(e) => { + return e.to_errno(); + } + Ok(i) => i, + }; + + // SAFETY: This pointer is treated as pinned, and the Drop guarantee is upheld in + // `postclose_callback()`. + let driver_priv = KBox::into_raw(unsafe { Pin::into_inner_unchecked(inner) }); + + // SAFETY: By the type invariants of `Self`, `self.as_raw()` is always valid. + unsafe { (*file.as_raw()).driver_priv = driver_priv.cast() }; + + 0 + } + + /// The postclose callback of a `struct drm_file`. + pub(crate) extern "C" fn postclose_callback( + _raw_dev: *mut bindings::drm_device, + raw_file: *mut bindings::drm_file, + ) { + // SAFETY: This reference won't escape this function + let file = unsafe { File::<T>::from_raw(raw_file) }; + + // SAFETY: `file.driver_priv` has been created in `open_callback` through `KBox::into_raw`. + let _ = unsafe { KBox::from_raw(file.driver_priv()) }; + } +} + +impl<T: DriverFile> super::private::Sealed for File<T> {} diff --git a/rust/kernel/drm/gem/mod.rs b/rust/kernel/drm/gem/mod.rs new file mode 100644 index 000000000000..b9f3248876ba --- /dev/null +++ b/rust/kernel/drm/gem/mod.rs @@ -0,0 +1,330 @@ +// SPDX-License-Identifier: GPL-2.0 OR MIT + +//! DRM GEM API +//! +//! C header: [`include/drm/drm_gem.h`](srctree/include/drm/drm_gem.h) + +use crate::{ + alloc::flags::*, + bindings, drm, + drm::driver::{AllocImpl, AllocOps}, + error::{to_result, Result}, + prelude::*, + types::{ARef, AlwaysRefCounted, Opaque}, +}; +use core::{mem, ops::Deref, ptr::NonNull}; + +/// GEM object functions, which must be implemented by drivers. +pub trait BaseDriverObject<T: BaseObject>: Sync + Send + Sized { + /// Create a new driver data object for a GEM object of a given size. + fn new(dev: &drm::Device<T::Driver>, size: usize) -> impl PinInit<Self, Error>; + + /// Open a new handle to an existing object, associated with a File. + fn open( + _obj: &<<T as IntoGEMObject>::Driver as drm::Driver>::Object, + _file: &drm::File<<<T as IntoGEMObject>::Driver as drm::Driver>::File>, + ) -> Result { + Ok(()) + } + + /// Close a handle to an existing object, associated with a File. + fn close( + _obj: &<<T as IntoGEMObject>::Driver as drm::Driver>::Object, + _file: &drm::File<<<T as IntoGEMObject>::Driver as drm::Driver>::File>, + ) { + } +} + +/// Trait that represents a GEM object subtype +pub trait IntoGEMObject: Sized + super::private::Sealed + AlwaysRefCounted { + /// Owning driver for this type + type Driver: drm::Driver; + + /// Returns a reference to the raw `drm_gem_object` structure, which must be valid as long as + /// this owning object is valid. + fn as_raw(&self) -> *mut bindings::drm_gem_object; + + /// Converts a pointer to a `struct drm_gem_object` into a reference to `Self`. + /// + /// # Safety + /// + /// - `self_ptr` must be a valid pointer to `Self`. + /// - The caller promises that holding the immutable reference returned by this function does + /// not violate rust's data aliasing rules and remains valid throughout the lifetime of `'a`. + unsafe fn from_raw<'a>(self_ptr: *mut bindings::drm_gem_object) -> &'a Self; +} + +// SAFETY: All gem objects are refcounted. +unsafe impl<T: IntoGEMObject> AlwaysRefCounted for T { + fn inc_ref(&self) { + // SAFETY: The existence of a shared reference guarantees that the refcount is non-zero. + unsafe { bindings::drm_gem_object_get(self.as_raw()) }; + } + + unsafe fn dec_ref(obj: NonNull<Self>) { + // SAFETY: We either hold the only refcount on `obj`, or one of many - meaning that no one + // else could possibly hold a mutable reference to `obj` and thus this immutable reference + // is safe. + let obj = unsafe { obj.as_ref() }.as_raw(); + + // SAFETY: + // - The safety requirements guarantee that the refcount is non-zero. + // - We hold no references to `obj` now, making it safe for us to potentially deallocate it. + unsafe { bindings::drm_gem_object_put(obj) }; + } +} + +/// Trait which must be implemented by drivers using base GEM objects. +pub trait DriverObject: BaseDriverObject<Object<Self>> { + /// Parent `Driver` for this object. + type Driver: drm::Driver; +} + +extern "C" fn open_callback<T: BaseDriverObject<U>, U: BaseObject>( + raw_obj: *mut bindings::drm_gem_object, + raw_file: *mut bindings::drm_file, +) -> core::ffi::c_int { + // SAFETY: `open_callback` is only ever called with a valid pointer to a `struct drm_file`. + let file = unsafe { + drm::File::<<<U as IntoGEMObject>::Driver as drm::Driver>::File>::from_raw(raw_file) + }; + // SAFETY: `open_callback` is specified in the AllocOps structure for `Object<T>`, ensuring that + // `raw_obj` is indeed contained within a `Object<T>`. + let obj = unsafe { + <<<U as IntoGEMObject>::Driver as drm::Driver>::Object as IntoGEMObject>::from_raw(raw_obj) + }; + + match T::open(obj, file) { + Err(e) => e.to_errno(), + Ok(()) => 0, + } +} + +extern "C" fn close_callback<T: BaseDriverObject<U>, U: BaseObject>( + raw_obj: *mut bindings::drm_gem_object, + raw_file: *mut bindings::drm_file, +) { + // SAFETY: `open_callback` is only ever called with a valid pointer to a `struct drm_file`. + let file = unsafe { + drm::File::<<<U as IntoGEMObject>::Driver as drm::Driver>::File>::from_raw(raw_file) + }; + // SAFETY: `close_callback` is specified in the AllocOps structure for `Object<T>`, ensuring + // that `raw_obj` is indeed contained within a `Object<T>`. + let obj = unsafe { + <<<U as IntoGEMObject>::Driver as drm::Driver>::Object as IntoGEMObject>::from_raw(raw_obj) + }; + + T::close(obj, file); +} + +impl<T: DriverObject> IntoGEMObject for Object<T> { + type Driver = T::Driver; + + fn as_raw(&self) -> *mut bindings::drm_gem_object { + self.obj.get() + } + + unsafe fn from_raw<'a>(self_ptr: *mut bindings::drm_gem_object) -> &'a Self { + // SAFETY: `obj` is guaranteed to be in an `Object<T>` via the safety contract of this + // function + unsafe { &*crate::container_of!(Opaque::cast_from(self_ptr), Object<T>, obj) } + } +} + +/// Base operations shared by all GEM object classes +pub trait BaseObject: IntoGEMObject { + /// Returns the size of the object in bytes. + fn size(&self) -> usize { + // SAFETY: `self.as_raw()` is guaranteed to be a pointer to a valid `struct drm_gem_object`. + unsafe { (*self.as_raw()).size } + } + + /// Creates a new handle for the object associated with a given `File` + /// (or returns an existing one). + fn create_handle( + &self, + file: &drm::File<<<Self as IntoGEMObject>::Driver as drm::Driver>::File>, + ) -> Result<u32> { + let mut handle: u32 = 0; + // SAFETY: The arguments are all valid per the type invariants. + to_result(unsafe { + bindings::drm_gem_handle_create(file.as_raw().cast(), self.as_raw(), &mut handle) + })?; + Ok(handle) + } + + /// Looks up an object by its handle for a given `File`. + fn lookup_handle( + file: &drm::File<<<Self as IntoGEMObject>::Driver as drm::Driver>::File>, + handle: u32, + ) -> Result<ARef<Self>> { + // SAFETY: The arguments are all valid per the type invariants. + let ptr = unsafe { bindings::drm_gem_object_lookup(file.as_raw().cast(), handle) }; + if ptr.is_null() { + return Err(ENOENT); + } + + // SAFETY: + // - A `drm::Driver` can only have a single `File` implementation. + // - `file` uses the same `drm::Driver` as `Self`. + // - Therefore, we're guaranteed that `ptr` must be a gem object embedded within `Self`. + // - And we check if the pointer is null befoe calling from_raw(), ensuring that `ptr` is a + // valid pointer to an initialized `Self`. + let obj = unsafe { Self::from_raw(ptr) }; + + // SAFETY: + // - We take ownership of the reference of `drm_gem_object_lookup()`. + // - Our `NonNull` comes from an immutable reference, thus ensuring it is a valid pointer to + // `Self`. + Ok(unsafe { ARef::from_raw(obj.into()) }) + } + + /// Creates an mmap offset to map the object from userspace. + fn create_mmap_offset(&self) -> Result<u64> { + // SAFETY: The arguments are valid per the type invariant. + to_result(unsafe { bindings::drm_gem_create_mmap_offset(self.as_raw()) })?; + + // SAFETY: The arguments are valid per the type invariant. + Ok(unsafe { bindings::drm_vma_node_offset_addr(&raw mut (*self.as_raw()).vma_node) }) + } +} + +impl<T: IntoGEMObject> BaseObject for T {} + +/// A base GEM object. +/// +/// Invariants +/// +/// - `self.obj` is a valid instance of a `struct drm_gem_object`. +/// - `self.dev` is always a valid pointer to a `struct drm_device`. +#[repr(C)] +#[pin_data] +pub struct Object<T: DriverObject + Send + Sync> { + obj: Opaque<bindings::drm_gem_object>, + dev: NonNull<drm::Device<T::Driver>>, + #[pin] + data: T, +} + +impl<T: DriverObject> Object<T> { + /// The size of this object's structure. + pub const SIZE: usize = mem::size_of::<Self>(); + + const OBJECT_FUNCS: bindings::drm_gem_object_funcs = bindings::drm_gem_object_funcs { + free: Some(Self::free_callback), + open: Some(open_callback::<T, Object<T>>), + close: Some(close_callback::<T, Object<T>>), + print_info: None, + export: None, + pin: None, + unpin: None, + get_sg_table: None, + vmap: None, + vunmap: None, + mmap: None, + status: None, + vm_ops: core::ptr::null_mut(), + evict: None, + rss: None, + }; + + /// Create a new GEM object. + pub fn new(dev: &drm::Device<T::Driver>, size: usize) -> Result<ARef<Self>> { + let obj: Pin<KBox<Self>> = KBox::pin_init( + try_pin_init!(Self { + obj: Opaque::new(bindings::drm_gem_object::default()), + data <- T::new(dev, size), + // INVARIANT: The drm subsystem guarantees that the `struct drm_device` will live + // as long as the GEM object lives. + dev: dev.into(), + }), + GFP_KERNEL, + )?; + + // SAFETY: `obj.as_raw()` is guaranteed to be valid by the initialization above. + unsafe { (*obj.as_raw()).funcs = &Self::OBJECT_FUNCS }; + + // SAFETY: The arguments are all valid per the type invariants. + to_result(unsafe { bindings::drm_gem_object_init(dev.as_raw(), obj.obj.get(), size) })?; + + // SAFETY: We never move out of `Self`. + let ptr = KBox::into_raw(unsafe { Pin::into_inner_unchecked(obj) }); + + // SAFETY: `ptr` comes from `KBox::into_raw` and hence can't be NULL. + let ptr = unsafe { NonNull::new_unchecked(ptr) }; + + // SAFETY: We take over the initial reference count from `drm_gem_object_init()`. + Ok(unsafe { ARef::from_raw(ptr) }) + } + + /// Returns the `Device` that owns this GEM object. + pub fn dev(&self) -> &drm::Device<T::Driver> { + // SAFETY: The DRM subsystem guarantees that the `struct drm_device` will live as long as + // the GEM object lives, hence the pointer must be valid. + unsafe { self.dev.as_ref() } + } + + fn as_raw(&self) -> *mut bindings::drm_gem_object { + self.obj.get() + } + + extern "C" fn free_callback(obj: *mut bindings::drm_gem_object) { + let ptr: *mut Opaque<bindings::drm_gem_object> = obj.cast(); + + // SAFETY: All of our objects are of type `Object<T>`. + let this = unsafe { crate::container_of!(ptr, Self, obj) }; + + // SAFETY: The C code only ever calls this callback with a valid pointer to a `struct + // drm_gem_object`. + unsafe { bindings::drm_gem_object_release(obj) }; + + // SAFETY: All of our objects are allocated via `KBox`, and we're in the + // free callback which guarantees this object has zero remaining references, + // so we can drop it. + let _ = unsafe { KBox::from_raw(this) }; + } +} + +impl<T: DriverObject> super::private::Sealed for Object<T> {} + +impl<T: DriverObject> Deref for Object<T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.data + } +} + +impl<T: DriverObject> AllocImpl for Object<T> { + const ALLOC_OPS: AllocOps = AllocOps { + gem_create_object: None, + prime_handle_to_fd: None, + prime_fd_to_handle: None, + gem_prime_import: None, + gem_prime_import_sg_table: None, + dumb_create: None, + dumb_map_offset: None, + }; +} + +pub(super) const fn create_fops() -> bindings::file_operations { + // SAFETY: As by the type invariant, it is safe to initialize `bindings::file_operations` + // zeroed. + let mut fops: bindings::file_operations = unsafe { core::mem::zeroed() }; + + fops.owner = core::ptr::null_mut(); + fops.open = Some(bindings::drm_open); + fops.release = Some(bindings::drm_release); + fops.unlocked_ioctl = Some(bindings::drm_ioctl); + #[cfg(CONFIG_COMPAT)] + { + fops.compat_ioctl = Some(bindings::drm_compat_ioctl); + } + fops.poll = Some(bindings::drm_poll); + fops.read = Some(bindings::drm_read); + fops.llseek = Some(bindings::noop_llseek); + fops.mmap = Some(bindings::drm_gem_mmap); + fops.fop_flags = bindings::FOP_UNSIGNED_OFFSET; + + fops +} diff --git a/rust/kernel/drm/ioctl.rs b/rust/kernel/drm/ioctl.rs new file mode 100644 index 000000000000..8431cdcd3ae0 --- /dev/null +++ b/rust/kernel/drm/ioctl.rs @@ -0,0 +1,162 @@ +// SPDX-License-Identifier: GPL-2.0 OR MIT + +//! DRM IOCTL definitions. +//! +//! C header: [`include/drm/drm_ioctl.h`](srctree/include/drm/drm_ioctl.h) + +use crate::ioctl; + +const BASE: u32 = uapi::DRM_IOCTL_BASE as u32; + +/// Construct a DRM ioctl number with no argument. +#[allow(non_snake_case)] +#[inline(always)] +pub const fn IO(nr: u32) -> u32 { + ioctl::_IO(BASE, nr) +} + +/// Construct a DRM ioctl number with a read-only argument. +#[allow(non_snake_case)] +#[inline(always)] +pub const fn IOR<T>(nr: u32) -> u32 { + ioctl::_IOR::<T>(BASE, nr) +} + +/// Construct a DRM ioctl number with a write-only argument. +#[allow(non_snake_case)] +#[inline(always)] +pub const fn IOW<T>(nr: u32) -> u32 { + ioctl::_IOW::<T>(BASE, nr) +} + +/// Construct a DRM ioctl number with a read-write argument. +#[allow(non_snake_case)] +#[inline(always)] +pub const fn IOWR<T>(nr: u32) -> u32 { + ioctl::_IOWR::<T>(BASE, nr) +} + +/// Descriptor type for DRM ioctls. Use the `declare_drm_ioctls!{}` macro to construct them. +pub type DrmIoctlDescriptor = bindings::drm_ioctl_desc; + +/// This is for ioctl which are used for rendering, and require that the file descriptor is either +/// for a render node, or if it’s a legacy/primary node, then it must be authenticated. +pub const AUTH: u32 = bindings::drm_ioctl_flags_DRM_AUTH; + +/// This must be set for any ioctl which can change the modeset or display state. Userspace must +/// call the ioctl through a primary node, while it is the active master. +/// +/// Note that read-only modeset ioctl can also be called by unauthenticated clients, or when a +/// master is not the currently active one. +pub const MASTER: u32 = bindings::drm_ioctl_flags_DRM_MASTER; + +/// Anything that could potentially wreak a master file descriptor needs to have this flag set. +/// +/// Current that’s only for the SETMASTER and DROPMASTER ioctl, which e.g. logind can call to +/// force a non-behaving master (display compositor) into compliance. +/// +/// This is equivalent to callers with the SYSADMIN capability. +pub const ROOT_ONLY: u32 = bindings::drm_ioctl_flags_DRM_ROOT_ONLY; + +/// This is used for all ioctl needed for rendering only, for drivers which support render nodes. +/// This should be all new render drivers, and hence it should be always set for any ioctl with +/// `AUTH` set. Note though that read-only query ioctl might have this set, but have not set +/// DRM_AUTH because they do not require authentication. +pub const RENDER_ALLOW: u32 = bindings::drm_ioctl_flags_DRM_RENDER_ALLOW; + +/// Internal structures used by the `declare_drm_ioctls!{}` macro. Do not use directly. +#[doc(hidden)] +pub mod internal { + pub use bindings::drm_device; + pub use bindings::drm_file; + pub use bindings::drm_ioctl_desc; +} + +/// Declare the DRM ioctls for a driver. +/// +/// Each entry in the list should have the form: +/// +/// `(ioctl_number, argument_type, flags, user_callback),` +/// +/// `argument_type` is the type name within the `bindings` crate. +/// `user_callback` should have the following prototype: +/// +/// ```ignore +/// fn foo(device: &kernel::drm::Device<Self>, +/// data: &Opaque<uapi::argument_type>, +/// file: &kernel::drm::File<Self::File>, +/// ) -> Result<u32> +/// ``` +/// where `Self` is the drm::drv::Driver implementation these ioctls are being declared within. +/// +/// # Examples +/// +/// ```ignore +/// kernel::declare_drm_ioctls! { +/// (FOO_GET_PARAM, drm_foo_get_param, ioctl::RENDER_ALLOW, my_get_param_handler), +/// } +/// ``` +/// +#[macro_export] +macro_rules! declare_drm_ioctls { + ( $(($cmd:ident, $struct:ident, $flags:expr, $func:expr)),* $(,)? ) => { + const IOCTLS: &'static [$crate::drm::ioctl::DrmIoctlDescriptor] = { + use $crate::uapi::*; + const _:() = { + let i: u32 = $crate::uapi::DRM_COMMAND_BASE; + // Assert that all the IOCTLs are in the right order and there are no gaps, + // and that the size of the specified type is correct. + $( + let cmd: u32 = $crate::macros::concat_idents!(DRM_IOCTL_, $cmd); + ::core::assert!(i == $crate::ioctl::_IOC_NR(cmd)); + ::core::assert!(core::mem::size_of::<$crate::uapi::$struct>() == + $crate::ioctl::_IOC_SIZE(cmd)); + let i: u32 = i + 1; + )* + }; + + let ioctls = &[$( + $crate::drm::ioctl::internal::drm_ioctl_desc { + cmd: $crate::macros::concat_idents!(DRM_IOCTL_, $cmd) as u32, + func: { + #[allow(non_snake_case)] + unsafe extern "C" fn $cmd( + raw_dev: *mut $crate::drm::ioctl::internal::drm_device, + raw_data: *mut ::core::ffi::c_void, + raw_file: *mut $crate::drm::ioctl::internal::drm_file, + ) -> core::ffi::c_int { + // SAFETY: + // - The DRM core ensures the device lives while callbacks are being + // called. + // - The DRM device must have been registered when we're called through + // an IOCTL. + // + // FIXME: Currently there is nothing enforcing that the types of the + // dev/file match the current driver these ioctls are being declared + // for, and it's not clear how to enforce this within the type system. + let dev = $crate::drm::device::Device::from_raw(raw_dev); + // SAFETY: The ioctl argument has size `_IOC_SIZE(cmd)`, which we + // asserted above matches the size of this type, and all bit patterns of + // UAPI structs must be valid. + let data = unsafe { + &*(raw_data as *const $crate::types::Opaque<$crate::uapi::$struct>) + }; + // SAFETY: This is just the DRM file structure + let file = unsafe { $crate::drm::File::from_raw(raw_file) }; + + match $func(dev, data, file) { + Err(e) => e.to_errno(), + Ok(i) => i.try_into() + .unwrap_or($crate::error::code::ERANGE.to_errno()), + } + } + Some($cmd) + }, + flags: $flags, + name: $crate::c_str!(::core::stringify!($cmd)).as_char_ptr(), + } + ),*]; + ioctls + }; + }; +} diff --git a/rust/kernel/drm/mod.rs b/rust/kernel/drm/mod.rs new file mode 100644 index 000000000000..1b82b6945edf --- /dev/null +++ b/rust/kernel/drm/mod.rs @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: GPL-2.0 OR MIT + +//! DRM subsystem abstractions. + +pub mod device; +pub mod driver; +pub mod file; +pub mod gem; +pub mod ioctl; + +pub use self::device::Device; +pub use self::driver::Driver; +pub use self::driver::DriverInfo; +pub use self::driver::Registration; +pub use self::file::File; + +pub(crate) mod private { + pub trait Sealed {} +} diff --git a/rust/kernel/error.rs b/rust/kernel/error.rs index 5fece574ec02..a41de293dcd1 100644 --- a/rust/kernel/error.rs +++ b/rust/kernel/error.rs @@ -4,11 +4,12 @@ //! //! C header: [`include/uapi/asm-generic/errno-base.h`](srctree/include/uapi/asm-generic/errno-base.h) -use crate::{alloc::AllocError, str::CStr}; +use crate::{ + alloc::{layout::LayoutError, AllocError}, + fmt, + str::CStr, +}; -use core::alloc::LayoutError; - -use core::fmt; use core::num::NonZeroI32; use core::num::TryFromIntError; use core::str::Utf8Error; @@ -63,6 +64,8 @@ pub mod code { declare_err!(EPIPE, "Broken pipe."); declare_err!(EDOM, "Math argument out of domain of func."); declare_err!(ERANGE, "Math result not representable."); + declare_err!(EOVERFLOW, "Value too large for defined data type."); + declare_err!(ETIMEDOUT, "Connection timed out."); declare_err!(ERESTARTSYS, "Restart the system call."); declare_err!(ERESTARTNOINTR, "System call was interrupted by a signal and will be restarted."); declare_err!(ERESTARTNOHAND, "Restart if no handler."); @@ -101,19 +104,16 @@ impl Error { /// It is a bug to pass an out-of-range `errno`. `EINVAL` would /// be returned in such a case. pub fn from_errno(errno: crate::ffi::c_int) -> Error { - if errno < -(bindings::MAX_ERRNO as i32) || errno >= 0 { + if let Some(error) = Self::try_from_errno(errno) { + error + } else { // TODO: Make it a `WARN_ONCE` once available. crate::pr_warn!( - "attempted to create `Error` with out of range `errno`: {}", + "attempted to create `Error` with out of range `errno`: {}\n", errno ); - return code::EINVAL; + code::EINVAL } - - // INVARIANT: The check above ensures the type invariant - // will hold. - // SAFETY: `errno` is checked above to be in a valid range. - unsafe { Error::from_errno_unchecked(errno) } } /// Creates an [`Error`] from a kernel error code. @@ -154,7 +154,7 @@ impl Error { /// Returns the error encoded as a pointer. pub fn to_ptr<T>(self) -> *mut T { // SAFETY: `self.0` is a valid error due to its invariant. - unsafe { bindings::ERR_PTR(self.0.get() as _) as *mut _ } + unsafe { bindings::ERR_PTR(self.0.get() as crate::ffi::c_long).cast() } } /// Returns a string representing the error, if one exists. @@ -189,7 +189,7 @@ impl fmt::Debug for Error { Some(name) => f .debug_tuple( // SAFETY: These strings are ASCII-only. - unsafe { core::str::from_utf8_unchecked(name) }, + unsafe { core::str::from_utf8_unchecked(name.to_bytes()) }, ) .finish(), } @@ -220,8 +220,8 @@ impl From<LayoutError> for Error { } } -impl From<core::fmt::Error> for Error { - fn from(_: core::fmt::Error) -> Error { +impl From<fmt::Error> for Error { + fn from(_: fmt::Error) -> Error { code::EINVAL } } @@ -250,8 +250,129 @@ impl From<core::convert::Infallible> for Error { /// [`Error`] as its error type. /// /// Note that even if a function does not return anything when it succeeds, -/// it should still be modeled as returning a `Result` rather than +/// it should still be modeled as returning a [`Result`] rather than /// just an [`Error`]. +/// +/// Calling a function that returns [`Result`] forces the caller to handle +/// the returned [`Result`]. +/// +/// This can be done "manually" by using [`match`]. Using [`match`] to decode +/// the [`Result`] is similar to C where all the return value decoding and the +/// error handling is done explicitly by writing handling code for each +/// error to cover. Using [`match`] the error and success handling can be +/// implemented in all detail as required. For example (inspired by +/// [`samples/rust/rust_minimal.rs`]): +/// +/// ``` +/// # #[allow(clippy::single_match)] +/// fn example() -> Result { +/// let mut numbers = KVec::new(); +/// +/// match numbers.push(72, GFP_KERNEL) { +/// Err(e) => { +/// pr_err!("Error pushing 72: {e:?}"); +/// return Err(e.into()); +/// } +/// // Do nothing, continue. +/// Ok(()) => (), +/// } +/// +/// match numbers.push(108, GFP_KERNEL) { +/// Err(e) => { +/// pr_err!("Error pushing 108: {e:?}"); +/// return Err(e.into()); +/// } +/// // Do nothing, continue. +/// Ok(()) => (), +/// } +/// +/// match numbers.push(200, GFP_KERNEL) { +/// Err(e) => { +/// pr_err!("Error pushing 200: {e:?}"); +/// return Err(e.into()); +/// } +/// // Do nothing, continue. +/// Ok(()) => (), +/// } +/// +/// Ok(()) +/// } +/// # example()?; +/// # Ok::<(), Error>(()) +/// ``` +/// +/// An alternative to be more concise is the [`if let`] syntax: +/// +/// ``` +/// fn example() -> Result { +/// let mut numbers = KVec::new(); +/// +/// if let Err(e) = numbers.push(72, GFP_KERNEL) { +/// pr_err!("Error pushing 72: {e:?}"); +/// return Err(e.into()); +/// } +/// +/// if let Err(e) = numbers.push(108, GFP_KERNEL) { +/// pr_err!("Error pushing 108: {e:?}"); +/// return Err(e.into()); +/// } +/// +/// if let Err(e) = numbers.push(200, GFP_KERNEL) { +/// pr_err!("Error pushing 200: {e:?}"); +/// return Err(e.into()); +/// } +/// +/// Ok(()) +/// } +/// # example()?; +/// # Ok::<(), Error>(()) +/// ``` +/// +/// Instead of these verbose [`match`]/[`if let`], the [`?`] operator can +/// be used to handle the [`Result`]. Using the [`?`] operator is often +/// the best choice to handle [`Result`] in a non-verbose way as done in +/// [`samples/rust/rust_minimal.rs`]: +/// +/// ``` +/// fn example() -> Result { +/// let mut numbers = KVec::new(); +/// +/// numbers.push(72, GFP_KERNEL)?; +/// numbers.push(108, GFP_KERNEL)?; +/// numbers.push(200, GFP_KERNEL)?; +/// +/// Ok(()) +/// } +/// # example()?; +/// # Ok::<(), Error>(()) +/// ``` +/// +/// Another possibility is to call [`unwrap()`](Result::unwrap) or +/// [`expect()`](Result::expect). However, use of these functions is +/// *heavily discouraged* in the kernel because they trigger a Rust +/// [`panic!`] if an error happens, which may destabilize the system or +/// entirely break it as a result -- just like the C [`BUG()`] macro. +/// Please see the documentation for the C macro [`BUG()`] for guidance +/// on when to use these functions. +/// +/// Alternatively, depending on the use case, using [`unwrap_or()`], +/// [`unwrap_or_else()`], [`unwrap_or_default()`] or [`unwrap_unchecked()`] +/// might be an option, as well. +/// +/// For even more details, please see the [Rust documentation]. +/// +/// [`match`]: https://doc.rust-lang.org/reference/expressions/match-expr.html +/// [`samples/rust/rust_minimal.rs`]: srctree/samples/rust/rust_minimal.rs +/// [`if let`]: https://doc.rust-lang.org/reference/expressions/if-expr.html#if-let-expressions +/// [`?`]: https://doc.rust-lang.org/reference/expressions/operator-expr.html#the-question-mark-operator +/// [`unwrap()`]: Result::unwrap +/// [`expect()`]: Result::expect +/// [`BUG()`]: https://docs.kernel.org/process/deprecated.html#bug-and-bug-on +/// [`unwrap_or()`]: Result::unwrap_or +/// [`unwrap_or_else()`]: Result::unwrap_or_else +/// [`unwrap_or_default()`]: Result::unwrap_or_default +/// [`unwrap_unchecked()`]: Result::unwrap_unchecked +/// [Rust documentation]: https://doc.rust-lang.org/book/ch09-02-recoverable-errors-with-result.html pub type Result<T = (), E = Error> = core::result::Result<T, E>; /// Converts an integer as returned by a C kernel function to an error if it's negative, and diff --git a/rust/kernel/faux.rs b/rust/kernel/faux.rs new file mode 100644 index 000000000000..7fe2dd197e37 --- /dev/null +++ b/rust/kernel/faux.rs @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: GPL-2.0-only + +//! Abstractions for the faux bus. +//! +//! This module provides bindings for working with faux devices in kernel modules. +//! +//! C header: [`include/linux/device/faux.h`](srctree/include/linux/device/faux.h) + +use crate::{bindings, device, error::code::*, prelude::*}; +use core::ptr::{addr_of_mut, null, null_mut, NonNull}; + +/// The registration of a faux device. +/// +/// This type represents the registration of a [`struct faux_device`]. When an instance of this type +/// is dropped, its respective faux device will be unregistered from the system. +/// +/// # Invariants +/// +/// `self.0` always holds a valid pointer to an initialized and registered [`struct faux_device`]. +/// +/// [`struct faux_device`]: srctree/include/linux/device/faux.h +pub struct Registration(NonNull<bindings::faux_device>); + +impl Registration { + /// Create and register a new faux device with the given name. + #[inline] + pub fn new(name: &CStr, parent: Option<&device::Device>) -> Result<Self> { + // SAFETY: + // - `name` is copied by this function into its own storage + // - `faux_ops` is safe to leave NULL according to the C API + // - `parent` can be either NULL or a pointer to a `struct device`, and `faux_device_create` + // will take a reference to `parent` using `device_add` - ensuring that it remains valid + // for the lifetime of the faux device. + let dev = unsafe { + bindings::faux_device_create( + name.as_char_ptr(), + parent.map_or(null_mut(), |p| p.as_raw()), + null(), + ) + }; + + // The above function will return either a valid device, or NULL on failure + // INVARIANT: The device will remain registered until faux_device_destroy() is called, which + // happens in our Drop implementation. + Ok(Self(NonNull::new(dev).ok_or(ENODEV)?)) + } + + fn as_raw(&self) -> *mut bindings::faux_device { + self.0.as_ptr() + } +} + +impl AsRef<device::Device> for Registration { + fn as_ref(&self) -> &device::Device { + // SAFETY: The underlying `device` in `faux_device` is guaranteed by the C API to be + // a valid initialized `device`. + unsafe { device::Device::from_raw(addr_of_mut!((*self.as_raw()).dev)) } + } +} + +impl Drop for Registration { + #[inline] + fn drop(&mut self) { + // SAFETY: `self.0` is a valid registered faux_device via our type invariants. + unsafe { bindings::faux_device_destroy(self.as_raw()) } + } +} + +// SAFETY: The faux device API is thread-safe as guaranteed by the device core, as long as +// faux_device_destroy() is guaranteed to only be called once - which is guaranteed by our type not +// having Copy/Clone. +unsafe impl Send for Registration {} + +// SAFETY: The faux device API is thread-safe as guaranteed by the device core, as long as +// faux_device_destroy() is guaranteed to only be called once - which is guaranteed by our type not +// having Copy/Clone. +unsafe impl Sync for Registration {} diff --git a/rust/kernel/firmware.rs b/rust/kernel/firmware.rs index c5162fdc95ff..1abab5b2f052 100644 --- a/rust/kernel/firmware.rs +++ b/rust/kernel/firmware.rs @@ -4,7 +4,7 @@ //! //! C header: [`include/linux/firmware.h`](srctree/include/linux/firmware.h) -use crate::{bindings, device::Device, error::Error, error::Result, str::CStr}; +use crate::{bindings, device::Device, error::Error, error::Result, ffi, str::CStr}; use core::ptr::NonNull; /// # Invariants @@ -12,7 +12,11 @@ use core::ptr::NonNull; /// One of the following: `bindings::request_firmware`, `bindings::firmware_request_nowarn`, /// `bindings::firmware_request_platform`, `bindings::request_firmware_direct`. struct FwFunc( - unsafe extern "C" fn(*mut *const bindings::firmware, *const u8, *mut bindings::device) -> i32, + unsafe extern "C" fn( + *mut *const bindings::firmware, + *const ffi::c_char, + *mut bindings::device, + ) -> i32, ); impl FwFunc { @@ -58,10 +62,11 @@ impl Firmware { fn request_internal(name: &CStr, dev: &Device, func: FwFunc) -> Result<Self> { let mut fw: *mut bindings::firmware = core::ptr::null_mut(); let pfw: *mut *mut bindings::firmware = &mut fw; + let pfw: *mut *const bindings::firmware = pfw.cast(); // SAFETY: `pfw` is a valid pointer to a NULL initialized `bindings::firmware` pointer. // `name` and `dev` are valid as by their type invariants. - let ret = unsafe { func.0(pfw as _, name.as_char_ptr(), dev.as_raw()) }; + let ret = unsafe { func.0(pfw, name.as_char_ptr(), dev.as_raw()) }; if ret != 0 { return Err(Error::from_errno(ret)); } @@ -115,3 +120,219 @@ unsafe impl Send for Firmware {} // SAFETY: `Firmware` only holds a pointer to a C `struct firmware`, references to which are safe to // be used from any thread. unsafe impl Sync for Firmware {} + +/// Create firmware .modinfo entries. +/// +/// This macro is the counterpart of the C macro `MODULE_FIRMWARE()`, but instead of taking a +/// simple string literals, which is already covered by the `firmware` field of +/// [`crate::prelude::module!`], it allows the caller to pass a builder type, based on the +/// [`ModInfoBuilder`], which can create the firmware modinfo strings in a more flexible way. +/// +/// Drivers should extend the [`ModInfoBuilder`] with their own driver specific builder type. +/// +/// The `builder` argument must be a type which implements the following function. +/// +/// `const fn create(module_name: &'static CStr) -> ModInfoBuilder` +/// +/// `create` should pass the `module_name` to the [`ModInfoBuilder`] and, with the help of +/// it construct the corresponding firmware modinfo. +/// +/// Typically, such contracts would be enforced by a trait, however traits do not (yet) support +/// const functions. +/// +/// # Examples +/// +/// ``` +/// # mod module_firmware_test { +/// # use kernel::firmware; +/// # use kernel::prelude::*; +/// # +/// # struct MyModule; +/// # +/// # impl kernel::Module for MyModule { +/// # fn init(_module: &'static ThisModule) -> Result<Self> { +/// # Ok(Self) +/// # } +/// # } +/// # +/// # +/// struct Builder<const N: usize>; +/// +/// impl<const N: usize> Builder<N> { +/// const DIR: &'static str = "vendor/chip/"; +/// const FILES: [&'static str; 3] = [ "foo", "bar", "baz" ]; +/// +/// const fn create(module_name: &'static kernel::str::CStr) -> firmware::ModInfoBuilder<N> { +/// let mut builder = firmware::ModInfoBuilder::new(module_name); +/// +/// let mut i = 0; +/// while i < Self::FILES.len() { +/// builder = builder.new_entry() +/// .push(Self::DIR) +/// .push(Self::FILES[i]) +/// .push(".bin"); +/// +/// i += 1; +/// } +/// +/// builder +/// } +/// } +/// +/// module! { +/// type: MyModule, +/// name: "module_firmware_test", +/// authors: ["Rust for Linux"], +/// description: "module_firmware! test module", +/// license: "GPL", +/// } +/// +/// kernel::module_firmware!(Builder); +/// # } +/// ``` +#[macro_export] +macro_rules! module_firmware { + // The argument is the builder type without the const generic, since it's deferred from within + // this macro. Hence, we can neither use `expr` nor `ty`. + ($($builder:tt)*) => { + const _: () = { + const __MODULE_FIRMWARE_PREFIX: &'static $crate::str::CStr = if cfg!(MODULE) { + $crate::c_str!("") + } else { + <LocalModule as $crate::ModuleMetadata>::NAME + }; + + #[link_section = ".modinfo"] + #[used(compiler)] + static __MODULE_FIRMWARE: [u8; $($builder)*::create(__MODULE_FIRMWARE_PREFIX) + .build_length()] = $($builder)*::create(__MODULE_FIRMWARE_PREFIX).build(); + }; + }; +} + +/// Builder for firmware module info. +/// +/// [`ModInfoBuilder`] is a helper component to flexibly compose firmware paths strings for the +/// .modinfo section in const context. +/// +/// Therefore the [`ModInfoBuilder`] provides the methods [`ModInfoBuilder::new_entry`] and +/// [`ModInfoBuilder::push`], where the latter is used to push path components and the former to +/// mark the beginning of a new path string. +/// +/// [`ModInfoBuilder`] is meant to be used in combination with [`kernel::module_firmware!`]. +/// +/// The const generic `N` as well as the `module_name` parameter of [`ModInfoBuilder::new`] is an +/// internal implementation detail and supplied through the above macro. +pub struct ModInfoBuilder<const N: usize> { + buf: [u8; N], + n: usize, + module_name: &'static CStr, +} + +impl<const N: usize> ModInfoBuilder<N> { + /// Create an empty builder instance. + pub const fn new(module_name: &'static CStr) -> Self { + Self { + buf: [0; N], + n: 0, + module_name, + } + } + + const fn push_internal(mut self, bytes: &[u8]) -> Self { + let mut j = 0; + + if N == 0 { + self.n += bytes.len(); + return self; + } + + while j < bytes.len() { + if self.n < N { + self.buf[self.n] = bytes[j]; + } + self.n += 1; + j += 1; + } + self + } + + /// Push an additional path component. + /// + /// Append path components to the [`ModInfoBuilder`] instance. Paths need to be separated + /// with [`ModInfoBuilder::new_entry`]. + /// + /// # Examples + /// + /// ``` + /// use kernel::firmware::ModInfoBuilder; + /// + /// # const DIR: &str = "vendor/chip/"; + /// # const fn no_run<const N: usize>(builder: ModInfoBuilder<N>) { + /// let builder = builder.new_entry() + /// .push(DIR) + /// .push("foo.bin") + /// .new_entry() + /// .push(DIR) + /// .push("bar.bin"); + /// # } + /// ``` + pub const fn push(self, s: &str) -> Self { + // Check whether there has been an initial call to `next_entry()`. + if N != 0 && self.n == 0 { + crate::build_error!("Must call next_entry() before push()."); + } + + self.push_internal(s.as_bytes()) + } + + const fn push_module_name(self) -> Self { + let mut this = self; + let module_name = this.module_name; + + if !this.module_name.is_empty() { + this = this.push_internal(module_name.as_bytes_with_nul()); + + if N != 0 { + // Re-use the space taken by the NULL terminator and swap it with the '.' separator. + this.buf[this.n - 1] = b'.'; + } + } + + this + } + + /// Prepare the [`ModInfoBuilder`] for the next entry. + /// + /// This method acts as a separator between module firmware path entries. + /// + /// Must be called before constructing a new entry with subsequent calls to + /// [`ModInfoBuilder::push`]. + /// + /// See [`ModInfoBuilder::push`] for an example. + pub const fn new_entry(self) -> Self { + self.push_internal(b"\0") + .push_module_name() + .push_internal(b"firmware=") + } + + /// Build the byte array. + pub const fn build(self) -> [u8; N] { + // Add the final NULL terminator. + let this = self.push_internal(b"\0"); + + if this.n == N { + this.buf + } else { + crate::build_error!("Length mismatch."); + } + } +} + +impl ModInfoBuilder<0> { + /// Return the length of the byte array to build. + pub const fn build_length(self) -> usize { + // Compensate for the NULL terminator added by `build`. + self.n + 1 + } +} diff --git a/rust/kernel/fmt.rs b/rust/kernel/fmt.rs new file mode 100644 index 000000000000..0306e8388968 --- /dev/null +++ b/rust/kernel/fmt.rs @@ -0,0 +1,7 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Formatting utilities. +//! +//! This module is intended to be used in place of `core::fmt` in kernel code. + +pub use core::fmt::{Arguments, Debug, Display, Error, Formatter, Result, Write}; diff --git a/rust/kernel/fs/file.rs b/rust/kernel/fs/file.rs index e03dbe14d62a..35fd5db35c46 100644 --- a/rust/kernel/fs/file.rs +++ b/rust/kernel/fs/file.rs @@ -219,12 +219,13 @@ unsafe impl AlwaysRefCounted for File { /// must be on the same thread as this file. /// /// [`assume_no_fdget_pos`]: LocalFile::assume_no_fdget_pos +#[repr(transparent)] pub struct LocalFile { inner: Opaque<bindings::file>, } // SAFETY: The type invariants guarantee that `LocalFile` is always ref-counted. This implementation -// makes `ARef<File>` own a normal refcount. +// makes `ARef<LocalFile>` own a normal refcount. unsafe impl AlwaysRefCounted for LocalFile { #[inline] fn inc_ref(&self) { @@ -235,7 +236,8 @@ unsafe impl AlwaysRefCounted for LocalFile { #[inline] unsafe fn dec_ref(obj: ptr::NonNull<LocalFile>) { // SAFETY: To call this method, the caller passes us ownership of a normal refcount, so we - // may drop it. The cast is okay since `File` has the same representation as `struct file`. + // may drop it. The cast is okay since `LocalFile` has the same representation as + // `struct file`. unsafe { bindings::fput(obj.cast().as_ptr()) } } } @@ -267,13 +269,13 @@ impl LocalFile { /// # Safety /// /// * The caller must ensure that `ptr` points at a valid file and that the file's refcount is - /// positive for the duration of 'a. + /// positive for the duration of `'a`. /// * The caller must ensure that if there is an active call to `fdget_pos` that did not take /// the `f_pos_lock` mutex, then that call is on the current thread. #[inline] pub unsafe fn from_raw_file<'a>(ptr: *const bindings::file) -> &'a LocalFile { // SAFETY: The caller guarantees that the pointer is not dangling and stays valid for the - // duration of 'a. The cast is okay because `File` is `repr(transparent)`. + // duration of `'a`. The cast is okay because `LocalFile` is `repr(transparent)`. // // INVARIANT: The caller guarantees that there are no problematic `fdget_pos` calls. unsafe { &*ptr.cast() } @@ -341,13 +343,13 @@ impl File { /// # Safety /// /// * The caller must ensure that `ptr` points at a valid file and that the file's refcount is - /// positive for the duration of 'a. + /// positive for the duration of `'a`. /// * The caller must ensure that if there are active `fdget_pos` calls on this file, then they /// took the `f_pos_lock` mutex. #[inline] pub unsafe fn from_raw_file<'a>(ptr: *const bindings::file) -> &'a File { // SAFETY: The caller guarantees that the pointer is not dangling and stays valid for the - // duration of 'a. The cast is okay because `File` is `repr(transparent)`. + // duration of `'a`. The cast is okay because `File` is `repr(transparent)`. // // INVARIANT: The caller guarantees that there are no problematic `fdget_pos` calls. unsafe { &*ptr.cast() } @@ -364,7 +366,7 @@ impl core::ops::Deref for File { // // By the type invariants, there are no `fdget_pos` calls that did not take the // `f_pos_lock` mutex. - unsafe { LocalFile::from_raw_file(self as *const File as *const bindings::file) } + unsafe { LocalFile::from_raw_file(core::ptr::from_ref(self).cast()) } } } @@ -392,6 +394,7 @@ pub struct FileDescriptorReservation { impl FileDescriptorReservation { /// Creates a new file descriptor reservation. + #[inline] pub fn get_unused_fd_flags(flags: u32) -> Result<Self> { // SAFETY: FFI call, there are no safety requirements on `flags`. let fd: i32 = unsafe { bindings::get_unused_fd_flags(flags) }; @@ -405,6 +408,7 @@ impl FileDescriptorReservation { } /// Returns the file descriptor number that was reserved. + #[inline] pub fn reserved_fd(&self) -> u32 { self.fd } @@ -413,6 +417,7 @@ impl FileDescriptorReservation { /// /// The previously reserved file descriptor is bound to `file`. This method consumes the /// [`FileDescriptorReservation`], so it will not be usable after this call. + #[inline] pub fn fd_install(self, file: ARef<File>) { // SAFETY: `self.fd` was previously returned by `get_unused_fd_flags`. We have not yet used // the fd, so it is still valid, and `current` still refers to the same task, as this type @@ -433,6 +438,7 @@ impl FileDescriptorReservation { } impl Drop for FileDescriptorReservation { + #[inline] fn drop(&mut self) { // SAFETY: By the type invariants of this type, `self.fd` was previously returned by // `get_unused_fd_flags`. We have not yet used the fd, so it is still valid, and `current` diff --git a/rust/kernel/generated_arch_reachable_asm.rs.S b/rust/kernel/generated_arch_reachable_asm.rs.S new file mode 100644 index 000000000000..3886a9ad3a99 --- /dev/null +++ b/rust/kernel/generated_arch_reachable_asm.rs.S @@ -0,0 +1,7 @@ +/* SPDX-License-Identifier: GPL-2.0 */ + +#include <linux/bug.h> + +// Cut here. + +::kernel::concat_literals!(ARCH_WARN_REACHABLE) diff --git a/rust/kernel/generated_arch_warn_asm.rs.S b/rust/kernel/generated_arch_warn_asm.rs.S new file mode 100644 index 000000000000..409eb4c2d3a1 --- /dev/null +++ b/rust/kernel/generated_arch_warn_asm.rs.S @@ -0,0 +1,7 @@ +/* SPDX-License-Identifier: GPL-2.0 */ + +#include <linux/bug.h> + +// Cut here. + +::kernel::concat_literals!(ARCH_WARN_ASM("{file}", "{line}", "{flags}", "{size}")) diff --git a/rust/kernel/init.rs b/rust/kernel/init.rs index c962029f96e1..4949047af8d7 100644 --- a/rust/kernel/init.rs +++ b/rust/kernel/init.rs @@ -1,144 +1,82 @@ -// SPDX-License-Identifier: Apache-2.0 OR MIT +// SPDX-License-Identifier: GPL-2.0 -//! API to safely and fallibly initialize pinned `struct`s using in-place constructors. -//! -//! It also allows in-place initialization of big `struct`s that would otherwise produce a stack -//! overflow. +//! Extensions to the [`pin-init`] crate. //! //! Most `struct`s from the [`sync`] module need to be pinned, because they contain self-referential //! `struct`s from C. [Pinning][pinning] is Rust's way of ensuring data does not move. //! -//! # Overview -//! -//! To initialize a `struct` with an in-place constructor you will need two things: -//! - an in-place constructor, -//! - a memory location that can hold your `struct` (this can be the [stack], an [`Arc<T>`], -//! [`UniqueArc<T>`], [`KBox<T>`] or any other smart pointer that implements [`InPlaceInit`]). -//! -//! To get an in-place constructor there are generally three options: -//! - directly creating an in-place constructor using the [`pin_init!`] macro, -//! - a custom function/macro returning an in-place constructor provided by someone else, -//! - using the unsafe function [`pin_init_from_closure()`] to manually create an initializer. -//! -//! Aside from pinned initialization, this API also supports in-place construction without pinning, -//! the macros/types/functions are generally named like the pinned variants without the `pin` -//! prefix. -//! -//! # Examples -//! -//! ## Using the [`pin_init!`] macro -//! -//! If you want to use [`PinInit`], then you will have to annotate your `struct` with -//! `#[`[`pin_data`]`]`. It is a macro that uses `#[pin]` as a marker for -//! [structurally pinned fields]. After doing this, you can then create an in-place constructor via -//! [`pin_init!`]. The syntax is almost the same as normal `struct` initializers. The difference is -//! that you need to write `<-` instead of `:` for fields that you want to initialize in-place. +//! The [`pin-init`] crate is the way such structs are initialized on the Rust side. Please refer +//! to its documentation to better understand how to use it. Additionally, there are many examples +//! throughout the kernel, such as the types from the [`sync`] module. And the ones presented +//! below. //! -//! ```rust -//! # #![expect(clippy::disallowed_names)] -//! use kernel::sync::{new_mutex, Mutex}; -//! # use core::pin::Pin; -//! #[pin_data] -//! struct Foo { -//! #[pin] -//! a: Mutex<usize>, -//! b: u32, -//! } -//! -//! let foo = pin_init!(Foo { -//! a <- new_mutex!(42, "Foo::a"), -//! b: 24, -//! }); -//! ``` +//! [`sync`]: crate::sync +//! [pinning]: https://doc.rust-lang.org/std/pin/index.html +//! [`pin-init`]: https://rust.docs.kernel.org/pin_init/ //! -//! `foo` now is of the type [`impl PinInit<Foo>`]. We can now use any smart pointer that we like -//! (or just the stack) to actually initialize a `Foo`: +//! # [`Opaque<T>`] //! -//! ```rust -//! # #![expect(clippy::disallowed_names)] -//! # use kernel::sync::{new_mutex, Mutex}; -//! # use core::pin::Pin; -//! # #[pin_data] -//! # struct Foo { -//! # #[pin] -//! # a: Mutex<usize>, -//! # b: u32, -//! # } -//! # let foo = pin_init!(Foo { -//! # a <- new_mutex!(42, "Foo::a"), -//! # b: 24, -//! # }); -//! let foo: Result<Pin<KBox<Foo>>> = KBox::pin_init(foo, GFP_KERNEL); -//! ``` +//! For the special case where initializing a field is a single FFI-function call that cannot fail, +//! there exist the helper function [`Opaque::ffi_init`]. This function initialize a single +//! [`Opaque<T>`] field by just delegating to the supplied closure. You can use these in +//! combination with [`pin_init!`]. //! -//! For more information see the [`pin_init!`] macro. +//! [`Opaque<T>`]: crate::types::Opaque +//! [`Opaque::ffi_init`]: crate::types::Opaque::ffi_init +//! [`pin_init!`]: pin_init::pin_init //! -//! ## Using a custom function/macro that returns an initializer +//! # Examples //! -//! Many types from the kernel supply a function/macro that returns an initializer, because the -//! above method only works for types where you can access the fields. +//! ## General Examples //! //! ```rust -//! # use kernel::sync::{new_mutex, Arc, Mutex}; -//! let mtx: Result<Arc<Mutex<usize>>> = -//! Arc::pin_init(new_mutex!(42, "example::mtx"), GFP_KERNEL); -//! ``` +//! # #![expect(clippy::disallowed_names, clippy::undocumented_unsafe_blocks)] +//! use kernel::types::Opaque; +//! use pin_init::pin_init_from_closure; //! -//! To declare an init macro/function you just return an [`impl PinInit<T, E>`]: +//! // assume we have some `raw_foo` type in C: +//! #[repr(C)] +//! struct RawFoo([u8; 16]); +//! extern "C" { +//! fn init_foo(_: *mut RawFoo); +//! } //! -//! ```rust -//! # use kernel::{sync::Mutex, new_mutex, init::PinInit, try_pin_init}; //! #[pin_data] -//! struct DriverData { +//! struct Foo { //! #[pin] -//! status: Mutex<i32>, -//! buffer: KBox<[u8; 1_000_000]>, +//! raw: Opaque<RawFoo>, //! } //! -//! impl DriverData { -//! fn new() -> impl PinInit<Self, Error> { -//! try_pin_init!(Self { -//! status <- new_mutex!(0, "DriverData::status"), -//! buffer: KBox::init(kernel::init::zeroed(), GFP_KERNEL)?, -//! }) +//! impl Foo { +//! fn setup(self: Pin<&mut Self>) { +//! pr_info!("Setting up foo\n"); //! } //! } -//! ``` -//! -//! ## Manual creation of an initializer //! -//! Often when working with primitives the previous approaches are not sufficient. That is where -//! [`pin_init_from_closure()`] comes in. This `unsafe` function allows you to create a -//! [`impl PinInit<T, E>`] directly from a closure. Of course you have to ensure that the closure -//! actually does the initialization in the correct way. Here are the things to look out for -//! (we are calling the parameter to the closure `slot`): -//! - when the closure returns `Ok(())`, then it has completed the initialization successfully, so -//! `slot` now contains a valid bit pattern for the type `T`, -//! - when the closure returns `Err(e)`, then the caller may deallocate the memory at `slot`, so -//! you need to take care to clean up anything if your initialization fails mid-way, -//! - you may assume that `slot` will stay pinned even after the closure returns until `drop` of -//! `slot` gets called. +//! let foo = pin_init!(Foo { +//! raw <- unsafe { +//! Opaque::ffi_init(|s| { +//! // note that this cannot fail. +//! init_foo(s); +//! }) +//! }, +//! }).pin_chain(|foo| { +//! foo.setup(); +//! Ok(()) +//! }); +//! ``` //! //! ```rust //! # #![expect(unreachable_pub, clippy::disallowed_names)] -//! use kernel::{init, types::Opaque}; +//! use kernel::{prelude::*, types::Opaque}; //! use core::{ptr::addr_of_mut, marker::PhantomPinned, pin::Pin}; //! # mod bindings { -//! # #![expect(non_camel_case_types)] -//! # #![expect(clippy::missing_safety_doc)] +//! # #![expect(non_camel_case_types, clippy::missing_safety_doc)] //! # pub struct foo; //! # pub unsafe fn init_foo(_ptr: *mut foo) {} //! # pub unsafe fn destroy_foo(_ptr: *mut foo) {} //! # pub unsafe fn enable_foo(_ptr: *mut foo, _flags: u32) -> i32 { 0 } //! # } -//! # // `Error::from_errno` is `pub(crate)` in the `kernel` crate, thus provide a workaround. -//! # trait FromErrno { -//! # fn from_errno(errno: kernel::ffi::c_int) -> Error { -//! # // Dummy error that can be constructed outside the `kernel` crate. -//! # Error::from(core::fmt::Error) -//! # } -//! # } -//! # impl FromErrno for Error {} //! /// # Invariants //! /// //! /// `foo` is always initialized @@ -157,18 +95,18 @@ //! // enabled `foo`, //! // - when it returns `Err(e)`, then it has cleaned up before //! unsafe { -//! init::pin_init_from_closure(move |slot: *mut Self| { +//! pin_init::pin_init_from_closure(move |slot: *mut Self| { //! // `slot` contains uninit memory, avoid creating a reference. //! let foo = addr_of_mut!((*slot).foo); //! //! // Initialize the `foo` -//! bindings::init_foo(Opaque::raw_get(foo)); +//! bindings::init_foo(Opaque::cast_into(foo)); //! //! // Try to enable it. -//! let err = bindings::enable_foo(Opaque::raw_get(foo), flags); +//! let err = bindings::enable_foo(Opaque::cast_into(foo), flags); //! if err != 0 { //! // Enabling has failed, first clean up the foo and then return the error. -//! bindings::destroy_foo(Opaque::raw_get(foo)); +//! bindings::destroy_foo(Opaque::cast_into(foo)); //! return Err(Error::from_errno(err)); //! } //! @@ -187,385 +125,114 @@ //! } //! } //! ``` -//! -//! For the special case where initializing a field is a single FFI-function call that cannot fail, -//! there exist the helper function [`Opaque::ffi_init`]. This function initialize a single -//! [`Opaque`] field by just delegating to the supplied closure. You can use these in combination -//! with [`pin_init!`]. -//! -//! For more information on how to use [`pin_init_from_closure()`], take a look at the uses inside -//! the `kernel` crate. The [`sync`] module is a good starting point. -//! -//! [`sync`]: kernel::sync -//! [pinning]: https://doc.rust-lang.org/std/pin/index.html -//! [structurally pinned fields]: -//! https://doc.rust-lang.org/std/pin/index.html#pinning-is-structural-for-field -//! [stack]: crate::stack_pin_init -//! [`Arc<T>`]: crate::sync::Arc -//! [`impl PinInit<Foo>`]: PinInit -//! [`impl PinInit<T, E>`]: PinInit -//! [`impl Init<T, E>`]: Init -//! [`Opaque`]: kernel::types::Opaque -//! [`Opaque::ffi_init`]: kernel::types::Opaque::ffi_init -//! [`pin_data`]: ::macros::pin_data -//! [`pin_init!`]: crate::pin_init! use crate::{ - alloc::{AllocError, Flags, KBox}, + alloc::{AllocError, Flags}, error::{self, Error}, - sync::Arc, - sync::UniqueArc, - types::{Opaque, ScopeGuard}, -}; -use core::{ - cell::UnsafeCell, - convert::Infallible, - marker::PhantomData, - mem::MaybeUninit, - num::*, - pin::Pin, - ptr::{self, NonNull}, }; +use pin_init::{init_from_closure, pin_init_from_closure, Init, PinInit}; -#[doc(hidden)] -pub mod __internal; -#[doc(hidden)] -pub mod macros; +/// Smart pointer that can initialize memory in-place. +pub trait InPlaceInit<T>: Sized { + /// Pinned version of `Self`. + /// + /// If a type already implicitly pins its pointee, `Pin<Self>` is unnecessary. In this case use + /// `Self`, otherwise just use `Pin<Self>`. + type PinnedSelf; -/// Initialize and pin a type directly on the stack. -/// -/// # Examples -/// -/// ```rust -/// # #![expect(clippy::disallowed_names)] -/// # use kernel::{init, macros::pin_data, pin_init, stack_pin_init, init::*, sync::Mutex, new_mutex}; -/// # use core::pin::Pin; -/// #[pin_data] -/// struct Foo { -/// #[pin] -/// a: Mutex<usize>, -/// b: Bar, -/// } -/// -/// #[pin_data] -/// struct Bar { -/// x: u32, -/// } -/// -/// stack_pin_init!(let foo = pin_init!(Foo { -/// a <- new_mutex!(42), -/// b: Bar { -/// x: 64, -/// }, -/// })); -/// let foo: Pin<&mut Foo> = foo; -/// pr_info!("a: {}", &*foo.a.lock()); -/// ``` -/// -/// # Syntax -/// -/// A normal `let` binding with optional type annotation. The expression is expected to implement -/// [`PinInit`]/[`Init`] with the error type [`Infallible`]. If you want to use a different error -/// type, then use [`stack_try_pin_init!`]. -/// -/// [`stack_try_pin_init!`]: crate::stack_try_pin_init! -#[macro_export] -macro_rules! stack_pin_init { - (let $var:ident $(: $t:ty)? = $val:expr) => { - let val = $val; - let mut $var = ::core::pin::pin!($crate::init::__internal::StackInit$(::<$t>)?::uninit()); - let mut $var = match $crate::init::__internal::StackInit::init($var, val) { - Ok(res) => res, - Err(x) => { - let x: ::core::convert::Infallible = x; - match x {} - } + /// Use the given pin-initializer to pin-initialize a `T` inside of a new smart pointer of this + /// type. + /// + /// If `T: !Unpin` it will not be able to move afterwards. + fn try_pin_init<E>(init: impl PinInit<T, E>, flags: Flags) -> Result<Self::PinnedSelf, E> + where + E: From<AllocError>; + + /// Use the given pin-initializer to pin-initialize a `T` inside of a new smart pointer of this + /// type. + /// + /// If `T: !Unpin` it will not be able to move afterwards. + fn pin_init<E>(init: impl PinInit<T, E>, flags: Flags) -> error::Result<Self::PinnedSelf> + where + Error: From<E>, + { + // SAFETY: We delegate to `init` and only change the error type. + let init = unsafe { + pin_init_from_closure(|slot| init.__pinned_init(slot).map_err(|e| Error::from(e))) }; - }; -} + Self::try_pin_init(init, flags) + } -/// Initialize and pin a type directly on the stack. -/// -/// # Examples -/// -/// ```rust,ignore -/// # #![expect(clippy::disallowed_names)] -/// # use kernel::{init, pin_init, stack_try_pin_init, init::*, sync::Mutex, new_mutex}; -/// # use macros::pin_data; -/// # use core::{alloc::AllocError, pin::Pin}; -/// #[pin_data] -/// struct Foo { -/// #[pin] -/// a: Mutex<usize>, -/// b: KBox<Bar>, -/// } -/// -/// struct Bar { -/// x: u32, -/// } -/// -/// stack_try_pin_init!(let foo: Result<Pin<&mut Foo>, AllocError> = pin_init!(Foo { -/// a <- new_mutex!(42), -/// b: KBox::new(Bar { -/// x: 64, -/// }, GFP_KERNEL)?, -/// })); -/// let foo = foo.unwrap(); -/// pr_info!("a: {}", &*foo.a.lock()); -/// ``` -/// -/// ```rust,ignore -/// # #![expect(clippy::disallowed_names)] -/// # use kernel::{init, pin_init, stack_try_pin_init, init::*, sync::Mutex, new_mutex}; -/// # use macros::pin_data; -/// # use core::{alloc::AllocError, pin::Pin}; -/// #[pin_data] -/// struct Foo { -/// #[pin] -/// a: Mutex<usize>, -/// b: KBox<Bar>, -/// } -/// -/// struct Bar { -/// x: u32, -/// } -/// -/// stack_try_pin_init!(let foo: Pin<&mut Foo> =? pin_init!(Foo { -/// a <- new_mutex!(42), -/// b: KBox::new(Bar { -/// x: 64, -/// }, GFP_KERNEL)?, -/// })); -/// pr_info!("a: {}", &*foo.a.lock()); -/// # Ok::<_, AllocError>(()) -/// ``` -/// -/// # Syntax -/// -/// A normal `let` binding with optional type annotation. The expression is expected to implement -/// [`PinInit`]/[`Init`]. This macro assigns a result to the given variable, adding a `?` after the -/// `=` will propagate this error. -#[macro_export] -macro_rules! stack_try_pin_init { - (let $var:ident $(: $t:ty)? = $val:expr) => { - let val = $val; - let mut $var = ::core::pin::pin!($crate::init::__internal::StackInit$(::<$t>)?::uninit()); - let mut $var = $crate::init::__internal::StackInit::init($var, val); - }; - (let $var:ident $(: $t:ty)? =? $val:expr) => { - let val = $val; - let mut $var = ::core::pin::pin!($crate::init::__internal::StackInit$(::<$t>)?::uninit()); - let mut $var = $crate::init::__internal::StackInit::init($var, val)?; - }; + /// Use the given initializer to in-place initialize a `T`. + fn try_init<E>(init: impl Init<T, E>, flags: Flags) -> Result<Self, E> + where + E: From<AllocError>; + + /// Use the given initializer to in-place initialize a `T`. + fn init<E>(init: impl Init<T, E>, flags: Flags) -> error::Result<Self> + where + Error: From<E>, + { + // SAFETY: We delegate to `init` and only change the error type. + let init = unsafe { + init_from_closure(|slot| init.__pinned_init(slot).map_err(|e| Error::from(e))) + }; + Self::try_init(init, flags) + } } -/// Construct an in-place, pinned initializer for `struct`s. -/// -/// This macro defaults the error to [`Infallible`]. If you need [`Error`], then use -/// [`try_pin_init!`]. -/// -/// The syntax is almost identical to that of a normal `struct` initializer: -/// -/// ```rust -/// # use kernel::{init, pin_init, macros::pin_data, init::*}; -/// # use core::pin::Pin; -/// #[pin_data] -/// struct Foo { -/// a: usize, -/// b: Bar, -/// } -/// -/// #[pin_data] -/// struct Bar { -/// x: u32, -/// } -/// -/// # fn demo() -> impl PinInit<Foo> { -/// let a = 42; -/// -/// let initializer = pin_init!(Foo { -/// a, -/// b: Bar { -/// x: 64, -/// }, -/// }); -/// # initializer } -/// # KBox::pin_init(demo(), GFP_KERNEL).unwrap(); -/// ``` -/// -/// Arbitrary Rust expressions can be used to set the value of a variable. -/// -/// The fields are initialized in the order that they appear in the initializer. So it is possible -/// to read already initialized fields using raw pointers. -/// -/// IMPORTANT: You are not allowed to create references to fields of the struct inside of the -/// initializer. -/// -/// # Init-functions -/// -/// When working with this API it is often desired to let others construct your types without -/// giving access to all fields. This is where you would normally write a plain function `new` -/// that would return a new instance of your type. With this API that is also possible. -/// However, there are a few extra things to keep in mind. -/// -/// To create an initializer function, simply declare it like this: -/// -/// ```rust -/// # use kernel::{init, pin_init, init::*}; -/// # use core::pin::Pin; -/// # #[pin_data] -/// # struct Foo { -/// # a: usize, -/// # b: Bar, -/// # } -/// # #[pin_data] -/// # struct Bar { -/// # x: u32, -/// # } -/// impl Foo { -/// fn new() -> impl PinInit<Self> { -/// pin_init!(Self { -/// a: 42, -/// b: Bar { -/// x: 64, -/// }, -/// }) -/// } -/// } -/// ``` +/// Construct an in-place fallible initializer for `struct`s. /// -/// Users of `Foo` can now create it like this: +/// This macro defaults the error to [`Error`]. If you need [`Infallible`], then use +/// [`init!`]. /// -/// ```rust -/// # #![expect(clippy::disallowed_names)] -/// # use kernel::{init, pin_init, macros::pin_data, init::*}; -/// # use core::pin::Pin; -/// # #[pin_data] -/// # struct Foo { -/// # a: usize, -/// # b: Bar, -/// # } -/// # #[pin_data] -/// # struct Bar { -/// # x: u32, -/// # } -/// # impl Foo { -/// # fn new() -> impl PinInit<Self> { -/// # pin_init!(Self { -/// # a: 42, -/// # b: Bar { -/// # x: 64, -/// # }, -/// # }) -/// # } -/// # } -/// let foo = KBox::pin_init(Foo::new(), GFP_KERNEL); -/// ``` +/// The syntax is identical to [`try_pin_init!`]. If you want to specify a custom error, +/// append `? $type` after the `struct` initializer. +/// The safety caveats from [`try_pin_init!`] also apply: +/// - `unsafe` code must guarantee either full initialization or return an error and allow +/// deallocation of the memory. +/// - the fields are initialized in the order given in the initializer. +/// - no references to fields are allowed to be created inside of the initializer. /// -/// They can also easily embed it into their own `struct`s: +/// # Examples /// /// ```rust -/// # use kernel::{init, pin_init, macros::pin_data, init::*}; -/// # use core::pin::Pin; -/// # #[pin_data] -/// # struct Foo { -/// # a: usize, -/// # b: Bar, -/// # } -/// # #[pin_data] -/// # struct Bar { -/// # x: u32, -/// # } -/// # impl Foo { -/// # fn new() -> impl PinInit<Self> { -/// # pin_init!(Self { -/// # a: 42, -/// # b: Bar { -/// # x: 64, -/// # }, -/// # }) -/// # } -/// # } -/// #[pin_data] -/// struct FooContainer { -/// #[pin] -/// foo1: Foo, -/// #[pin] -/// foo2: Foo, -/// other: u32, +/// use kernel::error::Error; +/// use pin_init::init_zeroed; +/// struct BigBuf { +/// big: KBox<[u8; 1024 * 1024 * 1024]>, +/// small: [u8; 1024 * 1024], /// } /// -/// impl FooContainer { -/// fn new(other: u32) -> impl PinInit<Self> { -/// pin_init!(Self { -/// foo1 <- Foo::new(), -/// foo2 <- Foo::new(), -/// other, -/// }) +/// impl BigBuf { +/// fn new() -> impl Init<Self, Error> { +/// try_init!(Self { +/// big: KBox::init(init_zeroed(), GFP_KERNEL)?, +/// small: [0; 1024 * 1024], +/// }? Error) /// } /// } /// ``` /// -/// Here we see that when using `pin_init!` with `PinInit`, one needs to write `<-` instead of `:`. -/// This signifies that the given field is initialized in-place. As with `struct` initializers, just -/// writing the field (in this case `other`) without `:` or `<-` means `other: other,`. -/// -/// # Syntax -/// -/// As already mentioned in the examples above, inside of `pin_init!` a `struct` initializer with -/// the following modifications is expected: -/// - Fields that you want to initialize in-place have to use `<-` instead of `:`. -/// - In front of the initializer you can write `&this in` to have access to a [`NonNull<Self>`] -/// pointer named `this` inside of the initializer. -/// - Using struct update syntax one can place `..Zeroable::zeroed()` at the very end of the -/// struct, this initializes every field with 0 and then runs all initializers specified in the -/// body. This can only be done if [`Zeroable`] is implemented for the struct. -/// -/// For instance: -/// -/// ```rust -/// # use kernel::{macros::{Zeroable, pin_data}, pin_init}; -/// # use core::{ptr::addr_of_mut, marker::PhantomPinned}; -/// #[pin_data] -/// #[derive(Zeroable)] -/// struct Buf { -/// // `ptr` points into `buf`. -/// ptr: *mut u8, -/// buf: [u8; 64], -/// #[pin] -/// pin: PhantomPinned, -/// } -/// pin_init!(&this in Buf { -/// buf: [0; 64], -/// // SAFETY: TODO. -/// ptr: unsafe { addr_of_mut!((*this.as_ptr()).buf).cast() }, -/// pin: PhantomPinned, -/// }); -/// pin_init!(Buf { -/// buf: [1; 64], -/// ..Zeroable::zeroed() -/// }); -/// ``` -/// -/// [`try_pin_init!`]: kernel::try_pin_init -/// [`NonNull<Self>`]: core::ptr::NonNull -// For a detailed example of how this macro works, see the module documentation of the hidden -// module `__internal` inside of `init/__internal.rs`. +/// [`Infallible`]: core::convert::Infallible +/// [`init!`]: pin_init::init +/// [`try_pin_init!`]: crate::try_pin_init! +/// [`Error`]: crate::error::Error #[macro_export] -macro_rules! pin_init { +macro_rules! try_init { ($(&$this:ident in)? $t:ident $(::<$($generics:ty),* $(,)?>)? { $($fields:tt)* }) => { - $crate::__init_internal!( - @this($($this)?), - @typ($t $(::<$($generics),*>)?), - @fields($($fields)*), - @error(::core::convert::Infallible), - @data(PinData, use_data), - @has_data(HasPinData, __pin_data), - @construct_closure(pin_init_from_closure), - @munch_fields($($fields)*), - ) + ::pin_init::try_init!($(&$this in)? $t $(::<$($generics),*>)? { + $($fields)* + }? $crate::error::Error) + }; + ($(&$this:ident in)? $t:ident $(::<$($generics:ty),* $(,)?>)? { + $($fields:tt)* + }? $err:ty) => { + ::pin_init::try_init!($(&$this in)? $t $(::<$($generics),*>)? { + $($fields)* + }? $err) }; } @@ -587,7 +254,9 @@ macro_rules! pin_init { /// # Examples /// /// ```rust -/// use kernel::{init::{self, PinInit}, error::Error}; +/// # #![feature(new_uninit)] +/// use kernel::error::Error; +/// use pin_init::init_zeroed; /// #[pin_data] /// struct BigBuf { /// big: KBox<[u8; 1024 * 1024 * 1024]>, @@ -598,844 +267,31 @@ macro_rules! pin_init { /// impl BigBuf { /// fn new() -> impl PinInit<Self, Error> { /// try_pin_init!(Self { -/// big: KBox::init(init::zeroed(), GFP_KERNEL)?, +/// big: KBox::init(init_zeroed(), GFP_KERNEL)?, /// small: [0; 1024 * 1024], /// ptr: core::ptr::null_mut(), /// }? Error) /// } /// } /// ``` -// For a detailed example of how this macro works, see the module documentation of the hidden -// module `__internal` inside of `init/__internal.rs`. -#[macro_export] -macro_rules! try_pin_init { - ($(&$this:ident in)? $t:ident $(::<$($generics:ty),* $(,)?>)? { - $($fields:tt)* - }) => { - $crate::__init_internal!( - @this($($this)?), - @typ($t $(::<$($generics),*>)? ), - @fields($($fields)*), - @error($crate::error::Error), - @data(PinData, use_data), - @has_data(HasPinData, __pin_data), - @construct_closure(pin_init_from_closure), - @munch_fields($($fields)*), - ) - }; - ($(&$this:ident in)? $t:ident $(::<$($generics:ty),* $(,)?>)? { - $($fields:tt)* - }? $err:ty) => { - $crate::__init_internal!( - @this($($this)?), - @typ($t $(::<$($generics),*>)? ), - @fields($($fields)*), - @error($err), - @data(PinData, use_data), - @has_data(HasPinData, __pin_data), - @construct_closure(pin_init_from_closure), - @munch_fields($($fields)*), - ) - }; -} - -/// Construct an in-place initializer for `struct`s. -/// -/// This macro defaults the error to [`Infallible`]. If you need [`Error`], then use -/// [`try_init!`]. -/// -/// The syntax is identical to [`pin_init!`] and its safety caveats also apply: -/// - `unsafe` code must guarantee either full initialization or return an error and allow -/// deallocation of the memory. -/// - the fields are initialized in the order given in the initializer. -/// - no references to fields are allowed to be created inside of the initializer. -/// -/// This initializer is for initializing data in-place that might later be moved. If you want to -/// pin-initialize, use [`pin_init!`]. -/// -/// [`try_init!`]: crate::try_init! -// For a detailed example of how this macro works, see the module documentation of the hidden -// module `__internal` inside of `init/__internal.rs`. -#[macro_export] -macro_rules! init { - ($(&$this:ident in)? $t:ident $(::<$($generics:ty),* $(,)?>)? { - $($fields:tt)* - }) => { - $crate::__init_internal!( - @this($($this)?), - @typ($t $(::<$($generics),*>)?), - @fields($($fields)*), - @error(::core::convert::Infallible), - @data(InitData, /*no use_data*/), - @has_data(HasInitData, __init_data), - @construct_closure(init_from_closure), - @munch_fields($($fields)*), - ) - } -} - -/// Construct an in-place fallible initializer for `struct`s. -/// -/// This macro defaults the error to [`Error`]. If you need [`Infallible`], then use -/// [`init!`]. -/// -/// The syntax is identical to [`try_pin_init!`]. If you want to specify a custom error, -/// append `? $type` after the `struct` initializer. -/// The safety caveats from [`try_pin_init!`] also apply: -/// - `unsafe` code must guarantee either full initialization or return an error and allow -/// deallocation of the memory. -/// - the fields are initialized in the order given in the initializer. -/// - no references to fields are allowed to be created inside of the initializer. -/// -/// # Examples /// -/// ```rust -/// use kernel::{alloc::KBox, init::{PinInit, zeroed}, error::Error}; -/// struct BigBuf { -/// big: KBox<[u8; 1024 * 1024 * 1024]>, -/// small: [u8; 1024 * 1024], -/// } -/// -/// impl BigBuf { -/// fn new() -> impl Init<Self, Error> { -/// try_init!(Self { -/// big: KBox::init(zeroed(), GFP_KERNEL)?, -/// small: [0; 1024 * 1024], -/// }? Error) -/// } -/// } -/// ``` -// For a detailed example of how this macro works, see the module documentation of the hidden -// module `__internal` inside of `init/__internal.rs`. +/// [`Infallible`]: core::convert::Infallible +/// [`pin_init!`]: pin_init::pin_init +/// [`Error`]: crate::error::Error #[macro_export] -macro_rules! try_init { +macro_rules! try_pin_init { ($(&$this:ident in)? $t:ident $(::<$($generics:ty),* $(,)?>)? { $($fields:tt)* }) => { - $crate::__init_internal!( - @this($($this)?), - @typ($t $(::<$($generics),*>)?), - @fields($($fields)*), - @error($crate::error::Error), - @data(InitData, /*no use_data*/), - @has_data(HasInitData, __init_data), - @construct_closure(init_from_closure), - @munch_fields($($fields)*), - ) + ::pin_init::try_pin_init!($(&$this in)? $t $(::<$($generics),*>)? { + $($fields)* + }? $crate::error::Error) }; ($(&$this:ident in)? $t:ident $(::<$($generics:ty),* $(,)?>)? { $($fields:tt)* }? $err:ty) => { - $crate::__init_internal!( - @this($($this)?), - @typ($t $(::<$($generics),*>)?), - @fields($($fields)*), - @error($err), - @data(InitData, /*no use_data*/), - @has_data(HasInitData, __init_data), - @construct_closure(init_from_closure), - @munch_fields($($fields)*), - ) + ::pin_init::try_pin_init!($(&$this in)? $t $(::<$($generics),*>)? { + $($fields)* + }? $err) }; } - -/// Asserts that a field on a struct using `#[pin_data]` is marked with `#[pin]` ie. that it is -/// structurally pinned. -/// -/// # Example -/// -/// This will succeed: -/// ``` -/// use kernel::assert_pinned; -/// #[pin_data] -/// struct MyStruct { -/// #[pin] -/// some_field: u64, -/// } -/// -/// assert_pinned!(MyStruct, some_field, u64); -/// ``` -/// -/// This will fail: -// TODO: replace with `compile_fail` when supported. -/// ```ignore -/// use kernel::assert_pinned; -/// #[pin_data] -/// struct MyStruct { -/// some_field: u64, -/// } -/// -/// assert_pinned!(MyStruct, some_field, u64); -/// ``` -/// -/// Some uses of the macro may trigger the `can't use generic parameters from outer item` error. To -/// work around this, you may pass the `inline` parameter to the macro. The `inline` parameter can -/// only be used when the macro is invoked from a function body. -/// ``` -/// use kernel::assert_pinned; -/// #[pin_data] -/// struct Foo<T> { -/// #[pin] -/// elem: T, -/// } -/// -/// impl<T> Foo<T> { -/// fn project(self: Pin<&mut Self>) -> Pin<&mut T> { -/// assert_pinned!(Foo<T>, elem, T, inline); -/// -/// // SAFETY: The field is structurally pinned. -/// unsafe { self.map_unchecked_mut(|me| &mut me.elem) } -/// } -/// } -/// ``` -#[macro_export] -macro_rules! assert_pinned { - ($ty:ty, $field:ident, $field_ty:ty, inline) => { - let _ = move |ptr: *mut $field_ty| { - // SAFETY: This code is unreachable. - let data = unsafe { <$ty as $crate::init::__internal::HasPinData>::__pin_data() }; - let init = $crate::init::__internal::AlwaysFail::<$field_ty>::new(); - // SAFETY: This code is unreachable. - unsafe { data.$field(ptr, init) }.ok(); - }; - }; - - ($ty:ty, $field:ident, $field_ty:ty) => { - const _: () = { - $crate::assert_pinned!($ty, $field, $field_ty, inline); - }; - }; -} - -/// A pin-initializer for the type `T`. -/// -/// To use this initializer, you will need a suitable memory location that can hold a `T`. This can -/// be [`KBox<T>`], [`Arc<T>`], [`UniqueArc<T>`] or even the stack (see [`stack_pin_init!`]). Use -/// the [`InPlaceInit::pin_init`] function of a smart pointer like [`Arc<T>`] on this. -/// -/// Also see the [module description](self). -/// -/// # Safety -/// -/// When implementing this trait you will need to take great care. Also there are probably very few -/// cases where a manual implementation is necessary. Use [`pin_init_from_closure`] where possible. -/// -/// The [`PinInit::__pinned_init`] function: -/// - returns `Ok(())` if it initialized every field of `slot`, -/// - returns `Err(err)` if it encountered an error and then cleaned `slot`, this means: -/// - `slot` can be deallocated without UB occurring, -/// - `slot` does not need to be dropped, -/// - `slot` is not partially initialized. -/// - while constructing the `T` at `slot` it upholds the pinning invariants of `T`. -/// -/// [`Arc<T>`]: crate::sync::Arc -/// [`Arc::pin_init`]: crate::sync::Arc::pin_init -#[must_use = "An initializer must be used in order to create its value."] -pub unsafe trait PinInit<T: ?Sized, E = Infallible>: Sized { - /// Initializes `slot`. - /// - /// # Safety - /// - /// - `slot` is a valid pointer to uninitialized memory. - /// - the caller does not touch `slot` when `Err` is returned, they are only permitted to - /// deallocate. - /// - `slot` will not move until it is dropped, i.e. it will be pinned. - unsafe fn __pinned_init(self, slot: *mut T) -> Result<(), E>; - - /// First initializes the value using `self` then calls the function `f` with the initialized - /// value. - /// - /// If `f` returns an error the value is dropped and the initializer will forward the error. - /// - /// # Examples - /// - /// ```rust - /// # #![expect(clippy::disallowed_names)] - /// use kernel::{types::Opaque, init::pin_init_from_closure}; - /// #[repr(C)] - /// struct RawFoo([u8; 16]); - /// extern "C" { - /// fn init_foo(_: *mut RawFoo); - /// } - /// - /// #[pin_data] - /// struct Foo { - /// #[pin] - /// raw: Opaque<RawFoo>, - /// } - /// - /// impl Foo { - /// fn setup(self: Pin<&mut Self>) { - /// pr_info!("Setting up foo"); - /// } - /// } - /// - /// let foo = pin_init!(Foo { - /// // SAFETY: TODO. - /// raw <- unsafe { - /// Opaque::ffi_init(|s| { - /// init_foo(s); - /// }) - /// }, - /// }).pin_chain(|foo| { - /// foo.setup(); - /// Ok(()) - /// }); - /// ``` - fn pin_chain<F>(self, f: F) -> ChainPinInit<Self, F, T, E> - where - F: FnOnce(Pin<&mut T>) -> Result<(), E>, - { - ChainPinInit(self, f, PhantomData) - } -} - -/// An initializer returned by [`PinInit::pin_chain`]. -pub struct ChainPinInit<I, F, T: ?Sized, E>(I, F, __internal::Invariant<(E, KBox<T>)>); - -// SAFETY: The `__pinned_init` function is implemented such that it -// - returns `Ok(())` on successful initialization, -// - returns `Err(err)` on error and in this case `slot` will be dropped. -// - considers `slot` pinned. -unsafe impl<T: ?Sized, E, I, F> PinInit<T, E> for ChainPinInit<I, F, T, E> -where - I: PinInit<T, E>, - F: FnOnce(Pin<&mut T>) -> Result<(), E>, -{ - unsafe fn __pinned_init(self, slot: *mut T) -> Result<(), E> { - // SAFETY: All requirements fulfilled since this function is `__pinned_init`. - unsafe { self.0.__pinned_init(slot)? }; - // SAFETY: The above call initialized `slot` and we still have unique access. - let val = unsafe { &mut *slot }; - // SAFETY: `slot` is considered pinned. - let val = unsafe { Pin::new_unchecked(val) }; - // SAFETY: `slot` was initialized above. - (self.1)(val).inspect_err(|_| unsafe { core::ptr::drop_in_place(slot) }) - } -} - -/// An initializer for `T`. -/// -/// To use this initializer, you will need a suitable memory location that can hold a `T`. This can -/// be [`KBox<T>`], [`Arc<T>`], [`UniqueArc<T>`] or even the stack (see [`stack_pin_init!`]). Use -/// the [`InPlaceInit::init`] function of a smart pointer like [`Arc<T>`] on this. Because -/// [`PinInit<T, E>`] is a super trait, you can use every function that takes it as well. -/// -/// Also see the [module description](self). -/// -/// # Safety -/// -/// When implementing this trait you will need to take great care. Also there are probably very few -/// cases where a manual implementation is necessary. Use [`init_from_closure`] where possible. -/// -/// The [`Init::__init`] function: -/// - returns `Ok(())` if it initialized every field of `slot`, -/// - returns `Err(err)` if it encountered an error and then cleaned `slot`, this means: -/// - `slot` can be deallocated without UB occurring, -/// - `slot` does not need to be dropped, -/// - `slot` is not partially initialized. -/// - while constructing the `T` at `slot` it upholds the pinning invariants of `T`. -/// -/// The `__pinned_init` function from the supertrait [`PinInit`] needs to execute the exact same -/// code as `__init`. -/// -/// Contrary to its supertype [`PinInit<T, E>`] the caller is allowed to -/// move the pointee after initialization. -/// -/// [`Arc<T>`]: crate::sync::Arc -#[must_use = "An initializer must be used in order to create its value."] -pub unsafe trait Init<T: ?Sized, E = Infallible>: PinInit<T, E> { - /// Initializes `slot`. - /// - /// # Safety - /// - /// - `slot` is a valid pointer to uninitialized memory. - /// - the caller does not touch `slot` when `Err` is returned, they are only permitted to - /// deallocate. - unsafe fn __init(self, slot: *mut T) -> Result<(), E>; - - /// First initializes the value using `self` then calls the function `f` with the initialized - /// value. - /// - /// If `f` returns an error the value is dropped and the initializer will forward the error. - /// - /// # Examples - /// - /// ```rust - /// # #![expect(clippy::disallowed_names)] - /// use kernel::{types::Opaque, init::{self, init_from_closure}}; - /// struct Foo { - /// buf: [u8; 1_000_000], - /// } - /// - /// impl Foo { - /// fn setup(&mut self) { - /// pr_info!("Setting up foo"); - /// } - /// } - /// - /// let foo = init!(Foo { - /// buf <- init::zeroed() - /// }).chain(|foo| { - /// foo.setup(); - /// Ok(()) - /// }); - /// ``` - fn chain<F>(self, f: F) -> ChainInit<Self, F, T, E> - where - F: FnOnce(&mut T) -> Result<(), E>, - { - ChainInit(self, f, PhantomData) - } -} - -/// An initializer returned by [`Init::chain`]. -pub struct ChainInit<I, F, T: ?Sized, E>(I, F, __internal::Invariant<(E, KBox<T>)>); - -// SAFETY: The `__init` function is implemented such that it -// - returns `Ok(())` on successful initialization, -// - returns `Err(err)` on error and in this case `slot` will be dropped. -unsafe impl<T: ?Sized, E, I, F> Init<T, E> for ChainInit<I, F, T, E> -where - I: Init<T, E>, - F: FnOnce(&mut T) -> Result<(), E>, -{ - unsafe fn __init(self, slot: *mut T) -> Result<(), E> { - // SAFETY: All requirements fulfilled since this function is `__init`. - unsafe { self.0.__pinned_init(slot)? }; - // SAFETY: The above call initialized `slot` and we still have unique access. - (self.1)(unsafe { &mut *slot }).inspect_err(|_| - // SAFETY: `slot` was initialized above. - unsafe { core::ptr::drop_in_place(slot) }) - } -} - -// SAFETY: `__pinned_init` behaves exactly the same as `__init`. -unsafe impl<T: ?Sized, E, I, F> PinInit<T, E> for ChainInit<I, F, T, E> -where - I: Init<T, E>, - F: FnOnce(&mut T) -> Result<(), E>, -{ - unsafe fn __pinned_init(self, slot: *mut T) -> Result<(), E> { - // SAFETY: `__init` has less strict requirements compared to `__pinned_init`. - unsafe { self.__init(slot) } - } -} - -/// Creates a new [`PinInit<T, E>`] from the given closure. -/// -/// # Safety -/// -/// The closure: -/// - returns `Ok(())` if it initialized every field of `slot`, -/// - returns `Err(err)` if it encountered an error and then cleaned `slot`, this means: -/// - `slot` can be deallocated without UB occurring, -/// - `slot` does not need to be dropped, -/// - `slot` is not partially initialized. -/// - may assume that the `slot` does not move if `T: !Unpin`, -/// - while constructing the `T` at `slot` it upholds the pinning invariants of `T`. -#[inline] -pub const unsafe fn pin_init_from_closure<T: ?Sized, E>( - f: impl FnOnce(*mut T) -> Result<(), E>, -) -> impl PinInit<T, E> { - __internal::InitClosure(f, PhantomData) -} - -/// Creates a new [`Init<T, E>`] from the given closure. -/// -/// # Safety -/// -/// The closure: -/// - returns `Ok(())` if it initialized every field of `slot`, -/// - returns `Err(err)` if it encountered an error and then cleaned `slot`, this means: -/// - `slot` can be deallocated without UB occurring, -/// - `slot` does not need to be dropped, -/// - `slot` is not partially initialized. -/// - the `slot` may move after initialization. -/// - while constructing the `T` at `slot` it upholds the pinning invariants of `T`. -#[inline] -pub const unsafe fn init_from_closure<T: ?Sized, E>( - f: impl FnOnce(*mut T) -> Result<(), E>, -) -> impl Init<T, E> { - __internal::InitClosure(f, PhantomData) -} - -/// An initializer that leaves the memory uninitialized. -/// -/// The initializer is a no-op. The `slot` memory is not changed. -#[inline] -pub fn uninit<T, E>() -> impl Init<MaybeUninit<T>, E> { - // SAFETY: The memory is allowed to be uninitialized. - unsafe { init_from_closure(|_| Ok(())) } -} - -/// Initializes an array by initializing each element via the provided initializer. -/// -/// # Examples -/// -/// ```rust -/// use kernel::{alloc::KBox, error::Error, init::init_array_from_fn}; -/// let array: KBox<[usize; 1_000]> = -/// KBox::init::<Error>(init_array_from_fn(|i| i), GFP_KERNEL).unwrap(); -/// assert_eq!(array.len(), 1_000); -/// ``` -pub fn init_array_from_fn<I, const N: usize, T, E>( - mut make_init: impl FnMut(usize) -> I, -) -> impl Init<[T; N], E> -where - I: Init<T, E>, -{ - let init = move |slot: *mut [T; N]| { - let slot = slot.cast::<T>(); - // Counts the number of initialized elements and when dropped drops that many elements from - // `slot`. - let mut init_count = ScopeGuard::new_with_data(0, |i| { - // We now free every element that has been initialized before. - // SAFETY: The loop initialized exactly the values from 0..i and since we - // return `Err` below, the caller will consider the memory at `slot` as - // uninitialized. - unsafe { ptr::drop_in_place(ptr::slice_from_raw_parts_mut(slot, i)) }; - }); - for i in 0..N { - let init = make_init(i); - // SAFETY: Since 0 <= `i` < N, it is still in bounds of `[T; N]`. - let ptr = unsafe { slot.add(i) }; - // SAFETY: The pointer is derived from `slot` and thus satisfies the `__init` - // requirements. - unsafe { init.__init(ptr) }?; - *init_count += 1; - } - init_count.dismiss(); - Ok(()) - }; - // SAFETY: The initializer above initializes every element of the array. On failure it drops - // any initialized elements and returns `Err`. - unsafe { init_from_closure(init) } -} - -/// Initializes an array by initializing each element via the provided initializer. -/// -/// # Examples -/// -/// ```rust -/// use kernel::{sync::{Arc, Mutex}, init::pin_init_array_from_fn, new_mutex}; -/// let array: Arc<[Mutex<usize>; 1_000]> = -/// Arc::pin_init(pin_init_array_from_fn(|i| new_mutex!(i)), GFP_KERNEL).unwrap(); -/// assert_eq!(array.len(), 1_000); -/// ``` -pub fn pin_init_array_from_fn<I, const N: usize, T, E>( - mut make_init: impl FnMut(usize) -> I, -) -> impl PinInit<[T; N], E> -where - I: PinInit<T, E>, -{ - let init = move |slot: *mut [T; N]| { - let slot = slot.cast::<T>(); - // Counts the number of initialized elements and when dropped drops that many elements from - // `slot`. - let mut init_count = ScopeGuard::new_with_data(0, |i| { - // We now free every element that has been initialized before. - // SAFETY: The loop initialized exactly the values from 0..i and since we - // return `Err` below, the caller will consider the memory at `slot` as - // uninitialized. - unsafe { ptr::drop_in_place(ptr::slice_from_raw_parts_mut(slot, i)) }; - }); - for i in 0..N { - let init = make_init(i); - // SAFETY: Since 0 <= `i` < N, it is still in bounds of `[T; N]`. - let ptr = unsafe { slot.add(i) }; - // SAFETY: The pointer is derived from `slot` and thus satisfies the `__init` - // requirements. - unsafe { init.__pinned_init(ptr) }?; - *init_count += 1; - } - init_count.dismiss(); - Ok(()) - }; - // SAFETY: The initializer above initializes every element of the array. On failure it drops - // any initialized elements and returns `Err`. - unsafe { pin_init_from_closure(init) } -} - -// SAFETY: Every type can be initialized by-value. -unsafe impl<T, E> Init<T, E> for T { - unsafe fn __init(self, slot: *mut T) -> Result<(), E> { - // SAFETY: TODO. - unsafe { slot.write(self) }; - Ok(()) - } -} - -// SAFETY: Every type can be initialized by-value. `__pinned_init` calls `__init`. -unsafe impl<T, E> PinInit<T, E> for T { - unsafe fn __pinned_init(self, slot: *mut T) -> Result<(), E> { - // SAFETY: TODO. - unsafe { self.__init(slot) } - } -} - -/// Smart pointer that can initialize memory in-place. -pub trait InPlaceInit<T>: Sized { - /// Pinned version of `Self`. - /// - /// If a type already implicitly pins its pointee, `Pin<Self>` is unnecessary. In this case use - /// `Self`, otherwise just use `Pin<Self>`. - type PinnedSelf; - - /// Use the given pin-initializer to pin-initialize a `T` inside of a new smart pointer of this - /// type. - /// - /// If `T: !Unpin` it will not be able to move afterwards. - fn try_pin_init<E>(init: impl PinInit<T, E>, flags: Flags) -> Result<Self::PinnedSelf, E> - where - E: From<AllocError>; - - /// Use the given pin-initializer to pin-initialize a `T` inside of a new smart pointer of this - /// type. - /// - /// If `T: !Unpin` it will not be able to move afterwards. - fn pin_init<E>(init: impl PinInit<T, E>, flags: Flags) -> error::Result<Self::PinnedSelf> - where - Error: From<E>, - { - // SAFETY: We delegate to `init` and only change the error type. - let init = unsafe { - pin_init_from_closure(|slot| init.__pinned_init(slot).map_err(|e| Error::from(e))) - }; - Self::try_pin_init(init, flags) - } - - /// Use the given initializer to in-place initialize a `T`. - fn try_init<E>(init: impl Init<T, E>, flags: Flags) -> Result<Self, E> - where - E: From<AllocError>; - - /// Use the given initializer to in-place initialize a `T`. - fn init<E>(init: impl Init<T, E>, flags: Flags) -> error::Result<Self> - where - Error: From<E>, - { - // SAFETY: We delegate to `init` and only change the error type. - let init = unsafe { - init_from_closure(|slot| init.__pinned_init(slot).map_err(|e| Error::from(e))) - }; - Self::try_init(init, flags) - } -} - -impl<T> InPlaceInit<T> for Arc<T> { - type PinnedSelf = Self; - - #[inline] - fn try_pin_init<E>(init: impl PinInit<T, E>, flags: Flags) -> Result<Self::PinnedSelf, E> - where - E: From<AllocError>, - { - UniqueArc::try_pin_init(init, flags).map(|u| u.into()) - } - - #[inline] - fn try_init<E>(init: impl Init<T, E>, flags: Flags) -> Result<Self, E> - where - E: From<AllocError>, - { - UniqueArc::try_init(init, flags).map(|u| u.into()) - } -} - -impl<T> InPlaceInit<T> for UniqueArc<T> { - type PinnedSelf = Pin<Self>; - - #[inline] - fn try_pin_init<E>(init: impl PinInit<T, E>, flags: Flags) -> Result<Self::PinnedSelf, E> - where - E: From<AllocError>, - { - UniqueArc::new_uninit(flags)?.write_pin_init(init) - } - - #[inline] - fn try_init<E>(init: impl Init<T, E>, flags: Flags) -> Result<Self, E> - where - E: From<AllocError>, - { - UniqueArc::new_uninit(flags)?.write_init(init) - } -} - -/// Smart pointer containing uninitialized memory and that can write a value. -pub trait InPlaceWrite<T> { - /// The type `Self` turns into when the contents are initialized. - type Initialized; - - /// Use the given initializer to write a value into `self`. - /// - /// Does not drop the current value and considers it as uninitialized memory. - fn write_init<E>(self, init: impl Init<T, E>) -> Result<Self::Initialized, E>; - - /// Use the given pin-initializer to write a value into `self`. - /// - /// Does not drop the current value and considers it as uninitialized memory. - fn write_pin_init<E>(self, init: impl PinInit<T, E>) -> Result<Pin<Self::Initialized>, E>; -} - -impl<T> InPlaceWrite<T> for UniqueArc<MaybeUninit<T>> { - type Initialized = UniqueArc<T>; - - fn write_init<E>(mut self, init: impl Init<T, E>) -> Result<Self::Initialized, E> { - let slot = self.as_mut_ptr(); - // SAFETY: When init errors/panics, slot will get deallocated but not dropped, - // slot is valid. - unsafe { init.__init(slot)? }; - // SAFETY: All fields have been initialized. - Ok(unsafe { self.assume_init() }) - } - - fn write_pin_init<E>(mut self, init: impl PinInit<T, E>) -> Result<Pin<Self::Initialized>, E> { - let slot = self.as_mut_ptr(); - // SAFETY: When init errors/panics, slot will get deallocated but not dropped, - // slot is valid and will not be moved, because we pin it later. - unsafe { init.__pinned_init(slot)? }; - // SAFETY: All fields have been initialized. - Ok(unsafe { self.assume_init() }.into()) - } -} - -/// Trait facilitating pinned destruction. -/// -/// Use [`pinned_drop`] to implement this trait safely: -/// -/// ```rust -/// # use kernel::sync::Mutex; -/// use kernel::macros::pinned_drop; -/// use core::pin::Pin; -/// #[pin_data(PinnedDrop)] -/// struct Foo { -/// #[pin] -/// mtx: Mutex<usize>, -/// } -/// -/// #[pinned_drop] -/// impl PinnedDrop for Foo { -/// fn drop(self: Pin<&mut Self>) { -/// pr_info!("Foo is being dropped!"); -/// } -/// } -/// ``` -/// -/// # Safety -/// -/// This trait must be implemented via the [`pinned_drop`] proc-macro attribute on the impl. -/// -/// [`pinned_drop`]: kernel::macros::pinned_drop -pub unsafe trait PinnedDrop: __internal::HasPinData { - /// Executes the pinned destructor of this type. - /// - /// While this function is marked safe, it is actually unsafe to call it manually. For this - /// reason it takes an additional parameter. This type can only be constructed by `unsafe` code - /// and thus prevents this function from being called where it should not. - /// - /// This extra parameter will be generated by the `#[pinned_drop]` proc-macro attribute - /// automatically. - fn drop(self: Pin<&mut Self>, only_call_from_drop: __internal::OnlyCallFromDrop); -} - -/// Marker trait for types that can be initialized by writing just zeroes. -/// -/// # Safety -/// -/// The bit pattern consisting of only zeroes is a valid bit pattern for this type. In other words, -/// this is not UB: -/// -/// ```rust,ignore -/// let val: Self = unsafe { core::mem::zeroed() }; -/// ``` -pub unsafe trait Zeroable {} - -/// Create a new zeroed T. -/// -/// The returned initializer will write `0x00` to every byte of the given `slot`. -#[inline] -pub fn zeroed<T: Zeroable>() -> impl Init<T> { - // SAFETY: Because `T: Zeroable`, all bytes zero is a valid bit pattern for `T` - // and because we write all zeroes, the memory is initialized. - unsafe { - init_from_closure(|slot: *mut T| { - slot.write_bytes(0, 1); - Ok(()) - }) - } -} - -macro_rules! impl_zeroable { - ($($({$($generics:tt)*})? $t:ty, )*) => { - // SAFETY: Safety comments written in the macro invocation. - $(unsafe impl$($($generics)*)? Zeroable for $t {})* - }; -} - -impl_zeroable! { - // SAFETY: All primitives that are allowed to be zero. - bool, - char, - u8, u16, u32, u64, u128, usize, - i8, i16, i32, i64, i128, isize, - f32, f64, - - // Note: do not add uninhabited types (such as `!` or `core::convert::Infallible`) to this list; - // creating an instance of an uninhabited type is immediate undefined behavior. For more on - // uninhabited/empty types, consult The Rustonomicon: - // <https://doc.rust-lang.org/stable/nomicon/exotic-sizes.html#empty-types>. The Rust Reference - // also has information on undefined behavior: - // <https://doc.rust-lang.org/stable/reference/behavior-considered-undefined.html>. - // - // SAFETY: These are inhabited ZSTs; there is nothing to zero and a valid value exists. - {<T: ?Sized>} PhantomData<T>, core::marker::PhantomPinned, (), - - // SAFETY: Type is allowed to take any value, including all zeros. - {<T>} MaybeUninit<T>, - // SAFETY: Type is allowed to take any value, including all zeros. - {<T>} Opaque<T>, - - // SAFETY: `T: Zeroable` and `UnsafeCell` is `repr(transparent)`. - {<T: ?Sized + Zeroable>} UnsafeCell<T>, - - // SAFETY: All zeros is equivalent to `None` (option layout optimization guarantee). - Option<NonZeroU8>, Option<NonZeroU16>, Option<NonZeroU32>, Option<NonZeroU64>, - Option<NonZeroU128>, Option<NonZeroUsize>, - Option<NonZeroI8>, Option<NonZeroI16>, Option<NonZeroI32>, Option<NonZeroI64>, - Option<NonZeroI128>, Option<NonZeroIsize>, - - // SAFETY: All zeros is equivalent to `None` (option layout optimization guarantee). - // - // In this case we are allowed to use `T: ?Sized`, since all zeros is the `None` variant. - {<T: ?Sized>} Option<NonNull<T>>, - {<T: ?Sized>} Option<KBox<T>>, - - // SAFETY: `null` pointer is valid. - // - // We cannot use `T: ?Sized`, since the VTABLE pointer part of fat pointers is not allowed to be - // null. - // - // When `Pointee` gets stabilized, we could use - // `T: ?Sized where <T as Pointee>::Metadata: Zeroable` - {<T>} *mut T, {<T>} *const T, - - // SAFETY: `null` pointer is valid and the metadata part of these fat pointers is allowed to be - // zero. - {<T>} *mut [T], {<T>} *const [T], *mut str, *const str, - - // SAFETY: `T` is `Zeroable`. - {<const N: usize, T: Zeroable>} [T; N], {<T: Zeroable>} Wrapping<T>, -} - -macro_rules! impl_tuple_zeroable { - ($(,)?) => {}; - ($first:ident, $($t:ident),* $(,)?) => { - // SAFETY: All elements are zeroable and padding can be zero. - unsafe impl<$first: Zeroable, $($t: Zeroable),*> Zeroable for ($first, $($t),*) {} - impl_tuple_zeroable!($($t),* ,); - } -} - -impl_tuple_zeroable!(A, B, C, D, E, F, G, H, I, J); diff --git a/rust/kernel/io.rs b/rust/kernel/io.rs new file mode 100644 index 000000000000..03b467722b86 --- /dev/null +++ b/rust/kernel/io.rs @@ -0,0 +1,265 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Memory-mapped IO. +//! +//! C header: [`include/asm-generic/io.h`](srctree/include/asm-generic/io.h) + +use crate::error::{code::EINVAL, Result}; +use crate::{bindings, build_assert, ffi::c_void}; + +pub mod mem; +pub mod resource; + +pub use resource::Resource; + +/// Raw representation of an MMIO region. +/// +/// By itself, the existence of an instance of this structure does not provide any guarantees that +/// the represented MMIO region does exist or is properly mapped. +/// +/// Instead, the bus specific MMIO implementation must convert this raw representation into an `Io` +/// instance providing the actual memory accessors. Only by the conversion into an `Io` structure +/// any guarantees are given. +pub struct IoRaw<const SIZE: usize = 0> { + addr: usize, + maxsize: usize, +} + +impl<const SIZE: usize> IoRaw<SIZE> { + /// Returns a new `IoRaw` instance on success, an error otherwise. + pub fn new(addr: usize, maxsize: usize) -> Result<Self> { + if maxsize < SIZE { + return Err(EINVAL); + } + + Ok(Self { addr, maxsize }) + } + + /// Returns the base address of the MMIO region. + #[inline] + pub fn addr(&self) -> usize { + self.addr + } + + /// Returns the maximum size of the MMIO region. + #[inline] + pub fn maxsize(&self) -> usize { + self.maxsize + } +} + +/// IO-mapped memory region. +/// +/// The creator (usually a subsystem / bus such as PCI) is responsible for creating the +/// mapping, performing an additional region request etc. +/// +/// # Invariant +/// +/// `addr` is the start and `maxsize` the length of valid I/O mapped memory region of size +/// `maxsize`. +/// +/// # Examples +/// +/// ```no_run +/// # use kernel::{bindings, ffi::c_void, io::{Io, IoRaw}}; +/// # use core::ops::Deref; +/// +/// // See also [`pci::Bar`] for a real example. +/// struct IoMem<const SIZE: usize>(IoRaw<SIZE>); +/// +/// impl<const SIZE: usize> IoMem<SIZE> { +/// /// # Safety +/// /// +/// /// [`paddr`, `paddr` + `SIZE`) must be a valid MMIO region that is mappable into the CPUs +/// /// virtual address space. +/// unsafe fn new(paddr: usize) -> Result<Self>{ +/// // SAFETY: By the safety requirements of this function [`paddr`, `paddr` + `SIZE`) is +/// // valid for `ioremap`. +/// let addr = unsafe { bindings::ioremap(paddr as bindings::phys_addr_t, SIZE) }; +/// if addr.is_null() { +/// return Err(ENOMEM); +/// } +/// +/// Ok(IoMem(IoRaw::new(addr as usize, SIZE)?)) +/// } +/// } +/// +/// impl<const SIZE: usize> Drop for IoMem<SIZE> { +/// fn drop(&mut self) { +/// // SAFETY: `self.0.addr()` is guaranteed to be properly mapped by `Self::new`. +/// unsafe { bindings::iounmap(self.0.addr() as *mut c_void); }; +/// } +/// } +/// +/// impl<const SIZE: usize> Deref for IoMem<SIZE> { +/// type Target = Io<SIZE>; +/// +/// fn deref(&self) -> &Self::Target { +/// // SAFETY: The memory range stored in `self` has been properly mapped in `Self::new`. +/// unsafe { Io::from_raw(&self.0) } +/// } +/// } +/// +///# fn no_run() -> Result<(), Error> { +/// // SAFETY: Invalid usage for example purposes. +/// let iomem = unsafe { IoMem::<{ core::mem::size_of::<u32>() }>::new(0xBAAAAAAD)? }; +/// iomem.write32(0x42, 0x0); +/// assert!(iomem.try_write32(0x42, 0x0).is_ok()); +/// assert!(iomem.try_write32(0x42, 0x4).is_err()); +/// # Ok(()) +/// # } +/// ``` +#[repr(transparent)] +pub struct Io<const SIZE: usize = 0>(IoRaw<SIZE>); + +macro_rules! define_read { + ($(#[$attr:meta])* $name:ident, $try_name:ident, $c_fn:ident -> $type_name:ty) => { + /// Read IO data from a given offset known at compile time. + /// + /// Bound checks are performed on compile time, hence if the offset is not known at compile + /// time, the build will fail. + $(#[$attr])* + #[inline] + pub fn $name(&self, offset: usize) -> $type_name { + let addr = self.io_addr_assert::<$type_name>(offset); + + // SAFETY: By the type invariant `addr` is a valid address for MMIO operations. + unsafe { bindings::$c_fn(addr as *const c_void) } + } + + /// Read IO data from a given offset. + /// + /// Bound checks are performed on runtime, it fails if the offset (plus the type size) is + /// out of bounds. + $(#[$attr])* + pub fn $try_name(&self, offset: usize) -> Result<$type_name> { + let addr = self.io_addr::<$type_name>(offset)?; + + // SAFETY: By the type invariant `addr` is a valid address for MMIO operations. + Ok(unsafe { bindings::$c_fn(addr as *const c_void) }) + } + }; +} + +macro_rules! define_write { + ($(#[$attr:meta])* $name:ident, $try_name:ident, $c_fn:ident <- $type_name:ty) => { + /// Write IO data from a given offset known at compile time. + /// + /// Bound checks are performed on compile time, hence if the offset is not known at compile + /// time, the build will fail. + $(#[$attr])* + #[inline] + pub fn $name(&self, value: $type_name, offset: usize) { + let addr = self.io_addr_assert::<$type_name>(offset); + + // SAFETY: By the type invariant `addr` is a valid address for MMIO operations. + unsafe { bindings::$c_fn(value, addr as *mut c_void) } + } + + /// Write IO data from a given offset. + /// + /// Bound checks are performed on runtime, it fails if the offset (plus the type size) is + /// out of bounds. + $(#[$attr])* + pub fn $try_name(&self, value: $type_name, offset: usize) -> Result { + let addr = self.io_addr::<$type_name>(offset)?; + + // SAFETY: By the type invariant `addr` is a valid address for MMIO operations. + unsafe { bindings::$c_fn(value, addr as *mut c_void) } + Ok(()) + } + }; +} + +impl<const SIZE: usize> Io<SIZE> { + /// Converts an `IoRaw` into an `Io` instance, providing the accessors to the MMIO mapping. + /// + /// # Safety + /// + /// Callers must ensure that `addr` is the start of a valid I/O mapped memory region of size + /// `maxsize`. + pub unsafe fn from_raw(raw: &IoRaw<SIZE>) -> &Self { + // SAFETY: `Io` is a transparent wrapper around `IoRaw`. + unsafe { &*core::ptr::from_ref(raw).cast() } + } + + /// Returns the base address of this mapping. + #[inline] + pub fn addr(&self) -> usize { + self.0.addr() + } + + /// Returns the maximum size of this mapping. + #[inline] + pub fn maxsize(&self) -> usize { + self.0.maxsize() + } + + #[inline] + const fn offset_valid<U>(offset: usize, size: usize) -> bool { + let type_size = core::mem::size_of::<U>(); + if let Some(end) = offset.checked_add(type_size) { + end <= size && offset % type_size == 0 + } else { + false + } + } + + #[inline] + fn io_addr<U>(&self, offset: usize) -> Result<usize> { + if !Self::offset_valid::<U>(offset, self.maxsize()) { + return Err(EINVAL); + } + + // Probably no need to check, since the safety requirements of `Self::new` guarantee that + // this can't overflow. + self.addr().checked_add(offset).ok_or(EINVAL) + } + + #[inline] + fn io_addr_assert<U>(&self, offset: usize) -> usize { + build_assert!(Self::offset_valid::<U>(offset, SIZE)); + + self.addr() + offset + } + + define_read!(read8, try_read8, readb -> u8); + define_read!(read16, try_read16, readw -> u16); + define_read!(read32, try_read32, readl -> u32); + define_read!( + #[cfg(CONFIG_64BIT)] + read64, + try_read64, + readq -> u64 + ); + + define_read!(read8_relaxed, try_read8_relaxed, readb_relaxed -> u8); + define_read!(read16_relaxed, try_read16_relaxed, readw_relaxed -> u16); + define_read!(read32_relaxed, try_read32_relaxed, readl_relaxed -> u32); + define_read!( + #[cfg(CONFIG_64BIT)] + read64_relaxed, + try_read64_relaxed, + readq_relaxed -> u64 + ); + + define_write!(write8, try_write8, writeb <- u8); + define_write!(write16, try_write16, writew <- u16); + define_write!(write32, try_write32, writel <- u32); + define_write!( + #[cfg(CONFIG_64BIT)] + write64, + try_write64, + writeq <- u64 + ); + + define_write!(write8_relaxed, try_write8_relaxed, writeb_relaxed <- u8); + define_write!(write16_relaxed, try_write16_relaxed, writew_relaxed <- u16); + define_write!(write32_relaxed, try_write32_relaxed, writel_relaxed <- u32); + define_write!( + #[cfg(CONFIG_64BIT)] + write64_relaxed, + try_write64_relaxed, + writeq_relaxed <- u64 + ); +} diff --git a/rust/kernel/io/mem.rs b/rust/kernel/io/mem.rs new file mode 100644 index 000000000000..6f99510bfc3a --- /dev/null +++ b/rust/kernel/io/mem.rs @@ -0,0 +1,279 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Generic memory-mapped IO. + +use core::ops::Deref; + +use crate::c_str; +use crate::device::Bound; +use crate::device::Device; +use crate::devres::Devres; +use crate::io; +use crate::io::resource::Region; +use crate::io::resource::Resource; +use crate::io::Io; +use crate::io::IoRaw; +use crate::prelude::*; + +/// An IO request for a specific device and resource. +pub struct IoRequest<'a> { + device: &'a Device<Bound>, + resource: &'a Resource, +} + +impl<'a> IoRequest<'a> { + /// Creates a new [`IoRequest`] instance. + /// + /// # Safety + /// + /// Callers must ensure that `resource` is valid for `device` during the + /// lifetime `'a`. + pub(crate) unsafe fn new(device: &'a Device<Bound>, resource: &'a Resource) -> Self { + IoRequest { device, resource } + } + + /// Maps an [`IoRequest`] where the size is known at compile time. + /// + /// This uses the [`ioremap()`] C API. + /// + /// [`ioremap()`]: https://docs.kernel.org/driver-api/device-io.html#getting-access-to-the-device + /// + /// # Examples + /// + /// The following example uses a [`kernel::platform::Device`] for + /// illustration purposes. + /// + /// ```no_run + /// use kernel::{bindings, c_str, platform, of, device::Core}; + /// struct SampleDriver; + /// + /// impl platform::Driver for SampleDriver { + /// # type IdInfo = (); + /// + /// fn probe( + /// pdev: &platform::Device<Core>, + /// info: Option<&Self::IdInfo>, + /// ) -> Result<Pin<KBox<Self>>> { + /// let offset = 0; // Some offset. + /// + /// // If the size is known at compile time, use [`Self::iomap_sized`]. + /// // + /// // No runtime checks will apply when reading and writing. + /// let request = pdev.io_request_by_index(0).ok_or(ENODEV)?; + /// let iomem = request.iomap_sized::<42>(); + /// let iomem = KBox::pin_init(iomem, GFP_KERNEL)?; + /// + /// let io = iomem.access(pdev.as_ref())?; + /// + /// // Read and write a 32-bit value at `offset`. + /// let data = io.read32_relaxed(offset); + /// + /// io.write32_relaxed(data, offset); + /// + /// # Ok(KBox::new(SampleDriver, GFP_KERNEL)?.into()) + /// } + /// } + /// ``` + pub fn iomap_sized<const SIZE: usize>(self) -> impl PinInit<Devres<IoMem<SIZE>>, Error> + 'a { + IoMem::new(self) + } + + /// Same as [`Self::iomap_sized`] but with exclusive access to the + /// underlying region. + /// + /// This uses the [`ioremap()`] C API. + /// + /// [`ioremap()`]: https://docs.kernel.org/driver-api/device-io.html#getting-access-to-the-device + pub fn iomap_exclusive_sized<const SIZE: usize>( + self, + ) -> impl PinInit<Devres<ExclusiveIoMem<SIZE>>, Error> + 'a { + ExclusiveIoMem::new(self) + } + + /// Maps an [`IoRequest`] where the size is not known at compile time, + /// + /// This uses the [`ioremap()`] C API. + /// + /// [`ioremap()`]: https://docs.kernel.org/driver-api/device-io.html#getting-access-to-the-device + /// + /// # Examples + /// + /// The following example uses a [`kernel::platform::Device`] for + /// illustration purposes. + /// + /// ```no_run + /// use kernel::{bindings, c_str, platform, of, device::Core}; + /// struct SampleDriver; + /// + /// impl platform::Driver for SampleDriver { + /// # type IdInfo = (); + /// + /// fn probe( + /// pdev: &platform::Device<Core>, + /// info: Option<&Self::IdInfo>, + /// ) -> Result<Pin<KBox<Self>>> { + /// let offset = 0; // Some offset. + /// + /// // Unlike [`Self::iomap_sized`], here the size of the memory region + /// // is not known at compile time, so only the `try_read*` and `try_write*` + /// // family of functions should be used, leading to runtime checks on every + /// // access. + /// let request = pdev.io_request_by_index(0).ok_or(ENODEV)?; + /// let iomem = request.iomap(); + /// let iomem = KBox::pin_init(iomem, GFP_KERNEL)?; + /// + /// let io = iomem.access(pdev.as_ref())?; + /// + /// let data = io.try_read32_relaxed(offset)?; + /// + /// io.try_write32_relaxed(data, offset)?; + /// + /// # Ok(KBox::new(SampleDriver, GFP_KERNEL)?.into()) + /// } + /// } + /// ``` + pub fn iomap(self) -> impl PinInit<Devres<IoMem<0>>, Error> + 'a { + Self::iomap_sized::<0>(self) + } + + /// Same as [`Self::iomap`] but with exclusive access to the underlying + /// region. + pub fn iomap_exclusive(self) -> impl PinInit<Devres<ExclusiveIoMem<0>>, Error> + 'a { + Self::iomap_exclusive_sized::<0>(self) + } +} + +/// An exclusive memory-mapped IO region. +/// +/// # Invariants +/// +/// - [`ExclusiveIoMem`] has exclusive access to the underlying [`IoMem`]. +pub struct ExclusiveIoMem<const SIZE: usize> { + /// The underlying `IoMem` instance. + iomem: IoMem<SIZE>, + + /// The region abstraction. This represents exclusive access to the + /// range represented by the underlying `iomem`. + /// + /// This field is needed for ownership of the region. + _region: Region, +} + +impl<const SIZE: usize> ExclusiveIoMem<SIZE> { + /// Creates a new `ExclusiveIoMem` instance. + fn ioremap(resource: &Resource) -> Result<Self> { + let start = resource.start(); + let size = resource.size(); + let name = resource.name().unwrap_or(c_str!("")); + + let region = resource + .request_region( + start, + size, + name.to_cstring()?, + io::resource::Flags::IORESOURCE_MEM, + ) + .ok_or(EBUSY)?; + + let iomem = IoMem::ioremap(resource)?; + + let iomem = ExclusiveIoMem { + iomem, + _region: region, + }; + + Ok(iomem) + } + + /// Creates a new `ExclusiveIoMem` instance from a previously acquired [`IoRequest`]. + pub fn new<'a>(io_request: IoRequest<'a>) -> impl PinInit<Devres<Self>, Error> + 'a { + let dev = io_request.device; + let res = io_request.resource; + + Devres::new(dev, Self::ioremap(res)) + } +} + +impl<const SIZE: usize> Deref for ExclusiveIoMem<SIZE> { + type Target = Io<SIZE>; + + fn deref(&self) -> &Self::Target { + &self.iomem + } +} + +/// A generic memory-mapped IO region. +/// +/// Accesses to the underlying region is checked either at compile time, if the +/// region's size is known at that point, or at runtime otherwise. +/// +/// # Invariants +/// +/// [`IoMem`] always holds an [`IoRaw`] instance that holds a valid pointer to the +/// start of the I/O memory mapped region. +pub struct IoMem<const SIZE: usize = 0> { + io: IoRaw<SIZE>, +} + +impl<const SIZE: usize> IoMem<SIZE> { + fn ioremap(resource: &Resource) -> Result<Self> { + // Note: Some ioremap() implementations use types that depend on the CPU + // word width rather than the bus address width. + // + // TODO: Properly address this in the C code to avoid this `try_into`. + let size = resource.size().try_into()?; + if size == 0 { + return Err(EINVAL); + } + + let res_start = resource.start(); + + let addr = if resource + .flags() + .contains(io::resource::Flags::IORESOURCE_MEM_NONPOSTED) + { + // SAFETY: + // - `res_start` and `size` are read from a presumably valid `struct resource`. + // - `size` is known not to be zero at this point. + unsafe { bindings::ioremap_np(res_start, size) } + } else { + // SAFETY: + // - `res_start` and `size` are read from a presumably valid `struct resource`. + // - `size` is known not to be zero at this point. + unsafe { bindings::ioremap(res_start, size) } + }; + + if addr.is_null() { + return Err(ENOMEM); + } + + let io = IoRaw::new(addr as usize, size)?; + let io = IoMem { io }; + + Ok(io) + } + + /// Creates a new `IoMem` instance from a previously acquired [`IoRequest`]. + pub fn new<'a>(io_request: IoRequest<'a>) -> impl PinInit<Devres<Self>, Error> + 'a { + let dev = io_request.device; + let res = io_request.resource; + + Devres::new(dev, Self::ioremap(res)) + } +} + +impl<const SIZE: usize> Drop for IoMem<SIZE> { + fn drop(&mut self) { + // SAFETY: Safe as by the invariant of `Io`. + unsafe { bindings::iounmap(self.io.addr() as *mut c_void) } + } +} + +impl<const SIZE: usize> Deref for IoMem<SIZE> { + type Target = Io<SIZE>; + + fn deref(&self) -> &Self::Target { + // SAFETY: Safe as by the invariant of `IoMem`. + unsafe { Io::from_raw(&self.io) } + } +} diff --git a/rust/kernel/io/resource.rs b/rust/kernel/io/resource.rs new file mode 100644 index 000000000000..bea3ee0ed87b --- /dev/null +++ b/rust/kernel/io/resource.rs @@ -0,0 +1,230 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Abstractions for [system +//! resources](https://docs.kernel.org/core-api/kernel-api.html#resources-management). +//! +//! C header: [`include/linux/ioport.h`](srctree/include/linux/ioport.h) + +use core::ops::Deref; +use core::ptr::NonNull; + +use crate::prelude::*; +use crate::str::{CStr, CString}; +use crate::types::Opaque; + +/// Resource Size type. +/// +/// This is a type alias to either `u32` or `u64` depending on the config option +/// `CONFIG_PHYS_ADDR_T_64BIT`, and it can be a u64 even on 32-bit architectures. +pub type ResourceSize = bindings::phys_addr_t; + +/// A region allocated from a parent [`Resource`]. +/// +/// # Invariants +/// +/// - `self.0` points to a valid `bindings::resource` that was obtained through +/// `bindings::__request_region`. +pub struct Region { + /// The resource returned when the region was requested. + resource: NonNull<bindings::resource>, + /// The name that was passed in when the region was requested. We need to + /// store it for ownership reasons. + _name: CString, +} + +impl Deref for Region { + type Target = Resource; + + fn deref(&self) -> &Self::Target { + // SAFETY: Safe as per the invariant of `Region`. + unsafe { Resource::from_raw(self.resource.as_ptr()) } + } +} + +impl Drop for Region { + fn drop(&mut self) { + let (flags, start, size) = { + let res = &**self; + (res.flags(), res.start(), res.size()) + }; + + let release_fn = if flags.contains(Flags::IORESOURCE_MEM) { + bindings::release_mem_region + } else { + bindings::release_region + }; + + // SAFETY: Safe as per the invariant of `Region`. + unsafe { release_fn(start, size) }; + } +} + +// SAFETY: `Region` only holds a pointer to a C `struct resource`, which is safe to be used from +// any thread. +unsafe impl Send for Region {} + +// SAFETY: `Region` only holds a pointer to a C `struct resource`, references to which are +// safe to be used from any thread. +unsafe impl Sync for Region {} + +/// A resource abstraction. +/// +/// # Invariants +/// +/// [`Resource`] is a transparent wrapper around a valid `bindings::resource`. +#[repr(transparent)] +pub struct Resource(Opaque<bindings::resource>); + +impl Resource { + /// Creates a reference to a [`Resource`] from a valid pointer. + /// + /// # Safety + /// + /// The caller must ensure that for the duration of 'a, the pointer will + /// point at a valid `bindings::resource`. + /// + /// The caller must also ensure that the [`Resource`] is only accessed via the + /// returned reference for the duration of 'a. + pub(crate) const unsafe fn from_raw<'a>(ptr: *mut bindings::resource) -> &'a Self { + // SAFETY: Self is a transparent wrapper around `Opaque<bindings::resource>`. + unsafe { &*ptr.cast() } + } + + /// Requests a resource region. + /// + /// Exclusive access will be given and the region will be marked as busy. + /// Further calls to [`Self::request_region`] will return [`None`] if + /// the region, or a part of it, is already in use. + pub fn request_region( + &self, + start: ResourceSize, + size: ResourceSize, + name: CString, + flags: Flags, + ) -> Option<Region> { + // SAFETY: + // - Safe as per the invariant of `Resource`. + // - `__request_region` will store a reference to the name, but that is + // safe as we own it and it will not be dropped until the `Region` is + // dropped. + let region = unsafe { + bindings::__request_region( + self.0.get(), + start, + size, + name.as_char_ptr(), + flags.0 as c_int, + ) + }; + + Some(Region { + resource: NonNull::new(region)?, + _name: name, + }) + } + + /// Returns the size of the resource. + pub fn size(&self) -> ResourceSize { + let inner = self.0.get(); + // SAFETY: Safe as per the invariants of `Resource`. + unsafe { bindings::resource_size(inner) } + } + + /// Returns the start address of the resource. + pub fn start(&self) -> ResourceSize { + let inner = self.0.get(); + // SAFETY: Safe as per the invariants of `Resource`. + unsafe { (*inner).start } + } + + /// Returns the name of the resource. + pub fn name(&self) -> Option<&CStr> { + let inner = self.0.get(); + + // SAFETY: Safe as per the invariants of `Resource`. + let name = unsafe { (*inner).name }; + + if name.is_null() { + return None; + } + + // SAFETY: In the C code, `resource::name` either contains a null + // pointer or points to a valid NUL-terminated C string, and at this + // point we know it is not null, so we can safely convert it to a + // `CStr`. + Some(unsafe { CStr::from_char_ptr(name) }) + } + + /// Returns the flags associated with the resource. + pub fn flags(&self) -> Flags { + let inner = self.0.get(); + // SAFETY: Safe as per the invariants of `Resource`. + let flags = unsafe { (*inner).flags }; + + Flags(flags) + } +} + +// SAFETY: `Resource` only holds a pointer to a C `struct resource`, which is +// safe to be used from any thread. +unsafe impl Send for Resource {} + +// SAFETY: `Resource` only holds a pointer to a C `struct resource`, references +// to which are safe to be used from any thread. +unsafe impl Sync for Resource {} + +/// Resource flags as stored in the C `struct resource::flags` field. +/// +/// They can be combined with the operators `|`, `&`, and `!`. +/// +/// Values can be used from the associated constants such as +/// [`Flags::IORESOURCE_IO`]. +#[derive(Clone, Copy, PartialEq)] +pub struct Flags(c_ulong); + +impl Flags { + /// Check whether `flags` is contained in `self`. + pub fn contains(self, flags: Flags) -> bool { + (self & flags) == flags + } +} + +impl core::ops::BitOr for Flags { + type Output = Self; + fn bitor(self, rhs: Self) -> Self::Output { + Self(self.0 | rhs.0) + } +} + +impl core::ops::BitAnd for Flags { + type Output = Self; + fn bitand(self, rhs: Self) -> Self::Output { + Self(self.0 & rhs.0) + } +} + +impl core::ops::Not for Flags { + type Output = Self; + fn not(self) -> Self::Output { + Self(!self.0) + } +} + +impl Flags { + /// PCI/ISA I/O ports. + pub const IORESOURCE_IO: Flags = Flags::new(bindings::IORESOURCE_IO); + + /// Resource is software muxed. + pub const IORESOURCE_MUXED: Flags = Flags::new(bindings::IORESOURCE_MUXED); + + /// Resource represents a memory region. + pub const IORESOURCE_MEM: Flags = Flags::new(bindings::IORESOURCE_MEM); + + /// Resource represents a memory region that must be ioremaped using `ioremap_np`. + pub const IORESOURCE_MEM_NONPOSTED: Flags = Flags::new(bindings::IORESOURCE_MEM_NONPOSTED); + + const fn new(value: u32) -> Self { + crate::build_assert!(value as u64 <= c_ulong::MAX as u64); + Flags(value as c_ulong) + } +} diff --git a/rust/kernel/kunit.rs b/rust/kernel/kunit.rs index 824da0e9738a..41efd87595d6 100644 --- a/rust/kernel/kunit.rs +++ b/rust/kernel/kunit.rs @@ -6,7 +6,11 @@ //! //! Reference: <https://docs.kernel.org/dev-tools/kunit/index.html> -use core::{ffi::c_void, fmt}; +use crate::prelude::*; +use core::fmt; + +#[cfg(CONFIG_PRINTK)] +use crate::c_str; /// Prints a KUnit error-level message. /// @@ -18,8 +22,8 @@ pub fn err(args: fmt::Arguments<'_>) { #[cfg(CONFIG_PRINTK)] unsafe { bindings::_printk( - c"\x013%pA".as_ptr() as _, - &args as *const _ as *const c_void, + c_str!("\x013%pA").as_char_ptr(), + core::ptr::from_ref(&args).cast::<c_void>(), ); } } @@ -34,8 +38,8 @@ pub fn info(args: fmt::Arguments<'_>) { #[cfg(CONFIG_PRINTK)] unsafe { bindings::_printk( - c"\x016%pA".as_ptr() as _, - &args as *const _ as *const c_void, + c_str!("\x016%pA").as_char_ptr(), + core::ptr::from_ref(&args).cast::<c_void>(), ); } } @@ -57,7 +61,7 @@ macro_rules! kunit_assert { } static FILE: &'static $crate::str::CStr = $crate::c_str!($file); - static LINE: i32 = core::line!() as i32 - $diff; + static LINE: i32 = ::core::line!() as i32 - $diff; static CONDITION: &'static $crate::str::CStr = $crate::c_str!(stringify!($condition)); // SAFETY: FFI call without safety requirements. @@ -128,11 +132,11 @@ macro_rules! kunit_assert { unsafe { $crate::bindings::__kunit_do_failed_assertion( kunit_test, - core::ptr::addr_of!(LOCATION.0), + ::core::ptr::addr_of!(LOCATION.0), $crate::bindings::kunit_assert_type_KUNIT_ASSERTION, - core::ptr::addr_of!(ASSERTION.0.assert), + ::core::ptr::addr_of!(ASSERTION.0.assert), Some($crate::bindings::kunit_unary_assert_format), - core::ptr::null(), + ::core::ptr::null(), ); } @@ -161,3 +165,196 @@ macro_rules! kunit_assert_eq { $crate::kunit_assert!($name, $file, $diff, $left == $right); }}; } + +trait TestResult { + fn is_test_result_ok(&self) -> bool; +} + +impl TestResult for () { + fn is_test_result_ok(&self) -> bool { + true + } +} + +impl<T, E> TestResult for Result<T, E> { + fn is_test_result_ok(&self) -> bool { + self.is_ok() + } +} + +/// Returns whether a test result is to be considered OK. +/// +/// This will be `assert!`ed from the generated tests. +#[doc(hidden)] +#[expect(private_bounds)] +pub fn is_test_result_ok(t: impl TestResult) -> bool { + t.is_test_result_ok() +} + +/// Represents an individual test case. +/// +/// The [`kunit_unsafe_test_suite!`] macro expects a NULL-terminated list of valid test cases. +/// Use [`kunit_case_null`] to generate such a delimiter. +#[doc(hidden)] +pub const fn kunit_case( + name: &'static kernel::str::CStr, + run_case: unsafe extern "C" fn(*mut kernel::bindings::kunit), +) -> kernel::bindings::kunit_case { + kernel::bindings::kunit_case { + run_case: Some(run_case), + name: name.as_char_ptr(), + attr: kernel::bindings::kunit_attributes { + speed: kernel::bindings::kunit_speed_KUNIT_SPEED_NORMAL, + }, + generate_params: None, + status: kernel::bindings::kunit_status_KUNIT_SUCCESS, + module_name: core::ptr::null_mut(), + log: core::ptr::null_mut(), + } +} + +/// Represents the NULL test case delimiter. +/// +/// The [`kunit_unsafe_test_suite!`] macro expects a NULL-terminated list of test cases. This +/// function returns such a delimiter. +#[doc(hidden)] +pub const fn kunit_case_null() -> kernel::bindings::kunit_case { + kernel::bindings::kunit_case { + run_case: None, + name: core::ptr::null_mut(), + generate_params: None, + attr: kernel::bindings::kunit_attributes { + speed: kernel::bindings::kunit_speed_KUNIT_SPEED_NORMAL, + }, + status: kernel::bindings::kunit_status_KUNIT_SUCCESS, + module_name: core::ptr::null_mut(), + log: core::ptr::null_mut(), + } +} + +/// Registers a KUnit test suite. +/// +/// # Safety +/// +/// `test_cases` must be a NULL terminated array of valid test cases, +/// whose lifetime is at least that of the test suite (i.e., static). +/// +/// # Examples +/// +/// ```ignore +/// extern "C" fn test_fn(_test: *mut kernel::bindings::kunit) { +/// let actual = 1 + 1; +/// let expected = 2; +/// assert_eq!(actual, expected); +/// } +/// +/// static mut KUNIT_TEST_CASES: [kernel::bindings::kunit_case; 2] = [ +/// kernel::kunit::kunit_case(kernel::c_str!("name"), test_fn), +/// kernel::kunit::kunit_case_null(), +/// ]; +/// kernel::kunit_unsafe_test_suite!(suite_name, KUNIT_TEST_CASES); +/// ``` +#[doc(hidden)] +#[macro_export] +macro_rules! kunit_unsafe_test_suite { + ($name:ident, $test_cases:ident) => { + const _: () = { + const KUNIT_TEST_SUITE_NAME: [::kernel::ffi::c_char; 256] = { + let name_u8 = ::core::stringify!($name).as_bytes(); + let mut ret = [0; 256]; + + if name_u8.len() > 255 { + panic!(concat!( + "The test suite name `", + ::core::stringify!($name), + "` exceeds the maximum length of 255 bytes." + )); + } + + let mut i = 0; + while i < name_u8.len() { + ret[i] = name_u8[i] as ::kernel::ffi::c_char; + i += 1; + } + + ret + }; + + static mut KUNIT_TEST_SUITE: ::kernel::bindings::kunit_suite = + ::kernel::bindings::kunit_suite { + name: KUNIT_TEST_SUITE_NAME, + #[allow(unused_unsafe)] + // SAFETY: `$test_cases` is passed in by the user, and + // (as documented) must be valid for the lifetime of + // the suite (i.e., static). + test_cases: unsafe { + ::core::ptr::addr_of_mut!($test_cases) + .cast::<::kernel::bindings::kunit_case>() + }, + suite_init: None, + suite_exit: None, + init: None, + exit: None, + attr: ::kernel::bindings::kunit_attributes { + speed: ::kernel::bindings::kunit_speed_KUNIT_SPEED_NORMAL, + }, + status_comment: [0; 256usize], + debugfs: ::core::ptr::null_mut(), + log: ::core::ptr::null_mut(), + suite_init_err: 0, + is_init: false, + }; + + #[used(compiler)] + #[allow(unused_unsafe)] + #[cfg_attr(not(target_os = "macos"), link_section = ".kunit_test_suites")] + static mut KUNIT_TEST_SUITE_ENTRY: *const ::kernel::bindings::kunit_suite = + // SAFETY: `KUNIT_TEST_SUITE` is static. + unsafe { ::core::ptr::addr_of_mut!(KUNIT_TEST_SUITE) }; + }; + }; +} + +/// Returns whether we are currently running a KUnit test. +/// +/// In some cases, you need to call test-only code from outside the test case, for example, to +/// create a function mock. This function allows to change behavior depending on whether we are +/// currently running a KUnit test or not. +/// +/// # Examples +/// +/// This example shows how a function can be mocked to return a well-known value while testing: +/// +/// ``` +/// # use kernel::kunit::in_kunit_test; +/// fn fn_mock_example(n: i32) -> i32 { +/// if in_kunit_test() { +/// return 100; +/// } +/// +/// n + 1 +/// } +/// +/// let mock_res = fn_mock_example(5); +/// assert_eq!(mock_res, 100); +/// ``` +pub fn in_kunit_test() -> bool { + // SAFETY: `kunit_get_current_test()` is always safe to call (it has fallbacks for + // when KUnit is not enabled). + !unsafe { bindings::kunit_get_current_test() }.is_null() +} + +#[kunit_tests(rust_kernel_kunit)] +mod tests { + use super::*; + + #[test] + fn rust_test_kunit_example_test() { + assert_eq!(1 + 1, 2); + } + + #[test] + fn rust_test_kunit_in_kunit_test() { + assert!(in_kunit_test()); + } +} diff --git a/rust/kernel/lib.rs b/rust/kernel/lib.rs index e1065a7551a3..fef97f2a5098 100644 --- a/rust/kernel/lib.rs +++ b/rust/kernel/lib.rs @@ -6,18 +6,47 @@ //! usage by Rust code in the kernel and is shared by all of them. //! //! In other words, all the rest of the Rust code in the kernel (e.g. kernel -//! modules written in Rust) depends on [`core`], [`alloc`] and this crate. +//! modules written in Rust) depends on [`core`] and this crate. //! //! If you need a kernel C API that is not ported or wrapped yet here, then //! do so first instead of bypassing this crate. #![no_std] -#![feature(arbitrary_self_types)] -#![feature(coerce_unsized)] -#![feature(dispatch_from_dyn)] +// +// Please see https://github.com/Rust-for-Linux/linux/issues/2 for details on +// the unstable features in use. +// +// Stable since Rust 1.79.0. #![feature(inline_const)] +// +// Stable since Rust 1.81.0. #![feature(lint_reasons)] -#![feature(unsize)] +// +// Stable since Rust 1.82.0. +#![feature(raw_ref_op)] +// +// Stable since Rust 1.83.0. +#![feature(const_maybe_uninit_as_mut_ptr)] +#![feature(const_mut_refs)] +#![feature(const_ptr_write)] +#![feature(const_refs_to_cell)] +// +// Expected to become stable. +#![feature(arbitrary_self_types)] +// +// To be determined. +#![feature(used_with_arg)] +// +// `feature(derive_coerce_pointee)` is expected to become stable. Before Rust +// 1.84.0, it did not exist, so enable the predecessor features. +#![cfg_attr(CONFIG_RUSTC_HAS_COERCE_POINTEE, feature(derive_coerce_pointee))] +#![cfg_attr(not(CONFIG_RUSTC_HAS_COERCE_POINTEE), feature(coerce_unsized))] +#![cfg_attr(not(CONFIG_RUSTC_HAS_COERCE_POINTEE), feature(dispatch_from_dyn))] +#![cfg_attr(not(CONFIG_RUSTC_HAS_COERCE_POINTEE), feature(unsize))] +// +// `feature(file_with_nul)` is expected to become stable. Before Rust 1.89.0, it did not exist, so +// enable it conditionally. +#![cfg_attr(CONFIG_RUSTC_HAS_FILE_WITH_NUL, feature(file_with_nul))] // Ensure conditional compilation based on the kernel configuration works; // otherwise we may silently break things like initcall handling. @@ -29,30 +58,61 @@ extern crate self as kernel; pub use ffi; +pub mod acpi; pub mod alloc; +#[cfg(CONFIG_AUXILIARY_BUS)] +pub mod auxiliary; +pub mod bits; #[cfg(CONFIG_BLOCK)] pub mod block; -mod build_assert; +pub mod bug; +#[doc(hidden)] +pub mod build_assert; +pub mod clk; +#[cfg(CONFIG_CONFIGFS_FS)] +pub mod configfs; +pub mod cpu; +#[cfg(CONFIG_CPU_FREQ)] +pub mod cpufreq; +pub mod cpumask; pub mod cred; pub mod device; +pub mod device_id; +pub mod devres; +pub mod dma; +pub mod driver; +#[cfg(CONFIG_DRM = "y")] +pub mod drm; pub mod error; +pub mod faux; #[cfg(CONFIG_RUST_FW_LOADER_ABSTRACTIONS)] pub mod firmware; +pub mod fmt; pub mod fs; pub mod init; +pub mod io; pub mod ioctl; pub mod jump_label; #[cfg(CONFIG_KUNIT)] pub mod kunit; pub mod list; pub mod miscdevice; +pub mod mm; #[cfg(CONFIG_NET)] pub mod net; +pub mod of; +#[cfg(CONFIG_PM_OPP)] +pub mod opp; pub mod page; +#[cfg(CONFIG_PCI)] +pub mod pci; pub mod pid_namespace; +pub mod platform; pub mod prelude; pub mod print; pub mod rbtree; +pub mod regulator; +pub mod revocable; pub mod security; pub mod seq_file; pub mod sizes; @@ -68,15 +128,13 @@ pub mod transmute; pub mod types; pub mod uaccess; pub mod workqueue; +pub mod xarray; #[doc(hidden)] pub use bindings; pub use macros; pub use uapi; -#[doc(hidden)] -pub use build_error::build_error; - /// Prefix to appear before log messages printed from within the `kernel` crate. const __LOG_PREFIX: &[u8] = b"rust_kernel\0"; @@ -98,11 +156,11 @@ pub trait InPlaceModule: Sync + Send { /// Creates an initialiser for the module. /// /// It is called when the module is loaded. - fn init(module: &'static ThisModule) -> impl init::PinInit<Self, error::Error>; + fn init(module: &'static ThisModule) -> impl pin_init::PinInit<Self, error::Error>; } impl<T: Module> InPlaceModule for T { - fn init(module: &'static ThisModule) -> impl init::PinInit<Self, error::Error> { + fn init(module: &'static ThisModule) -> impl pin_init::PinInit<Self, error::Error> { let initer = move |slot: *mut Self| { let m = <Self as Module>::init(module)?; @@ -112,10 +170,16 @@ impl<T: Module> InPlaceModule for T { }; // SAFETY: On success, `initer` always fully initialises an instance of `Self`. - unsafe { init::pin_init_from_closure(initer) } + unsafe { pin_init::pin_init_from_closure(initer) } } } +/// Metadata attached to a [`Module`] or [`InPlaceModule`]. +pub trait ModuleMetadata { + /// The name of the module as specified in the `module!` macro. + const NAME: &'static crate::str::CStr; +} + /// Equivalent to `THIS_MODULE` in the C API. /// /// C header: [`include/linux/init.h`](srctree/include/linux/init.h) @@ -152,6 +216,13 @@ fn panic(info: &core::panic::PanicInfo<'_>) -> ! { /// Produces a pointer to an object from a pointer to one of its fields. /// +/// If you encounter a type mismatch due to the [`Opaque`] type, then use [`Opaque::cast_into`] or +/// [`Opaque::cast_from`] to resolve the mismatch. +/// +/// [`Opaque`]: crate::types::Opaque +/// [`Opaque::cast_into`]: crate::types::Opaque::cast_into +/// [`Opaque::cast_from`]: crate::types::Opaque::cast_from +/// /// # Safety /// /// The pointer passed to this macro, and the pointer returned by this macro, must both be in @@ -167,7 +238,7 @@ fn panic(info: &core::panic::PanicInfo<'_>) -> ! { /// } /// /// let test = Test { a: 10, b: 20 }; -/// let b_ptr = &test.b; +/// let b_ptr: *const _ = &test.b; /// // SAFETY: The pointer points at the `b` field of a `Test`, so the resulting pointer will be /// // in-bounds of the same allocation as `b_ptr`. /// let test_alias = unsafe { container_of!(b_ptr, Test, b) }; @@ -175,13 +246,19 @@ fn panic(info: &core::panic::PanicInfo<'_>) -> ! { /// ``` #[macro_export] macro_rules! container_of { - ($ptr:expr, $type:ty, $($f:tt)*) => {{ - let ptr = $ptr as *const _ as *const u8; - let offset: usize = ::core::mem::offset_of!($type, $($f)*); - ptr.sub(offset) as *const $type + ($field_ptr:expr, $Container:ty, $($fields:tt)*) => {{ + let offset: usize = ::core::mem::offset_of!($Container, $($fields)*); + let field_ptr = $field_ptr; + let container_ptr = field_ptr.byte_sub(offset).cast::<$Container>(); + $crate::assert_same_type(field_ptr, (&raw const (*container_ptr).$($fields)*).cast_mut()); + container_ptr }} } +/// Helper for [`container_of!`]. +#[doc(hidden)] +pub fn assert_same_type<T>(_: T, _: T) {} + /// Helper for `.rs.S` files. #[doc(hidden)] #[macro_export] @@ -216,3 +293,52 @@ macro_rules! asm { ::core::arch::asm!( $($asm)*, $($rest)* ) }; } + +/// Gets the C string file name of a [`Location`]. +/// +/// If `Location::file_as_c_str()` is not available, returns a string that warns about it. +/// +/// [`Location`]: core::panic::Location +/// +/// # Examples +/// +/// ``` +/// # use kernel::file_from_location; +/// +/// #[track_caller] +/// fn foo() { +/// let caller = core::panic::Location::caller(); +/// +/// // Output: +/// // - A path like "rust/kernel/example.rs" if `file_as_c_str()` is available. +/// // - "<Location::file_as_c_str() not supported>" otherwise. +/// let caller_file = file_from_location(caller); +/// +/// // Prints out the message with caller's file name. +/// pr_info!("foo() called in file {caller_file:?}\n"); +/// +/// # if cfg!(CONFIG_RUSTC_HAS_FILE_WITH_NUL) { +/// # assert_eq!(Ok(caller.file()), caller_file.to_str()); +/// # } +/// } +/// +/// # foo(); +/// ``` +#[inline] +pub fn file_from_location<'a>(loc: &'a core::panic::Location<'a>) -> &'a core::ffi::CStr { + #[cfg(CONFIG_RUSTC_HAS_FILE_AS_C_STR)] + { + loc.file_as_c_str() + } + + #[cfg(all(CONFIG_RUSTC_HAS_FILE_WITH_NUL, not(CONFIG_RUSTC_HAS_FILE_AS_C_STR)))] + { + loc.file_with_nul() + } + + #[cfg(not(CONFIG_RUSTC_HAS_FILE_WITH_NUL))] + { + let _ = loc; + c"<Location::file_as_c_str() not supported>" + } +} diff --git a/rust/kernel/list.rs b/rust/kernel/list.rs index fb93330f4af4..44e5219cfcbc 100644 --- a/rust/kernel/list.rs +++ b/rust/kernel/list.rs @@ -4,12 +4,12 @@ //! A linked list implementation. -use crate::init::PinInit; use crate::sync::ArcBorrow; use crate::types::Opaque; use core::iter::{DoubleEndedIterator, FusedIterator}; use core::marker::PhantomData; use core::ptr; +use pin_init::PinInit; mod impl_list_item_mod; pub use self::impl_list_item_mod::{ @@ -35,6 +35,111 @@ pub use self::arc_field::{define_list_arc_field_getter, ListArcField}; /// * All prev/next pointers in `ListLinks` fields of items in the list are valid and form a cycle. /// * For every item in the list, the list owns the associated [`ListArc`] reference and has /// exclusive access to the `ListLinks` field. +/// +/// # Examples +/// +/// ``` +/// use kernel::list::*; +/// +/// #[pin_data] +/// struct BasicItem { +/// value: i32, +/// #[pin] +/// links: ListLinks, +/// } +/// +/// impl BasicItem { +/// fn new(value: i32) -> Result<ListArc<Self>> { +/// ListArc::pin_init(try_pin_init!(Self { +/// value, +/// links <- ListLinks::new(), +/// }), GFP_KERNEL) +/// } +/// } +/// +/// impl_list_arc_safe! { +/// impl ListArcSafe<0> for BasicItem { untracked; } +/// } +/// impl_list_item! { +/// impl ListItem<0> for BasicItem { using ListLinks { self.links }; } +/// } +/// +/// // Create a new empty list. +/// let mut list = List::new(); +/// { +/// assert!(list.is_empty()); +/// } +/// +/// // Insert 3 elements using `push_back()`. +/// list.push_back(BasicItem::new(15)?); +/// list.push_back(BasicItem::new(10)?); +/// list.push_back(BasicItem::new(30)?); +/// +/// // Iterate over the list to verify the nodes were inserted correctly. +/// // [15, 10, 30] +/// { +/// let mut iter = list.iter(); +/// assert_eq!(iter.next().ok_or(EINVAL)?.value, 15); +/// assert_eq!(iter.next().ok_or(EINVAL)?.value, 10); +/// assert_eq!(iter.next().ok_or(EINVAL)?.value, 30); +/// assert!(iter.next().is_none()); +/// +/// // Verify the length of the list. +/// assert_eq!(list.iter().count(), 3); +/// } +/// +/// // Pop the items from the list using `pop_back()` and verify the content. +/// { +/// assert_eq!(list.pop_back().ok_or(EINVAL)?.value, 30); +/// assert_eq!(list.pop_back().ok_or(EINVAL)?.value, 10); +/// assert_eq!(list.pop_back().ok_or(EINVAL)?.value, 15); +/// } +/// +/// // Insert 3 elements using `push_front()`. +/// list.push_front(BasicItem::new(15)?); +/// list.push_front(BasicItem::new(10)?); +/// list.push_front(BasicItem::new(30)?); +/// +/// // Iterate over the list to verify the nodes were inserted correctly. +/// // [30, 10, 15] +/// { +/// let mut iter = list.iter(); +/// assert_eq!(iter.next().ok_or(EINVAL)?.value, 30); +/// assert_eq!(iter.next().ok_or(EINVAL)?.value, 10); +/// assert_eq!(iter.next().ok_or(EINVAL)?.value, 15); +/// assert!(iter.next().is_none()); +/// +/// // Verify the length of the list. +/// assert_eq!(list.iter().count(), 3); +/// } +/// +/// // Pop the items from the list using `pop_front()` and verify the content. +/// { +/// assert_eq!(list.pop_front().ok_or(EINVAL)?.value, 30); +/// assert_eq!(list.pop_front().ok_or(EINVAL)?.value, 10); +/// } +/// +/// // Push `list2` to `list` through `push_all_back()`. +/// // list: [15] +/// // list2: [25, 35] +/// { +/// let mut list2 = List::new(); +/// list2.push_back(BasicItem::new(25)?); +/// list2.push_back(BasicItem::new(35)?); +/// +/// list.push_all_back(&mut list2); +/// +/// // list: [15, 25, 35] +/// // list2: [] +/// let mut iter = list.iter(); +/// assert_eq!(iter.next().ok_or(EINVAL)?.value, 15); +/// assert_eq!(iter.next().ok_or(EINVAL)?.value, 25); +/// assert_eq!(iter.next().ok_or(EINVAL)?.value, 35); +/// assert!(iter.next().is_none()); +/// assert!(list2.is_empty()); +/// } +/// # Result::<(), Error>::Ok(()) +/// ``` pub struct List<T: ?Sized + ListItem<ID>, const ID: u64 = 0> { first: *mut ListLinksFields, _ty: PhantomData<ListArc<T, ID>>, @@ -176,7 +281,7 @@ impl<const ID: u64> ListLinks<ID> { #[inline] unsafe fn fields(me: *mut Self) -> *mut ListLinksFields { // SAFETY: The caller promises that the pointer is valid. - unsafe { Opaque::raw_get(ptr::addr_of!((*me).inner)) } + unsafe { Opaque::cast_into(ptr::addr_of!((*me).inner)) } } /// # Safety @@ -212,9 +317,6 @@ unsafe impl<T: ?Sized + Send, const ID: u64> Send for ListLinksSelfPtr<T, ID> {} unsafe impl<T: ?Sized + Sync, const ID: u64> Sync for ListLinksSelfPtr<T, ID> {} impl<T: ?Sized, const ID: u64> ListLinksSelfPtr<T, ID> { - /// The offset from the [`ListLinks`] to the self pointer field. - pub const LIST_LINKS_SELF_PTR_OFFSET: usize = core::mem::offset_of!(Self, self_ptr); - /// Creates a new initializer for this type. pub fn new() -> impl PinInit<Self> { // INVARIANT: Pin-init initializers can't be used on an existing `Arc`, so this value will @@ -229,6 +331,16 @@ impl<T: ?Sized, const ID: u64> ListLinksSelfPtr<T, ID> { self_ptr: Opaque::uninit(), } } + + /// Returns a pointer to the self pointer. + /// + /// # Safety + /// + /// The provided pointer must point at a valid struct of type `Self`. + pub unsafe fn raw_get_self_ptr(me: *const Self) -> *const Opaque<*const T> { + // SAFETY: The caller promises that the pointer is valid. + unsafe { ptr::addr_of!((*me).self_ptr) } + } } impl<T: ?Sized + ListItem<ID>, const ID: u64> List<T, ID> { @@ -245,8 +357,20 @@ impl<T: ?Sized + ListItem<ID>, const ID: u64> List<T, ID> { self.first.is_null() } - /// Add the provided item to the back of the list. - pub fn push_back(&mut self, item: ListArc<T, ID>) { + /// Inserts `item` before `next` in the cycle. + /// + /// Returns a pointer to the newly inserted element. Never changes `self.first` unless the list + /// is empty. + /// + /// # Safety + /// + /// * `next` must be an element in this list or null. + /// * if `next` is null, then the list must be empty. + unsafe fn insert_inner( + &mut self, + item: ListArc<T, ID>, + next: *mut ListLinksFields, + ) -> *mut ListLinksFields { let raw_item = ListArc::into_raw(item); // SAFETY: // * We just got `raw_item` from a `ListArc`, so it's in an `Arc`. @@ -259,16 +383,16 @@ impl<T: ?Sized + ListItem<ID>, const ID: u64> List<T, ID> { // SAFETY: We have not yet called `post_remove`, so `list_links` is still valid. let item = unsafe { ListLinks::fields(list_links) }; - if self.first.is_null() { - self.first = item; + // Check if the list is empty. + if next.is_null() { // SAFETY: The caller just gave us ownership of these fields. // INVARIANT: A linked list with one item should be cyclic. unsafe { (*item).next = item; (*item).prev = item; } + self.first = item; } else { - let next = self.first; // SAFETY: By the type invariant, this pointer is valid or null. We just checked that // it's not null, so it must be valid. let prev = unsafe { (*next).prev }; @@ -282,50 +406,32 @@ impl<T: ?Sized + ListItem<ID>, const ID: u64> List<T, ID> { (*next).prev = item; } } + + item + } + + /// Add the provided item to the back of the list. + pub fn push_back(&mut self, item: ListArc<T, ID>) { + // SAFETY: + // * `self.first` is null or in the list. + // * `self.first` is only null if the list is empty. + unsafe { self.insert_inner(item, self.first) }; } /// Add the provided item to the front of the list. pub fn push_front(&mut self, item: ListArc<T, ID>) { - let raw_item = ListArc::into_raw(item); // SAFETY: - // * We just got `raw_item` from a `ListArc`, so it's in an `Arc`. - // * If this requirement is violated, then the previous caller of `prepare_to_insert` - // violated the safety requirement that they can't give up ownership of the `ListArc` - // until they call `post_remove`. - // * We own the `ListArc`. - // * Removing items] from this list is always done using `remove_internal_inner`, which - // calls `post_remove` before giving up ownership. - let list_links = unsafe { T::prepare_to_insert(raw_item) }; - // SAFETY: We have not yet called `post_remove`, so `list_links` is still valid. - let item = unsafe { ListLinks::fields(list_links) }; + // * `self.first` is null or in the list. + // * `self.first` is only null if the list is empty. + let new_elem = unsafe { self.insert_inner(item, self.first) }; - if self.first.is_null() { - // SAFETY: The caller just gave us ownership of these fields. - // INVARIANT: A linked list with one item should be cyclic. - unsafe { - (*item).next = item; - (*item).prev = item; - } - } else { - let next = self.first; - // SAFETY: We just checked that `next` is non-null. - let prev = unsafe { (*next).prev }; - // SAFETY: Pointers in a linked list are never dangling, and the caller just gave us - // ownership of the fields on `item`. - // INVARIANT: This correctly inserts `item` between `prev` and `next`. - unsafe { - (*item).next = next; - (*item).prev = prev; - (*prev).next = item; - (*next).prev = item; - } - } - self.first = item; + // INVARIANT: `new_elem` is in the list because we just inserted it. + self.first = new_elem; } /// Removes the last item from this list. pub fn pop_back(&mut self) -> Option<ListArc<T, ID>> { - if self.first.is_null() { + if self.is_empty() { return None; } @@ -337,7 +443,7 @@ impl<T: ?Sized + ListItem<ID>, const ID: u64> List<T, ID> { /// Removes the first item from this list. pub fn pop_front(&mut self) -> Option<ListArc<T, ID>> { - if self.first.is_null() { + if self.is_empty() { return None; } @@ -489,17 +595,21 @@ impl<T: ?Sized + ListItem<ID>, const ID: u64> List<T, ID> { other.first = ptr::null_mut(); } - /// Returns a cursor to the first element of the list. - /// - /// If the list is empty, this returns `None`. - pub fn cursor_front(&mut self) -> Option<Cursor<'_, T, ID>> { - if self.first.is_null() { - None - } else { - Some(Cursor { - current: self.first, - list: self, - }) + /// Returns a cursor that points before the first element of the list. + pub fn cursor_front(&mut self) -> Cursor<'_, T, ID> { + // INVARIANT: `self.first` is in this list. + Cursor { + next: self.first, + list: self, + } + } + + /// Returns a cursor that points after the last element in the list. + pub fn cursor_back(&mut self) -> Cursor<'_, T, ID> { + // INVARIANT: `next` is allowed to be null. + Cursor { + next: core::ptr::null_mut(), + list: self, } } @@ -579,69 +689,355 @@ impl<'a, T: ?Sized + ListItem<ID>, const ID: u64> Iterator for Iter<'a, T, ID> { /// A cursor into a [`List`]. /// +/// A cursor always rests between two elements in the list. This means that a cursor has a previous +/// and next element, but no current element. It also means that it's possible to have a cursor +/// into an empty list. +/// +/// # Examples +/// +/// ``` +/// use kernel::prelude::*; +/// use kernel::list::{List, ListArc, ListLinks}; +/// +/// #[pin_data] +/// struct ListItem { +/// value: u32, +/// #[pin] +/// links: ListLinks, +/// } +/// +/// impl ListItem { +/// fn new(value: u32) -> Result<ListArc<Self>> { +/// ListArc::pin_init(try_pin_init!(Self { +/// value, +/// links <- ListLinks::new(), +/// }), GFP_KERNEL) +/// } +/// } +/// +/// kernel::list::impl_list_arc_safe! { +/// impl ListArcSafe<0> for ListItem { untracked; } +/// } +/// kernel::list::impl_list_item! { +/// impl ListItem<0> for ListItem { using ListLinks { self.links }; } +/// } +/// +/// // Use a cursor to remove the first element with the given value. +/// fn remove_first(list: &mut List<ListItem>, value: u32) -> Option<ListArc<ListItem>> { +/// let mut cursor = list.cursor_front(); +/// while let Some(next) = cursor.peek_next() { +/// if next.value == value { +/// return Some(next.remove()); +/// } +/// cursor.move_next(); +/// } +/// None +/// } +/// +/// // Use a cursor to remove the last element with the given value. +/// fn remove_last(list: &mut List<ListItem>, value: u32) -> Option<ListArc<ListItem>> { +/// let mut cursor = list.cursor_back(); +/// while let Some(prev) = cursor.peek_prev() { +/// if prev.value == value { +/// return Some(prev.remove()); +/// } +/// cursor.move_prev(); +/// } +/// None +/// } +/// +/// // Use a cursor to remove all elements with the given value. The removed elements are moved to +/// // a new list. +/// fn remove_all(list: &mut List<ListItem>, value: u32) -> List<ListItem> { +/// let mut out = List::new(); +/// let mut cursor = list.cursor_front(); +/// while let Some(next) = cursor.peek_next() { +/// if next.value == value { +/// out.push_back(next.remove()); +/// } else { +/// cursor.move_next(); +/// } +/// } +/// out +/// } +/// +/// // Use a cursor to insert a value at a specific index. Returns an error if the index is out of +/// // bounds. +/// fn insert_at(list: &mut List<ListItem>, new: ListArc<ListItem>, idx: usize) -> Result { +/// let mut cursor = list.cursor_front(); +/// for _ in 0..idx { +/// if !cursor.move_next() { +/// return Err(EINVAL); +/// } +/// } +/// cursor.insert_next(new); +/// Ok(()) +/// } +/// +/// // Merge two sorted lists into a single sorted list. +/// fn merge_sorted(list: &mut List<ListItem>, merge: List<ListItem>) { +/// let mut cursor = list.cursor_front(); +/// for to_insert in merge { +/// while let Some(next) = cursor.peek_next() { +/// if to_insert.value < next.value { +/// break; +/// } +/// cursor.move_next(); +/// } +/// cursor.insert_prev(to_insert); +/// } +/// } +/// +/// let mut list = List::new(); +/// list.push_back(ListItem::new(14)?); +/// list.push_back(ListItem::new(12)?); +/// list.push_back(ListItem::new(10)?); +/// list.push_back(ListItem::new(12)?); +/// list.push_back(ListItem::new(15)?); +/// list.push_back(ListItem::new(14)?); +/// assert_eq!(remove_all(&mut list, 12).iter().count(), 2); +/// // [14, 10, 15, 14] +/// assert!(remove_first(&mut list, 14).is_some()); +/// // [10, 15, 14] +/// insert_at(&mut list, ListItem::new(12)?, 2)?; +/// // [10, 15, 12, 14] +/// assert!(remove_last(&mut list, 15).is_some()); +/// // [10, 12, 14] +/// +/// let mut list2 = List::new(); +/// list2.push_back(ListItem::new(11)?); +/// list2.push_back(ListItem::new(13)?); +/// merge_sorted(&mut list, list2); +/// +/// let mut items = list.into_iter(); +/// assert_eq!(items.next().ok_or(EINVAL)?.value, 10); +/// assert_eq!(items.next().ok_or(EINVAL)?.value, 11); +/// assert_eq!(items.next().ok_or(EINVAL)?.value, 12); +/// assert_eq!(items.next().ok_or(EINVAL)?.value, 13); +/// assert_eq!(items.next().ok_or(EINVAL)?.value, 14); +/// assert!(items.next().is_none()); +/// # Result::<(), Error>::Ok(()) +/// ``` +/// /// # Invariants /// -/// The `current` pointer points a value in `list`. +/// The `next` pointer is null or points a value in `list`. pub struct Cursor<'a, T: ?Sized + ListItem<ID>, const ID: u64 = 0> { - current: *mut ListLinksFields, list: &'a mut List<T, ID>, + /// Points at the element after this cursor, or null if the cursor is after the last element. + next: *mut ListLinksFields, } impl<'a, T: ?Sized + ListItem<ID>, const ID: u64> Cursor<'a, T, ID> { - /// Access the current element of this cursor. - pub fn current(&self) -> ArcBorrow<'_, T> { - // SAFETY: The `current` pointer points a value in the list. - let me = unsafe { T::view_value(ListLinks::from_fields(self.current)) }; - // SAFETY: - // * All values in a list are stored in an `Arc`. - // * The value cannot be removed from the list for the duration of the lifetime annotated - // on the returned `ArcBorrow`, because removing it from the list would require mutable - // access to the cursor or the list. However, the `ArcBorrow` holds an immutable borrow - // on the cursor, which in turn holds a mutable borrow on the list, so any such - // mutable access requires first releasing the immutable borrow on the cursor. - // * Values in a list never have a `UniqueArc` reference, because the list has a `ListArc` - // reference, and `UniqueArc` references must be unique. - unsafe { ArcBorrow::from_raw(me) } + /// Returns a pointer to the element before the cursor. + /// + /// Returns null if there is no element before the cursor. + fn prev_ptr(&self) -> *mut ListLinksFields { + let mut next = self.next; + let first = self.list.first; + if next == first { + // We are before the first element. + return core::ptr::null_mut(); + } + + if next.is_null() { + // We are after the last element, so we need a pointer to the last element, which is + // the same as `(*first).prev`. + next = first; + } + + // SAFETY: `next` can't be null, because then `first` must also be null, but in that case + // we would have exited at the `next == first` check. Thus, `next` is an element in the + // list, so we can access its `prev` pointer. + unsafe { (*next).prev } + } + + /// Access the element after this cursor. + pub fn peek_next(&mut self) -> Option<CursorPeek<'_, 'a, T, true, ID>> { + if self.next.is_null() { + return None; + } + + // INVARIANT: + // * We just checked that `self.next` is non-null, so it must be in `self.list`. + // * `ptr` is equal to `self.next`. + Some(CursorPeek { + ptr: self.next, + cursor: self, + }) + } + + /// Access the element before this cursor. + pub fn peek_prev(&mut self) -> Option<CursorPeek<'_, 'a, T, false, ID>> { + let prev = self.prev_ptr(); + + if prev.is_null() { + return None; + } + + // INVARIANT: + // * We just checked that `prev` is non-null, so it must be in `self.list`. + // * `self.prev_ptr()` never returns `self.next`. + Some(CursorPeek { + ptr: prev, + cursor: self, + }) } - /// Move the cursor to the next element. - pub fn next(self) -> Option<Cursor<'a, T, ID>> { - // SAFETY: The `current` field is always in a list. - let next = unsafe { (*self.current).next }; + /// Move the cursor one element forward. + /// + /// If the cursor is after the last element, then this call does nothing. This call returns + /// `true` if the cursor's position was changed. + pub fn move_next(&mut self) -> bool { + if self.next.is_null() { + return false; + } + + // SAFETY: `self.next` is an element in the list and we borrow the list mutably, so we can + // access the `next` field. + let mut next = unsafe { (*self.next).next }; if next == self.list.first { - None - } else { - // INVARIANT: Since `self.current` is in the `list`, its `next` pointer is also in the - // `list`. - Some(Cursor { - current: next, - list: self.list, - }) + next = core::ptr::null_mut(); } + + // INVARIANT: `next` is either null or the next element after an element in the list. + self.next = next; + true } - /// Move the cursor to the previous element. - pub fn prev(self) -> Option<Cursor<'a, T, ID>> { - // SAFETY: The `current` field is always in a list. - let prev = unsafe { (*self.current).prev }; + /// Move the cursor one element backwards. + /// + /// If the cursor is before the first element, then this call does nothing. This call returns + /// `true` if the cursor's position was changed. + pub fn move_prev(&mut self) -> bool { + if self.next == self.list.first { + return false; + } - if self.current == self.list.first { - None + // INVARIANT: `prev_ptr()` always returns a pointer that is null or in the list. + self.next = self.prev_ptr(); + true + } + + /// Inserts an element where the cursor is pointing and get a pointer to the new element. + fn insert_inner(&mut self, item: ListArc<T, ID>) -> *mut ListLinksFields { + let ptr = if self.next.is_null() { + self.list.first } else { - // INVARIANT: Since `self.current` is in the `list`, its `prev` pointer is also in the - // `list`. - Some(Cursor { - current: prev, - list: self.list, - }) + self.next + }; + // SAFETY: + // * `ptr` is an element in the list or null. + // * if `ptr` is null, then `self.list.first` is null so the list is empty. + let item = unsafe { self.list.insert_inner(item, ptr) }; + if self.next == self.list.first { + // INVARIANT: We just inserted `item`, so it's a member of list. + self.list.first = item; } + item + } + + /// Insert an element at this cursor's location. + pub fn insert(mut self, item: ListArc<T, ID>) { + // This is identical to `insert_prev`, but consumes the cursor. This is helpful because it + // reduces confusion when the last operation on the cursor is an insertion; in that case, + // you just want to insert the element at the cursor, and it is confusing that the call + // involves the word prev or next. + self.insert_inner(item); + } + + /// Inserts an element after this cursor. + /// + /// After insertion, the new element will be after the cursor. + pub fn insert_next(&mut self, item: ListArc<T, ID>) { + self.next = self.insert_inner(item); } - /// Remove the current element from the list. + /// Inserts an element before this cursor. + /// + /// After insertion, the new element will be before the cursor. + pub fn insert_prev(&mut self, item: ListArc<T, ID>) { + self.insert_inner(item); + } + + /// Remove the next element from the list. + pub fn remove_next(&mut self) -> Option<ListArc<T, ID>> { + self.peek_next().map(|v| v.remove()) + } + + /// Remove the previous element from the list. + pub fn remove_prev(&mut self) -> Option<ListArc<T, ID>> { + self.peek_prev().map(|v| v.remove()) + } +} + +/// References the element in the list next to the cursor. +/// +/// # Invariants +/// +/// * `ptr` is an element in `self.cursor.list`. +/// * `ISNEXT == (self.ptr == self.cursor.next)`. +pub struct CursorPeek<'a, 'b, T: ?Sized + ListItem<ID>, const ISNEXT: bool, const ID: u64> { + cursor: &'a mut Cursor<'b, T, ID>, + ptr: *mut ListLinksFields, +} + +impl<'a, 'b, T: ?Sized + ListItem<ID>, const ISNEXT: bool, const ID: u64> + CursorPeek<'a, 'b, T, ISNEXT, ID> +{ + /// Remove the element from the list. pub fn remove(self) -> ListArc<T, ID> { - // SAFETY: The `current` pointer always points at a member of the list. - unsafe { self.list.remove_internal(self.current) } + if ISNEXT { + self.cursor.move_next(); + } + + // INVARIANT: `self.ptr` is not equal to `self.cursor.next` due to the above `move_next` + // call. + // SAFETY: By the type invariants of `Self`, `next` is not null, so `next` is an element of + // `self.cursor.list` by the type invariants of `Cursor`. + unsafe { self.cursor.list.remove_internal(self.ptr) } + } + + /// Access this value as an [`ArcBorrow`]. + pub fn arc(&self) -> ArcBorrow<'_, T> { + // SAFETY: `self.ptr` points at an element in `self.cursor.list`. + let me = unsafe { T::view_value(ListLinks::from_fields(self.ptr)) }; + // SAFETY: + // * All values in a list are stored in an `Arc`. + // * The value cannot be removed from the list for the duration of the lifetime annotated + // on the returned `ArcBorrow`, because removing it from the list would require mutable + // access to the `CursorPeek`, the `Cursor` or the `List`. However, the `ArcBorrow` holds + // an immutable borrow on the `CursorPeek`, which in turn holds a mutable borrow on the + // `Cursor`, which in turn holds a mutable borrow on the `List`, so any such mutable + // access requires first releasing the immutable borrow on the `CursorPeek`. + // * Values in a list never have a `UniqueArc` reference, because the list has a `ListArc` + // reference, and `UniqueArc` references must be unique. + unsafe { ArcBorrow::from_raw(me) } + } +} + +impl<'a, 'b, T: ?Sized + ListItem<ID>, const ISNEXT: bool, const ID: u64> core::ops::Deref + for CursorPeek<'a, 'b, T, ISNEXT, ID> +{ + // If you change the `ptr` field to have type `ArcBorrow<'a, T>`, it might seem like you could + // get rid of the `CursorPeek::arc` method and change the deref target to `ArcBorrow<'a, T>`. + // However, that doesn't work because 'a is too long. You could obtain an `ArcBorrow<'a, T>` + // and then call `CursorPeek::remove` without giving up the `ArcBorrow<'a, T>`, which would be + // unsound. + type Target = T; + + fn deref(&self) -> &T { + // SAFETY: `self.ptr` points at an element in `self.cursor.list`. + let me = unsafe { T::view_value(ListLinks::from_fields(self.ptr)) }; + + // SAFETY: The value cannot be removed from the list for the duration of the lifetime + // annotated on the returned `&T`, because removing it from the list would require mutable + // access to the `CursorPeek`, the `Cursor` or the `List`. However, the `&T` holds an + // immutable borrow on the `CursorPeek`, which in turn holds a mutable borrow on the + // `Cursor`, which in turn holds a mutable borrow on the `List`, so any such mutable access + // requires first releasing the immutable borrow on the `CursorPeek`. + unsafe { &*me } } } diff --git a/rust/kernel/list/arc.rs b/rust/kernel/list/arc.rs index 3483d8c232c4..d92bcf665c89 100644 --- a/rust/kernel/list/arc.rs +++ b/rust/kernel/list/arc.rs @@ -7,7 +7,7 @@ use crate::alloc::{AllocError, Flags}; use crate::prelude::*; use crate::sync::{Arc, ArcBorrow, UniqueArc}; -use core::marker::{PhantomPinned, Unsize}; +use core::marker::PhantomPinned; use core::ops::Deref; use core::pin::Pin; use core::sync::atomic::{AtomicBool, Ordering}; @@ -74,7 +74,7 @@ pub unsafe trait TryNewListArc<const ID: u64 = 0>: ListArcSafe<ID> { /// /// * The `untracked` strategy does not actually keep track of whether a [`ListArc`] exists. When /// using this strategy, the only way to create a [`ListArc`] is using a [`UniqueArc`]. -/// * The `tracked_by` strategy defers the tracking to a field of the struct. The user much specify +/// * The `tracked_by` strategy defers the tracking to a field of the struct. The user must specify /// which field to defer the tracking to. The field must implement [`ListArcSafe`]. If the field /// implements [`TryNewListArc`], then the type will also implement [`TryNewListArc`]. /// @@ -96,7 +96,7 @@ macro_rules! impl_list_arc_safe { } $($rest:tt)*) => { impl$(<$($generics)*>)? $crate::list::ListArcSafe<$num> for $t { unsafe fn on_create_list_arc_from_unique(self: ::core::pin::Pin<&mut Self>) { - $crate::assert_pinned!($t, $field, $fty, inline); + ::pin_init::assert_pinned!($t, $field, $fty, inline); // SAFETY: This field is structurally pinned as per the above assertion. let field = unsafe { @@ -159,6 +159,7 @@ pub use impl_list_arc_safe; /// /// [`List`]: crate::list::List #[repr(transparent)] +#[cfg_attr(CONFIG_RUSTC_HAS_COERCE_POINTEE, derive(core::marker::CoercePointee))] pub struct ListArc<T, const ID: u64 = 0> where T: ListArcSafe<ID> + ?Sized, @@ -443,25 +444,27 @@ where // This is to allow coercion from `ListArc<T>` to `ListArc<U>` if `T` can be converted to the // dynamically-sized type (DST) `U`. +#[cfg(not(CONFIG_RUSTC_HAS_COERCE_POINTEE))] impl<T, U, const ID: u64> core::ops::CoerceUnsized<ListArc<U, ID>> for ListArc<T, ID> where - T: ListArcSafe<ID> + Unsize<U> + ?Sized, + T: ListArcSafe<ID> + core::marker::Unsize<U> + ?Sized, U: ListArcSafe<ID> + ?Sized, { } // This is to allow `ListArc<U>` to be dispatched on when `ListArc<T>` can be coerced into // `ListArc<U>`. +#[cfg(not(CONFIG_RUSTC_HAS_COERCE_POINTEE))] impl<T, U, const ID: u64> core::ops::DispatchFromDyn<ListArc<U, ID>> for ListArc<T, ID> where - T: ListArcSafe<ID> + Unsize<U> + ?Sized, + T: ListArcSafe<ID> + core::marker::Unsize<U> + ?Sized, U: ListArcSafe<ID> + ?Sized, { } /// A utility for tracking whether a [`ListArc`] exists using an atomic. /// -/// # Invariant +/// # Invariants /// /// If the boolean is `false`, then there is no [`ListArc`] for this value. #[repr(transparent)] diff --git a/rust/kernel/list/impl_list_item_mod.rs b/rust/kernel/list/impl_list_item_mod.rs index a0438537cee1..202bc6f97c13 100644 --- a/rust/kernel/list/impl_list_item_mod.rs +++ b/rust/kernel/list/impl_list_item_mod.rs @@ -4,60 +4,48 @@ //! Helpers for implementing list traits safely. -use crate::list::ListLinks; - -/// Declares that this type has a `ListLinks<ID>` field at a fixed offset. +/// Declares that this type has a [`ListLinks<ID>`] field. /// -/// This trait is only used to help implement `ListItem` safely. If `ListItem` is implemented +/// This trait is only used to help implement [`ListItem`] safely. If [`ListItem`] is implemented /// manually, then this trait is not needed. Use the [`impl_has_list_links!`] macro to implement /// this trait. /// /// # Safety /// -/// All values of this type must have a `ListLinks<ID>` field at the given offset. +/// The methods on this trait must have exactly the behavior that the definitions given below have. /// -/// The behavior of `raw_get_list_links` must not be changed. +/// [`ListLinks<ID>`]: crate::list::ListLinks +/// [`ListItem`]: crate::list::ListItem pub unsafe trait HasListLinks<const ID: u64 = 0> { - /// The offset of the `ListLinks` field. - const OFFSET: usize; - - /// Returns a pointer to the [`ListLinks<T, ID>`] field. + /// Returns a pointer to the [`ListLinks<ID>`] field. /// /// # Safety /// /// The provided pointer must point at a valid struct of type `Self`. /// - /// [`ListLinks<T, ID>`]: ListLinks - // We don't really need this method, but it's necessary for the implementation of - // `impl_has_list_links!` to be correct. - #[inline] - unsafe fn raw_get_list_links(ptr: *mut Self) -> *mut ListLinks<ID> { - // SAFETY: The caller promises that the pointer is valid. The implementer promises that the - // `OFFSET` constant is correct. - unsafe { (ptr as *mut u8).add(Self::OFFSET) as *mut ListLinks<ID> } - } + /// [`ListLinks<ID>`]: crate::list::ListLinks + unsafe fn raw_get_list_links(ptr: *mut Self) -> *mut crate::list::ListLinks<ID>; } /// Implements the [`HasListLinks`] trait for the given type. #[macro_export] macro_rules! impl_has_list_links { - ($(impl$(<$($implarg:ident),*>)? + ($(impl$({$($generics:tt)*})? HasListLinks$(<$id:tt>)? - for $self:ident $(<$($selfarg:ty),*>)? + for $self:ty { self$(.$field:ident)* } )*) => {$( // SAFETY: The implementation of `raw_get_list_links` only compiles if the field has the // right type. - // - // The behavior of `raw_get_list_links` is not changed since the `addr_of_mut!` macro is - // equivalent to the pointer offset operation in the trait definition. - unsafe impl$(<$($implarg),*>)? $crate::list::HasListLinks$(<$id>)? for - $self $(<$($selfarg),*>)? - { - const OFFSET: usize = ::core::mem::offset_of!(Self, $($field).*) as usize; - + unsafe impl$(<$($generics)*>)? $crate::list::HasListLinks$(<$id>)? for $self { #[inline] unsafe fn raw_get_list_links(ptr: *mut Self) -> *mut $crate::list::ListLinks$(<$id>)? { + // Statically ensure that `$(.field)*` doesn't follow any pointers. + // + // Cannot be `const` because `$self` may contain generics and E0401 says constants + // "can't use {`Self`,generic parameters} from outer item". + if false { let _: usize = ::core::mem::offset_of!(Self, $($field).*); } + // SAFETY: The caller promises that the pointer is not dangling. We know that this // expression doesn't follow any pointers, as the `offset_of!` invocation above // would otherwise not compile. @@ -68,12 +56,16 @@ macro_rules! impl_has_list_links { } pub use impl_has_list_links; -/// Declares that the `ListLinks<ID>` field in this struct is inside a `ListLinksSelfPtr<T, ID>`. +/// Declares that the [`ListLinks<ID>`] field in this struct is inside a +/// [`ListLinksSelfPtr<T, ID>`]. /// /// # Safety /// -/// The `ListLinks<ID>` field of this struct at the offset `HasListLinks<ID>::OFFSET` must be -/// inside a `ListLinksSelfPtr<T, ID>`. +/// The [`ListLinks<ID>`] field of this struct at [`HasListLinks<ID>::raw_get_list_links`] must be +/// inside a [`ListLinksSelfPtr<T, ID>`]. +/// +/// [`ListLinks<ID>`]: crate::list::ListLinks +/// [`ListLinksSelfPtr<T, ID>`]: crate::list::ListLinksSelfPtr pub unsafe trait HasSelfPtr<T: ?Sized, const ID: u64 = 0> where Self: HasListLinks<ID>, @@ -83,27 +75,21 @@ where /// Implements the [`HasListLinks`] and [`HasSelfPtr`] traits for the given type. #[macro_export] macro_rules! impl_has_list_links_self_ptr { - ($(impl$({$($implarg:tt)*})? + ($(impl$({$($generics:tt)*})? HasSelfPtr<$item_type:ty $(, $id:tt)?> - for $self:ident $(<$($selfarg:ty),*>)? - { self.$field:ident } + for $self:ty + { self$(.$field:ident)* } )*) => {$( // SAFETY: The implementation of `raw_get_list_links` only compiles if the field has the // right type. - unsafe impl$(<$($implarg)*>)? $crate::list::HasSelfPtr<$item_type $(, $id)?> for - $self $(<$($selfarg),*>)? - {} - - unsafe impl$(<$($implarg)*>)? $crate::list::HasListLinks$(<$id>)? for - $self $(<$($selfarg),*>)? - { - const OFFSET: usize = ::core::mem::offset_of!(Self, $field) as usize; + unsafe impl$(<$($generics)*>)? $crate::list::HasSelfPtr<$item_type $(, $id)?> for $self {} + unsafe impl$(<$($generics)*>)? $crate::list::HasListLinks$(<$id>)? for $self { #[inline] unsafe fn raw_get_list_links(ptr: *mut Self) -> *mut $crate::list::ListLinks$(<$id>)? { // SAFETY: The caller promises that the pointer is not dangling. let ptr: *mut $crate::list::ListLinksSelfPtr<$item_type $(, $id)?> = - unsafe { ::core::ptr::addr_of_mut!((*ptr).$field) }; + unsafe { ::core::ptr::addr_of_mut!((*ptr)$(.$field)*) }; ptr.cast() } } @@ -117,15 +103,95 @@ pub use impl_has_list_links_self_ptr; /// implement that trait. /// /// [`ListItem`]: crate::list::ListItem +/// +/// # Examples +/// +/// ``` +/// #[pin_data] +/// struct SimpleListItem { +/// value: u32, +/// #[pin] +/// links: kernel::list::ListLinks, +/// } +/// +/// kernel::list::impl_list_arc_safe! { +/// impl ListArcSafe<0> for SimpleListItem { untracked; } +/// } +/// +/// kernel::list::impl_list_item! { +/// impl ListItem<0> for SimpleListItem { using ListLinks { self.links }; } +/// } +/// +/// struct ListLinksHolder { +/// inner: kernel::list::ListLinks, +/// } +/// +/// #[pin_data] +/// struct ComplexListItem<T, U> { +/// value: Result<T, U>, +/// #[pin] +/// links: ListLinksHolder, +/// } +/// +/// kernel::list::impl_list_arc_safe! { +/// impl{T, U} ListArcSafe<0> for ComplexListItem<T, U> { untracked; } +/// } +/// +/// kernel::list::impl_list_item! { +/// impl{T, U} ListItem<0> for ComplexListItem<T, U> { using ListLinks { self.links.inner }; } +/// } +/// ``` +/// +/// ``` +/// #[pin_data] +/// struct SimpleListItem { +/// value: u32, +/// #[pin] +/// links: kernel::list::ListLinksSelfPtr<SimpleListItem>, +/// } +/// +/// kernel::list::impl_list_arc_safe! { +/// impl ListArcSafe<0> for SimpleListItem { untracked; } +/// } +/// +/// kernel::list::impl_list_item! { +/// impl ListItem<0> for SimpleListItem { using ListLinksSelfPtr { self.links }; } +/// } +/// +/// struct ListLinksSelfPtrHolder<T, U> { +/// inner: kernel::list::ListLinksSelfPtr<ComplexListItem<T, U>>, +/// } +/// +/// #[pin_data] +/// struct ComplexListItem<T, U> { +/// value: Result<T, U>, +/// #[pin] +/// links: ListLinksSelfPtrHolder<T, U>, +/// } +/// +/// kernel::list::impl_list_arc_safe! { +/// impl{T, U} ListArcSafe<0> for ComplexListItem<T, U> { untracked; } +/// } +/// +/// kernel::list::impl_list_item! { +/// impl{T, U} ListItem<0> for ComplexListItem<T, U> { +/// using ListLinksSelfPtr { self.links.inner }; +/// } +/// } +/// ``` #[macro_export] macro_rules! impl_list_item { ( - $(impl$({$($generics:tt)*})? ListItem<$num:tt> for $t:ty { - using ListLinks; + $(impl$({$($generics:tt)*})? ListItem<$num:tt> for $self:ty { + using ListLinks { self$(.$field:ident)* }; })* ) => {$( + $crate::list::impl_has_list_links! { + impl$({$($generics)*})? HasListLinks<$num> for $self { self$(.$field)* } + } + // SAFETY: See GUARANTEES comment on each method. - unsafe impl$(<$($generics)*>)? $crate::list::ListItem<$num> for $t { + unsafe impl$(<$($generics)*>)? $crate::list::ListItem<$num> for $self { // GUARANTEES: // * This returns the same pointer as `prepare_to_insert` because `prepare_to_insert` // is implemented in terms of `view_links`. @@ -139,20 +205,19 @@ macro_rules! impl_list_item { } // GUARANTEES: - // * `me` originates from the most recent call to `prepare_to_insert`, which just added - // `offset` to the pointer passed to `prepare_to_insert`. This method subtracts - // `offset` from `me` so it returns the pointer originally passed to - // `prepare_to_insert`. + // * `me` originates from the most recent call to `prepare_to_insert`, which calls + // `raw_get_list_link`, which is implemented using `addr_of_mut!((*self)$(.$field)*)`. + // This method uses `container_of` to perform the inverse operation, so it returns the + // pointer originally passed to `prepare_to_insert`. // * The pointer remains valid until the next call to `post_remove` because the caller // of the most recent call to `prepare_to_insert` promised to retain ownership of the // `ListArc` containing `Self` until the next call to `post_remove`. The value cannot // be destroyed while a `ListArc` reference exists. unsafe fn view_value(me: *mut $crate::list::ListLinks<$num>) -> *const Self { - let offset = <Self as $crate::list::HasListLinks<$num>>::OFFSET; // SAFETY: `me` originates from the most recent call to `prepare_to_insert`, so it - // points at the field at offset `offset` in a value of type `Self`. Thus, - // subtracting `offset` from `me` is still in-bounds of the allocation. - unsafe { (me as *const u8).sub(offset) as *const Self } + // points at the field `$field` in a value of type `Self`. Thus, reversing that + // operation is still in-bounds of the allocation. + $crate::container_of!(me, Self, $($field).*) } // GUARANTEES: @@ -169,27 +234,30 @@ macro_rules! impl_list_item { } // GUARANTEES: - // * `me` originates from the most recent call to `prepare_to_insert`, which just added - // `offset` to the pointer passed to `prepare_to_insert`. This method subtracts - // `offset` from `me` so it returns the pointer originally passed to - // `prepare_to_insert`. + // * `me` originates from the most recent call to `prepare_to_insert`, which calls + // `raw_get_list_link`, which is implemented using `addr_of_mut!((*self)$(.$field)*)`. + // This method uses `container_of` to perform the inverse operation, so it returns the + // pointer originally passed to `prepare_to_insert`. unsafe fn post_remove(me: *mut $crate::list::ListLinks<$num>) -> *const Self { - let offset = <Self as $crate::list::HasListLinks<$num>>::OFFSET; // SAFETY: `me` originates from the most recent call to `prepare_to_insert`, so it - // points at the field at offset `offset` in a value of type `Self`. Thus, - // subtracting `offset` from `me` is still in-bounds of the allocation. - unsafe { (me as *const u8).sub(offset) as *const Self } + // points at the field `$field` in a value of type `Self`. Thus, reversing that + // operation is still in-bounds of the allocation. + $crate::container_of!(me, Self, $($field).*) } } )*}; ( - $(impl$({$($generics:tt)*})? ListItem<$num:tt> for $t:ty { - using ListLinksSelfPtr; + $(impl$({$($generics:tt)*})? ListItem<$num:tt> for $self:ty { + using ListLinksSelfPtr { self$(.$field:ident)* }; })* ) => {$( + $crate::list::impl_has_list_links_self_ptr! { + impl$({$($generics)*})? HasSelfPtr<$self> for $self { self$(.$field)* } + } + // SAFETY: See GUARANTEES comment on each method. - unsafe impl$(<$($generics)*>)? $crate::list::ListItem<$num> for $t { + unsafe impl$(<$($generics)*>)? $crate::list::ListItem<$num> for $self { // GUARANTEES: // This implementation of `ListItem` will not give out exclusive access to the same // `ListLinks` several times because calls to `prepare_to_insert` and `post_remove` @@ -202,14 +270,16 @@ macro_rules! impl_list_item { // SAFETY: The caller promises that `me` points at a valid value of type `Self`. let links_field = unsafe { <Self as $crate::list::ListItem<$num>>::view_links(me) }; - let spoff = $crate::list::ListLinksSelfPtr::<Self, $num>::LIST_LINKS_SELF_PTR_OFFSET; - // Goes via the offset as the field is private. - // - // SAFETY: The constant is equal to `offset_of!(ListLinksSelfPtr, self_ptr)`, so - // the pointer stays in bounds of the allocation. - let self_ptr = unsafe { (links_field as *const u8).add(spoff) } - as *const $crate::types::Opaque<*const Self>; - let cell_inner = $crate::types::Opaque::raw_get(self_ptr); + let container = $crate::container_of!( + links_field, $crate::list::ListLinksSelfPtr<Self, $num>, inner + ); + + // SAFETY: By the same reasoning above, `links_field` is a valid pointer. + let self_ptr = unsafe { + $crate::list::ListLinksSelfPtr::raw_get_self_ptr(container) + }; + + let cell_inner = $crate::types::Opaque::cast_into(self_ptr); // SAFETY: This value is not accessed in any other places than `prepare_to_insert`, // `post_remove`, or `view_value`. By the safety requirements of those methods, @@ -228,7 +298,9 @@ macro_rules! impl_list_item { // this value is not in a list. unsafe fn view_links(me: *const Self) -> *mut $crate::list::ListLinks<$num> { // SAFETY: The caller promises that `me` points at a valid value of type `Self`. - unsafe { <Self as HasListLinks<$num>>::raw_get_list_links(me.cast_mut()) } + unsafe { + <Self as $crate::list::HasListLinks<$num>>::raw_get_list_links(me.cast_mut()) + } } // This function is also used as the implementation of `post_remove`, so the caller @@ -247,12 +319,17 @@ macro_rules! impl_list_item { // `ListArc` containing `Self` until the next call to `post_remove`. The value cannot // be destroyed while a `ListArc` reference exists. unsafe fn view_value(links_field: *mut $crate::list::ListLinks<$num>) -> *const Self { - let spoff = $crate::list::ListLinksSelfPtr::<Self, $num>::LIST_LINKS_SELF_PTR_OFFSET; - // SAFETY: The constant is equal to `offset_of!(ListLinksSelfPtr, self_ptr)`, so - // the pointer stays in bounds of the allocation. - let self_ptr = unsafe { (links_field as *const u8).add(spoff) } - as *const ::core::cell::UnsafeCell<*const Self>; - let cell_inner = ::core::cell::UnsafeCell::raw_get(self_ptr); + let container = $crate::container_of!( + links_field, $crate::list::ListLinksSelfPtr<Self, $num>, inner + ); + + // SAFETY: By the same reasoning above, `links_field` is a valid pointer. + let self_ptr = unsafe { + $crate::list::ListLinksSelfPtr::raw_get_self_ptr(container) + }; + + let cell_inner = $crate::types::Opaque::cast_into(self_ptr); + // SAFETY: This is not a data race, because the only function that writes to this // value is `prepare_to_insert`, but by the safety requirements the // `prepare_to_insert` method may not be called in parallel with `view_value` or diff --git a/rust/kernel/miscdevice.rs b/rust/kernel/miscdevice.rs index 8f88891fb1d2..6373fe183b27 100644 --- a/rust/kernel/miscdevice.rs +++ b/rust/kernel/miscdevice.rs @@ -10,10 +10,13 @@ use crate::{ bindings, + device::Device, error::{to_result, Error, Result, VTABLE_DEFAULT_ERROR}, ffi::{c_int, c_long, c_uint, c_ulong}, + fs::File, + mm::virt::VmaNew, prelude::*, - str::CStr, + seq_file::SeqFile, types::{ForeignOwnable, Opaque}, }; use core::{marker::PhantomData, mem::MaybeUninit, pin::Pin}; @@ -30,9 +33,9 @@ impl MiscDeviceOptions { pub const fn into_raw<T: MiscDevice>(self) -> bindings::miscdevice { // SAFETY: All zeros is valid for this C type. let mut result: bindings::miscdevice = unsafe { MaybeUninit::zeroed().assume_init() }; - result.minor = bindings::MISC_DYNAMIC_MINOR as _; + result.minor = bindings::MISC_DYNAMIC_MINOR as ffi::c_int; result.name = self.name.as_char_ptr(); - result.fops = create_vtable::<T>(); + result.fops = MiscdeviceVTable::<T>::build(); result } } @@ -41,7 +44,13 @@ impl MiscDeviceOptions { /// /// # Invariants /// -/// `inner` is a registered misc device. +/// - `inner` contains a `struct miscdevice` that is registered using +/// `misc_register()`. +/// - This registration remains valid for the entire lifetime of the +/// [`MiscDeviceRegistration`] instance. +/// - Deregistration occurs exactly once in [`Drop`] via `misc_deregister()`. +/// - `inner` wraps a valid, pinned `miscdevice` created using +/// [`MiscDeviceOptions::into_raw`]. #[repr(transparent)] #[pin_data(PinnedDrop)] pub struct MiscDeviceRegistration<T> { @@ -80,6 +89,16 @@ impl<T: MiscDevice> MiscDeviceRegistration<T> { pub fn as_raw(&self) -> *mut bindings::miscdevice { self.inner.get() } + + /// Access the `this_device` field. + pub fn device(&self) -> &Device { + // SAFETY: This can only be called after a successful register(), which always + // initialises `this_device` with a valid device. Furthermore, the signature of this + // function tells the borrow-checker that the `&Device` reference must not outlive the + // `&MiscDeviceRegistration<T>` used to obtain it, so the last use of the reference must be + // before the underlying `struct miscdevice` is destroyed. + unsafe { Device::from_raw((*self.as_raw()).this_device) } + } } #[pinned_drop] @@ -92,31 +111,48 @@ impl<T> PinnedDrop for MiscDeviceRegistration<T> { /// Trait implemented by the private data of an open misc device. #[vtable] -pub trait MiscDevice { +pub trait MiscDevice: Sized { /// What kind of pointer should `Self` be wrapped in. type Ptr: ForeignOwnable + Send + Sync; /// Called when the misc device is opened. /// /// The returned pointer will be stored as the private data for the file. - fn open() -> Result<Self::Ptr>; + fn open(_file: &File, _misc: &MiscDeviceRegistration<Self>) -> Result<Self::Ptr>; /// Called when the misc device is released. - fn release(device: Self::Ptr) { + fn release(device: Self::Ptr, _file: &File) { drop(device); } + /// Handle for mmap. + /// + /// This function is invoked when a user space process invokes the `mmap` system call on + /// `file`. The function is a callback that is part of the VMA initializer. The kernel will do + /// initial setup of the VMA before calling this function. The function can then interact with + /// the VMA initialization by calling methods of `vma`. If the function does not return an + /// error, the kernel will complete initialization of the VMA according to the properties of + /// `vma`. + fn mmap( + _device: <Self::Ptr as ForeignOwnable>::Borrowed<'_>, + _file: &File, + _vma: &VmaNew, + ) -> Result { + build_error!(VTABLE_DEFAULT_ERROR) + } + /// Handler for ioctls. /// - /// The `cmd` argument is usually manipulated using the utilties in [`kernel::ioctl`]. + /// The `cmd` argument is usually manipulated using the utilities in [`kernel::ioctl`]. /// /// [`kernel::ioctl`]: mod@crate::ioctl fn ioctl( _device: <Self::Ptr as ForeignOwnable>::Borrowed<'_>, + _file: &File, _cmd: u32, _arg: usize, ) -> Result<isize> { - kernel::build_error(VTABLE_DEFAULT_ERROR) + build_error!(VTABLE_DEFAULT_ERROR) } /// Handler for ioctls. @@ -129,124 +165,205 @@ pub trait MiscDevice { #[cfg(CONFIG_COMPAT)] fn compat_ioctl( _device: <Self::Ptr as ForeignOwnable>::Borrowed<'_>, + _file: &File, _cmd: u32, _arg: usize, ) -> Result<isize> { - kernel::build_error(VTABLE_DEFAULT_ERROR) + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// Show info for this fd. + fn show_fdinfo( + _device: <Self::Ptr as ForeignOwnable>::Borrowed<'_>, + _m: &SeqFile, + _file: &File, + ) { + build_error!(VTABLE_DEFAULT_ERROR) } } -const fn create_vtable<T: MiscDevice>() -> &'static bindings::file_operations { - const fn maybe_fn<T: Copy>(check: bool, func: T) -> Option<T> { - if check { - Some(func) - } else { - None +/// A vtable for the file operations of a Rust miscdevice. +struct MiscdeviceVTable<T: MiscDevice>(PhantomData<T>); + +impl<T: MiscDevice> MiscdeviceVTable<T> { + /// # Safety + /// + /// `file` and `inode` must be the file and inode for a file that is undergoing initialization. + /// The file must be associated with a `MiscDeviceRegistration<T>`. + unsafe extern "C" fn open(inode: *mut bindings::inode, raw_file: *mut bindings::file) -> c_int { + // SAFETY: The pointers are valid and for a file being opened. + let ret = unsafe { bindings::generic_file_open(inode, raw_file) }; + if ret != 0 { + return ret; } - } - struct VtableHelper<T: MiscDevice> { - _t: PhantomData<T>, - } - impl<T: MiscDevice> VtableHelper<T> { - const VTABLE: bindings::file_operations = bindings::file_operations { - open: Some(fops_open::<T>), - release: Some(fops_release::<T>), - unlocked_ioctl: maybe_fn(T::HAS_IOCTL, fops_ioctl::<T>), - #[cfg(CONFIG_COMPAT)] - compat_ioctl: if T::HAS_COMPAT_IOCTL { - Some(fops_compat_ioctl::<T>) - } else if T::HAS_IOCTL { - Some(bindings::compat_ptr_ioctl) - } else { - None - }, - // SAFETY: All zeros is a valid value for `bindings::file_operations`. - ..unsafe { MaybeUninit::zeroed().assume_init() } + // SAFETY: The open call of a file can access the private data. + let misc_ptr = unsafe { (*raw_file).private_data }; + + // SAFETY: This is a miscdevice, so `misc_open()` set the private data to a pointer to the + // associated `struct miscdevice` before calling into this method. Furthermore, + // `misc_open()` ensures that the miscdevice can't be unregistered and freed during this + // call to `fops_open`. + let misc = unsafe { &*misc_ptr.cast::<MiscDeviceRegistration<T>>() }; + + // SAFETY: + // * This underlying file is valid for (much longer than) the duration of `T::open`. + // * There is no active fdget_pos region on the file on this thread. + let file = unsafe { File::from_raw_file(raw_file) }; + + let ptr = match T::open(file, misc) { + Ok(ptr) => ptr, + Err(err) => return err.to_errno(), }; + + // This overwrites the private data with the value specified by the user, changing the type + // of this file's private data. All future accesses to the private data is performed by + // other fops_* methods in this file, which all correctly cast the private data to the new + // type. + // + // SAFETY: The open call of a file can access the private data. + unsafe { (*raw_file).private_data = ptr.into_foreign() }; + + 0 } - &VtableHelper::<T>::VTABLE -} + /// # Safety + /// + /// `file` and `inode` must be the file and inode for a file that is being released. The file + /// must be associated with a `MiscDeviceRegistration<T>`. + unsafe extern "C" fn release(_inode: *mut bindings::inode, file: *mut bindings::file) -> c_int { + // SAFETY: The release call of a file owns the private data. + let private = unsafe { (*file).private_data }; + // SAFETY: The release call of a file owns the private data. + let ptr = unsafe { <T::Ptr as ForeignOwnable>::from_foreign(private) }; -/// # Safety -/// -/// `file` and `inode` must be the file and inode for a file that is undergoing initialization. -/// The file must be associated with a `MiscDeviceRegistration<T>`. -unsafe extern "C" fn fops_open<T: MiscDevice>( - inode: *mut bindings::inode, - file: *mut bindings::file, -) -> c_int { - // SAFETY: The pointers are valid and for a file being opened. - let ret = unsafe { bindings::generic_file_open(inode, file) }; - if ret != 0 { - return ret; - } - - let ptr = match T::open() { - Ok(ptr) => ptr, - Err(err) => return err.to_errno(), - }; + // SAFETY: + // * The file is valid for the duration of this call. + // * There is no active fdget_pos region on the file on this thread. + T::release(ptr, unsafe { File::from_raw_file(file) }); - // SAFETY: The open call of a file owns the private data. - unsafe { (*file).private_data = ptr.into_foreign().cast_mut() }; + 0 + } - 0 -} + /// # Safety + /// + /// `file` must be a valid file that is associated with a `MiscDeviceRegistration<T>`. + /// `vma` must be a vma that is currently being mmap'ed with this file. + unsafe extern "C" fn mmap( + file: *mut bindings::file, + vma: *mut bindings::vm_area_struct, + ) -> c_int { + // SAFETY: The mmap call of a file can access the private data. + let private = unsafe { (*file).private_data }; + // SAFETY: This is a Rust Miscdevice, so we call `into_foreign` in `open` and + // `from_foreign` in `release`, and `fops_mmap` is guaranteed to be called between those + // two operations. + let device = unsafe { <T::Ptr as ForeignOwnable>::borrow(private.cast()) }; + // SAFETY: The caller provides a vma that is undergoing initial VMA setup. + let area = unsafe { VmaNew::from_raw(vma) }; + // SAFETY: + // * The file is valid for the duration of this call. + // * There is no active fdget_pos region on the file on this thread. + let file = unsafe { File::from_raw_file(file) }; -/// # Safety -/// -/// `file` and `inode` must be the file and inode for a file that is being released. The file must -/// be associated with a `MiscDeviceRegistration<T>`. -unsafe extern "C" fn fops_release<T: MiscDevice>( - _inode: *mut bindings::inode, - file: *mut bindings::file, -) -> c_int { - // SAFETY: The release call of a file owns the private data. - let private = unsafe { (*file).private_data }; - // SAFETY: The release call of a file owns the private data. - let ptr = unsafe { <T::Ptr as ForeignOwnable>::from_foreign(private) }; - - T::release(ptr); - - 0 -} + match T::mmap(device, file, area) { + Ok(()) => 0, + Err(err) => err.to_errno(), + } + } -/// # Safety -/// -/// `file` must be a valid file that is associated with a `MiscDeviceRegistration<T>`. -unsafe extern "C" fn fops_ioctl<T: MiscDevice>( - file: *mut bindings::file, - cmd: c_uint, - arg: c_ulong, -) -> c_long { - // SAFETY: The ioctl call of a file can access the private data. - let private = unsafe { (*file).private_data }; - // SAFETY: Ioctl calls can borrow the private data of the file. - let device = unsafe { <T::Ptr as ForeignOwnable>::borrow(private) }; - - match T::ioctl(device, cmd, arg) { - Ok(ret) => ret as c_long, - Err(err) => err.to_errno() as c_long, + /// # Safety + /// + /// `file` must be a valid file that is associated with a `MiscDeviceRegistration<T>`. + unsafe extern "C" fn ioctl(file: *mut bindings::file, cmd: c_uint, arg: c_ulong) -> c_long { + // SAFETY: The ioctl call of a file can access the private data. + let private = unsafe { (*file).private_data }; + // SAFETY: Ioctl calls can borrow the private data of the file. + let device = unsafe { <T::Ptr as ForeignOwnable>::borrow(private) }; + + // SAFETY: + // * The file is valid for the duration of this call. + // * There is no active fdget_pos region on the file on this thread. + let file = unsafe { File::from_raw_file(file) }; + + match T::ioctl(device, file, cmd, arg) { + Ok(ret) => ret as c_long, + Err(err) => err.to_errno() as c_long, + } } -} -/// # Safety -/// -/// `file` must be a valid file that is associated with a `MiscDeviceRegistration<T>`. -#[cfg(CONFIG_COMPAT)] -unsafe extern "C" fn fops_compat_ioctl<T: MiscDevice>( - file: *mut bindings::file, - cmd: c_uint, - arg: c_ulong, -) -> c_long { - // SAFETY: The compat ioctl call of a file can access the private data. - let private = unsafe { (*file).private_data }; - // SAFETY: Ioctl calls can borrow the private data of the file. - let device = unsafe { <T::Ptr as ForeignOwnable>::borrow(private) }; - - match T::compat_ioctl(device, cmd, arg) { - Ok(ret) => ret as c_long, - Err(err) => err.to_errno() as c_long, + /// # Safety + /// + /// `file` must be a valid file that is associated with a `MiscDeviceRegistration<T>`. + #[cfg(CONFIG_COMPAT)] + unsafe extern "C" fn compat_ioctl( + file: *mut bindings::file, + cmd: c_uint, + arg: c_ulong, + ) -> c_long { + // SAFETY: The compat ioctl call of a file can access the private data. + let private = unsafe { (*file).private_data }; + // SAFETY: Ioctl calls can borrow the private data of the file. + let device = unsafe { <T::Ptr as ForeignOwnable>::borrow(private) }; + + // SAFETY: + // * The file is valid for the duration of this call. + // * There is no active fdget_pos region on the file on this thread. + let file = unsafe { File::from_raw_file(file) }; + + match T::compat_ioctl(device, file, cmd, arg) { + Ok(ret) => ret as c_long, + Err(err) => err.to_errno() as c_long, + } + } + + /// # Safety + /// + /// - `file` must be a valid file that is associated with a `MiscDeviceRegistration<T>`. + /// - `seq_file` must be a valid `struct seq_file` that we can write to. + unsafe extern "C" fn show_fdinfo(seq_file: *mut bindings::seq_file, file: *mut bindings::file) { + // SAFETY: The release call of a file owns the private data. + let private = unsafe { (*file).private_data }; + // SAFETY: Ioctl calls can borrow the private data of the file. + let device = unsafe { <T::Ptr as ForeignOwnable>::borrow(private) }; + // SAFETY: + // * The file is valid for the duration of this call. + // * There is no active fdget_pos region on the file on this thread. + let file = unsafe { File::from_raw_file(file) }; + // SAFETY: The caller ensures that the pointer is valid and exclusive for the duration in + // which this method is called. + let m = unsafe { SeqFile::from_raw(seq_file) }; + + T::show_fdinfo(device, m, file); + } + + const VTABLE: bindings::file_operations = bindings::file_operations { + open: Some(Self::open), + release: Some(Self::release), + mmap: if T::HAS_MMAP { Some(Self::mmap) } else { None }, + unlocked_ioctl: if T::HAS_IOCTL { + Some(Self::ioctl) + } else { + None + }, + #[cfg(CONFIG_COMPAT)] + compat_ioctl: if T::HAS_COMPAT_IOCTL { + Some(Self::compat_ioctl) + } else if T::HAS_IOCTL { + Some(bindings::compat_ptr_ioctl) + } else { + None + }, + show_fdinfo: if T::HAS_SHOW_FDINFO { + Some(Self::show_fdinfo) + } else { + None + }, + // SAFETY: All zeros is a valid value for `bindings::file_operations`. + ..unsafe { MaybeUninit::zeroed().assume_init() } + }; + + const fn build() -> &'static bindings::file_operations { + &Self::VTABLE } } diff --git a/rust/kernel/mm.rs b/rust/kernel/mm.rs new file mode 100644 index 000000000000..43f525c0d16c --- /dev/null +++ b/rust/kernel/mm.rs @@ -0,0 +1,296 @@ +// SPDX-License-Identifier: GPL-2.0 + +// Copyright (C) 2024 Google LLC. + +//! Memory management. +//! +//! This module deals with managing the address space of userspace processes. Each process has an +//! instance of [`Mm`], which keeps track of multiple VMAs (virtual memory areas). Each VMA +//! corresponds to a region of memory that the userspace process can access, and the VMA lets you +//! control what happens when userspace reads or writes to that region of memory. +//! +//! C header: [`include/linux/mm.h`](srctree/include/linux/mm.h) + +use crate::{ + bindings, + types::{ARef, AlwaysRefCounted, NotThreadSafe, Opaque}, +}; +use core::{ops::Deref, ptr::NonNull}; + +pub mod virt; +use virt::VmaRef; + +#[cfg(CONFIG_MMU)] +pub use mmput_async::MmWithUserAsync; +mod mmput_async; + +/// A wrapper for the kernel's `struct mm_struct`. +/// +/// This represents the address space of a userspace process, so each process has one `Mm` +/// instance. It may hold many VMAs internally. +/// +/// There is a counter called `mm_users` that counts the users of the address space; this includes +/// the userspace process itself, but can also include kernel threads accessing the address space. +/// Once `mm_users` reaches zero, this indicates that the address space can be destroyed. To access +/// the address space, you must prevent `mm_users` from reaching zero while you are accessing it. +/// The [`MmWithUser`] type represents an address space where this is guaranteed, and you can +/// create one using [`mmget_not_zero`]. +/// +/// The `ARef<Mm>` smart pointer holds an `mmgrab` refcount. Its destructor may sleep. +/// +/// # Invariants +/// +/// Values of this type are always refcounted using `mmgrab`. +/// +/// [`mmget_not_zero`]: Mm::mmget_not_zero +#[repr(transparent)] +pub struct Mm { + mm: Opaque<bindings::mm_struct>, +} + +// SAFETY: It is safe to call `mmdrop` on another thread than where `mmgrab` was called. +unsafe impl Send for Mm {} +// SAFETY: All methods on `Mm` can be called in parallel from several threads. +unsafe impl Sync for Mm {} + +// SAFETY: By the type invariants, this type is always refcounted. +unsafe impl AlwaysRefCounted for Mm { + #[inline] + fn inc_ref(&self) { + // SAFETY: The pointer is valid since self is a reference. + unsafe { bindings::mmgrab(self.as_raw()) }; + } + + #[inline] + unsafe fn dec_ref(obj: NonNull<Self>) { + // SAFETY: The caller is giving up their refcount. + unsafe { bindings::mmdrop(obj.cast().as_ptr()) }; + } +} + +/// A wrapper for the kernel's `struct mm_struct`. +/// +/// This type is like [`Mm`], but with non-zero `mm_users`. It can only be used when `mm_users` can +/// be proven to be non-zero at compile-time, usually because the relevant code holds an `mmget` +/// refcount. It can be used to access the associated address space. +/// +/// The `ARef<MmWithUser>` smart pointer holds an `mmget` refcount. Its destructor may sleep. +/// +/// # Invariants +/// +/// Values of this type are always refcounted using `mmget`. The value of `mm_users` is non-zero. +#[repr(transparent)] +pub struct MmWithUser { + mm: Mm, +} + +// SAFETY: It is safe to call `mmput` on another thread than where `mmget` was called. +unsafe impl Send for MmWithUser {} +// SAFETY: All methods on `MmWithUser` can be called in parallel from several threads. +unsafe impl Sync for MmWithUser {} + +// SAFETY: By the type invariants, this type is always refcounted. +unsafe impl AlwaysRefCounted for MmWithUser { + #[inline] + fn inc_ref(&self) { + // SAFETY: The pointer is valid since self is a reference. + unsafe { bindings::mmget(self.as_raw()) }; + } + + #[inline] + unsafe fn dec_ref(obj: NonNull<Self>) { + // SAFETY: The caller is giving up their refcount. + unsafe { bindings::mmput(obj.cast().as_ptr()) }; + } +} + +// Make all `Mm` methods available on `MmWithUser`. +impl Deref for MmWithUser { + type Target = Mm; + + #[inline] + fn deref(&self) -> &Mm { + &self.mm + } +} + +// These methods are safe to call even if `mm_users` is zero. +impl Mm { + /// Returns a raw pointer to the inner `mm_struct`. + #[inline] + pub fn as_raw(&self) -> *mut bindings::mm_struct { + self.mm.get() + } + + /// Obtain a reference from a raw pointer. + /// + /// # Safety + /// + /// The caller must ensure that `ptr` points at an `mm_struct`, and that it is not deallocated + /// during the lifetime 'a. + #[inline] + pub unsafe fn from_raw<'a>(ptr: *const bindings::mm_struct) -> &'a Mm { + // SAFETY: Caller promises that the pointer is valid for 'a. Layouts are compatible due to + // repr(transparent). + unsafe { &*ptr.cast() } + } + + /// Calls `mmget_not_zero` and returns a handle if it succeeds. + #[inline] + pub fn mmget_not_zero(&self) -> Option<ARef<MmWithUser>> { + // SAFETY: The pointer is valid since self is a reference. + let success = unsafe { bindings::mmget_not_zero(self.as_raw()) }; + + if success { + // SAFETY: We just created an `mmget` refcount. + Some(unsafe { ARef::from_raw(NonNull::new_unchecked(self.as_raw().cast())) }) + } else { + None + } + } +} + +// These methods require `mm_users` to be non-zero. +impl MmWithUser { + /// Obtain a reference from a raw pointer. + /// + /// # Safety + /// + /// The caller must ensure that `ptr` points at an `mm_struct`, and that `mm_users` remains + /// non-zero for the duration of the lifetime 'a. + #[inline] + pub unsafe fn from_raw<'a>(ptr: *const bindings::mm_struct) -> &'a MmWithUser { + // SAFETY: Caller promises that the pointer is valid for 'a. The layout is compatible due + // to repr(transparent). + unsafe { &*ptr.cast() } + } + + /// Attempt to access a vma using the vma read lock. + /// + /// This is an optimistic trylock operation, so it may fail if there is contention. In that + /// case, you should fall back to taking the mmap read lock. + /// + /// When per-vma locks are disabled, this always returns `None`. + #[inline] + pub fn lock_vma_under_rcu(&self, vma_addr: usize) -> Option<VmaReadGuard<'_>> { + #[cfg(CONFIG_PER_VMA_LOCK)] + { + // SAFETY: Calling `bindings::lock_vma_under_rcu` is always okay given an mm where + // `mm_users` is non-zero. + let vma = unsafe { bindings::lock_vma_under_rcu(self.as_raw(), vma_addr) }; + if !vma.is_null() { + return Some(VmaReadGuard { + // SAFETY: If `lock_vma_under_rcu` returns a non-null ptr, then it points at a + // valid vma. The vma is stable for as long as the vma read lock is held. + vma: unsafe { VmaRef::from_raw(vma) }, + _nts: NotThreadSafe, + }); + } + } + + // Silence warnings about unused variables. + #[cfg(not(CONFIG_PER_VMA_LOCK))] + let _ = vma_addr; + + None + } + + /// Lock the mmap read lock. + #[inline] + pub fn mmap_read_lock(&self) -> MmapReadGuard<'_> { + // SAFETY: The pointer is valid since self is a reference. + unsafe { bindings::mmap_read_lock(self.as_raw()) }; + + // INVARIANT: We just acquired the read lock. + MmapReadGuard { + mm: self, + _nts: NotThreadSafe, + } + } + + /// Try to lock the mmap read lock. + #[inline] + pub fn mmap_read_trylock(&self) -> Option<MmapReadGuard<'_>> { + // SAFETY: The pointer is valid since self is a reference. + let success = unsafe { bindings::mmap_read_trylock(self.as_raw()) }; + + if success { + // INVARIANT: We just acquired the read lock. + Some(MmapReadGuard { + mm: self, + _nts: NotThreadSafe, + }) + } else { + None + } + } +} + +/// A guard for the mmap read lock. +/// +/// # Invariants +/// +/// This `MmapReadGuard` guard owns the mmap read lock. +pub struct MmapReadGuard<'a> { + mm: &'a MmWithUser, + // `mmap_read_lock` and `mmap_read_unlock` must be called on the same thread + _nts: NotThreadSafe, +} + +impl<'a> MmapReadGuard<'a> { + /// Look up a vma at the given address. + #[inline] + pub fn vma_lookup(&self, vma_addr: usize) -> Option<&virt::VmaRef> { + // SAFETY: By the type invariants we hold the mmap read guard, so we can safely call this + // method. Any value is okay for `vma_addr`. + let vma = unsafe { bindings::vma_lookup(self.mm.as_raw(), vma_addr) }; + + if vma.is_null() { + None + } else { + // SAFETY: We just checked that a vma was found, so the pointer references a valid vma. + // + // Furthermore, the returned vma is still under the protection of the read lock guard + // and can be used while the mmap read lock is still held. That the vma is not used + // after the MmapReadGuard gets dropped is enforced by the borrow-checker. + unsafe { Some(virt::VmaRef::from_raw(vma)) } + } + } +} + +impl Drop for MmapReadGuard<'_> { + #[inline] + fn drop(&mut self) { + // SAFETY: We hold the read lock by the type invariants. + unsafe { bindings::mmap_read_unlock(self.mm.as_raw()) }; + } +} + +/// A guard for the vma read lock. +/// +/// # Invariants +/// +/// This `VmaReadGuard` guard owns the vma read lock. +pub struct VmaReadGuard<'a> { + vma: &'a VmaRef, + // `vma_end_read` must be called on the same thread as where the lock was taken + _nts: NotThreadSafe, +} + +// Make all `VmaRef` methods available on `VmaReadGuard`. +impl Deref for VmaReadGuard<'_> { + type Target = VmaRef; + + #[inline] + fn deref(&self) -> &VmaRef { + self.vma + } +} + +impl Drop for VmaReadGuard<'_> { + #[inline] + fn drop(&mut self) { + // SAFETY: We hold the read lock by the type invariants. + unsafe { bindings::vma_end_read(self.vma.as_ptr()) }; + } +} diff --git a/rust/kernel/mm/mmput_async.rs b/rust/kernel/mm/mmput_async.rs new file mode 100644 index 000000000000..9289e05f7a67 --- /dev/null +++ b/rust/kernel/mm/mmput_async.rs @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: GPL-2.0 + +// Copyright (C) 2024 Google LLC. + +//! Version of `MmWithUser` using `mmput_async`. +//! +//! This is a separate file from `mm.rs` due to the dependency on `CONFIG_MMU=y`. +#![cfg(CONFIG_MMU)] + +use crate::{ + bindings, + mm::MmWithUser, + types::{ARef, AlwaysRefCounted}, +}; +use core::{ops::Deref, ptr::NonNull}; + +/// A wrapper for the kernel's `struct mm_struct`. +/// +/// This type is identical to `MmWithUser` except that it uses `mmput_async` when dropping a +/// refcount. This means that the destructor of `ARef<MmWithUserAsync>` is safe to call in atomic +/// context. +/// +/// # Invariants +/// +/// Values of this type are always refcounted using `mmget`. The value of `mm_users` is non-zero. +#[repr(transparent)] +pub struct MmWithUserAsync { + mm: MmWithUser, +} + +// SAFETY: It is safe to call `mmput_async` on another thread than where `mmget` was called. +unsafe impl Send for MmWithUserAsync {} +// SAFETY: All methods on `MmWithUserAsync` can be called in parallel from several threads. +unsafe impl Sync for MmWithUserAsync {} + +// SAFETY: By the type invariants, this type is always refcounted. +unsafe impl AlwaysRefCounted for MmWithUserAsync { + #[inline] + fn inc_ref(&self) { + // SAFETY: The pointer is valid since self is a reference. + unsafe { bindings::mmget(self.as_raw()) }; + } + + #[inline] + unsafe fn dec_ref(obj: NonNull<Self>) { + // SAFETY: The caller is giving up their refcount. + unsafe { bindings::mmput_async(obj.cast().as_ptr()) }; + } +} + +// Make all `MmWithUser` methods available on `MmWithUserAsync`. +impl Deref for MmWithUserAsync { + type Target = MmWithUser; + + #[inline] + fn deref(&self) -> &MmWithUser { + &self.mm + } +} + +impl MmWithUser { + /// Use `mmput_async` when dropping this refcount. + #[inline] + pub fn into_mmput_async(me: ARef<MmWithUser>) -> ARef<MmWithUserAsync> { + // SAFETY: The layouts and invariants are compatible. + unsafe { ARef::from_raw(ARef::into_raw(me).cast()) } + } +} diff --git a/rust/kernel/mm/virt.rs b/rust/kernel/mm/virt.rs new file mode 100644 index 000000000000..a1bfa4e19293 --- /dev/null +++ b/rust/kernel/mm/virt.rs @@ -0,0 +1,472 @@ +// SPDX-License-Identifier: GPL-2.0 + +// Copyright (C) 2024 Google LLC. + +//! Virtual memory. +//! +//! This module deals with managing a single VMA in the address space of a userspace process. Each +//! VMA corresponds to a region of memory that the userspace process can access, and the VMA lets +//! you control what happens when userspace reads or writes to that region of memory. +//! +//! The module has several different Rust types that all correspond to the C type called +//! `vm_area_struct`. The different structs represent what kind of access you have to the VMA, e.g. +//! [`VmaRef`] is used when you hold the mmap or vma read lock. Using the appropriate struct +//! ensures that you can't, for example, accidentally call a function that requires holding the +//! write lock when you only hold the read lock. + +use crate::{ + bindings, + error::{code::EINVAL, to_result, Result}, + mm::MmWithUser, + page::Page, + types::Opaque, +}; + +use core::ops::Deref; + +/// A wrapper for the kernel's `struct vm_area_struct` with read access. +/// +/// It represents an area of virtual memory. +/// +/// # Invariants +/// +/// The caller must hold the mmap read lock or the vma read lock. +#[repr(transparent)] +pub struct VmaRef { + vma: Opaque<bindings::vm_area_struct>, +} + +// Methods you can call when holding the mmap or vma read lock (or stronger). They must be usable +// no matter what the vma flags are. +impl VmaRef { + /// Access a virtual memory area given a raw pointer. + /// + /// # Safety + /// + /// Callers must ensure that `vma` is valid for the duration of 'a, and that the mmap or vma + /// read lock (or stronger) is held for at least the duration of 'a. + #[inline] + pub unsafe fn from_raw<'a>(vma: *const bindings::vm_area_struct) -> &'a Self { + // SAFETY: The caller ensures that the invariants are satisfied for the duration of 'a. + unsafe { &*vma.cast() } + } + + /// Returns a raw pointer to this area. + #[inline] + pub fn as_ptr(&self) -> *mut bindings::vm_area_struct { + self.vma.get() + } + + /// Access the underlying `mm_struct`. + #[inline] + pub fn mm(&self) -> &MmWithUser { + // SAFETY: By the type invariants, this `vm_area_struct` is valid and we hold the mmap/vma + // read lock or stronger. This implies that the underlying mm has a non-zero value of + // `mm_users`. + unsafe { MmWithUser::from_raw((*self.as_ptr()).vm_mm) } + } + + /// Returns the flags associated with the virtual memory area. + /// + /// The possible flags are a combination of the constants in [`flags`]. + #[inline] + pub fn flags(&self) -> vm_flags_t { + // SAFETY: By the type invariants, the caller holds at least the mmap read lock, so this + // access is not a data race. + unsafe { (*self.as_ptr()).__bindgen_anon_2.vm_flags } + } + + /// Returns the (inclusive) start address of the virtual memory area. + #[inline] + pub fn start(&self) -> usize { + // SAFETY: By the type invariants, the caller holds at least the mmap read lock, so this + // access is not a data race. + unsafe { (*self.as_ptr()).__bindgen_anon_1.__bindgen_anon_1.vm_start } + } + + /// Returns the (exclusive) end address of the virtual memory area. + #[inline] + pub fn end(&self) -> usize { + // SAFETY: By the type invariants, the caller holds at least the mmap read lock, so this + // access is not a data race. + unsafe { (*self.as_ptr()).__bindgen_anon_1.__bindgen_anon_1.vm_end } + } + + /// Zap pages in the given page range. + /// + /// This clears page table mappings for the range at the leaf level, leaving all other page + /// tables intact, and freeing any memory referenced by the VMA in this range. That is, + /// anonymous memory is completely freed, file-backed memory has its reference count on page + /// cache folio's dropped, any dirty data will still be written back to disk as usual. + /// + /// It may seem odd that we clear at the leaf level, this is however a product of the page + /// table structure used to map physical memory into a virtual address space - each virtual + /// address actually consists of a bitmap of array indices into page tables, which form a + /// hierarchical page table level structure. + /// + /// As a result, each page table level maps a multiple of page table levels below, and thus + /// span ever increasing ranges of pages. At the leaf or PTE level, we map the actual physical + /// memory. + /// + /// It is here where a zap operates, as it the only place we can be certain of clearing without + /// impacting any other virtual mappings. It is an implementation detail as to whether the + /// kernel goes further in freeing unused page tables, but for the purposes of this operation + /// we must only assume that the leaf level is cleared. + #[inline] + pub fn zap_page_range_single(&self, address: usize, size: usize) { + let (end, did_overflow) = address.overflowing_add(size); + if did_overflow || address < self.start() || self.end() < end { + // TODO: call WARN_ONCE once Rust version of it is added + return; + } + + // SAFETY: By the type invariants, the caller has read access to this VMA, which is + // sufficient for this method call. This method has no requirements on the vma flags. The + // address range is checked to be within the vma. + unsafe { + bindings::zap_page_range_single(self.as_ptr(), address, size, core::ptr::null_mut()) + }; + } + + /// If the [`VM_MIXEDMAP`] flag is set, returns a [`VmaMixedMap`] to this VMA, otherwise + /// returns `None`. + /// + /// This can be used to access methods that require [`VM_MIXEDMAP`] to be set. + /// + /// [`VM_MIXEDMAP`]: flags::MIXEDMAP + #[inline] + pub fn as_mixedmap_vma(&self) -> Option<&VmaMixedMap> { + if self.flags() & flags::MIXEDMAP != 0 { + // SAFETY: We just checked that `VM_MIXEDMAP` is set. All other requirements are + // satisfied by the type invariants of `VmaRef`. + Some(unsafe { VmaMixedMap::from_raw(self.as_ptr()) }) + } else { + None + } + } +} + +/// A wrapper for the kernel's `struct vm_area_struct` with read access and [`VM_MIXEDMAP`] set. +/// +/// It represents an area of virtual memory. +/// +/// This struct is identical to [`VmaRef`] except that it must only be used when the +/// [`VM_MIXEDMAP`] flag is set on the vma. +/// +/// # Invariants +/// +/// The caller must hold the mmap read lock or the vma read lock. The `VM_MIXEDMAP` flag must be +/// set. +/// +/// [`VM_MIXEDMAP`]: flags::MIXEDMAP +#[repr(transparent)] +pub struct VmaMixedMap { + vma: VmaRef, +} + +// Make all `VmaRef` methods available on `VmaMixedMap`. +impl Deref for VmaMixedMap { + type Target = VmaRef; + + #[inline] + fn deref(&self) -> &VmaRef { + &self.vma + } +} + +impl VmaMixedMap { + /// Access a virtual memory area given a raw pointer. + /// + /// # Safety + /// + /// Callers must ensure that `vma` is valid for the duration of 'a, and that the mmap read lock + /// (or stronger) is held for at least the duration of 'a. The `VM_MIXEDMAP` flag must be set. + #[inline] + pub unsafe fn from_raw<'a>(vma: *const bindings::vm_area_struct) -> &'a Self { + // SAFETY: The caller ensures that the invariants are satisfied for the duration of 'a. + unsafe { &*vma.cast() } + } + + /// Maps a single page at the given address within the virtual memory area. + /// + /// This operation does not take ownership of the page. + #[inline] + pub fn vm_insert_page(&self, address: usize, page: &Page) -> Result { + // SAFETY: By the type invariant of `Self` caller has read access and has verified that + // `VM_MIXEDMAP` is set. By invariant on `Page` the page has order 0. + to_result(unsafe { bindings::vm_insert_page(self.as_ptr(), address, page.as_ptr()) }) + } +} + +/// A configuration object for setting up a VMA in an `f_ops->mmap()` hook. +/// +/// The `f_ops->mmap()` hook is called when a new VMA is being created, and the hook is able to +/// configure the VMA in various ways to fit the driver that owns it. Using `VmaNew` indicates that +/// you are allowed to perform operations on the VMA that can only be performed before the VMA is +/// fully initialized. +/// +/// # Invariants +/// +/// For the duration of 'a, the referenced vma must be undergoing initialization in an +/// `f_ops->mmap()` hook. +#[repr(transparent)] +pub struct VmaNew { + vma: VmaRef, +} + +// Make all `VmaRef` methods available on `VmaNew`. +impl Deref for VmaNew { + type Target = VmaRef; + + #[inline] + fn deref(&self) -> &VmaRef { + &self.vma + } +} + +impl VmaNew { + /// Access a virtual memory area given a raw pointer. + /// + /// # Safety + /// + /// Callers must ensure that `vma` is undergoing initial vma setup for the duration of 'a. + #[inline] + pub unsafe fn from_raw<'a>(vma: *mut bindings::vm_area_struct) -> &'a Self { + // SAFETY: The caller ensures that the invariants are satisfied for the duration of 'a. + unsafe { &*vma.cast() } + } + + /// Internal method for updating the vma flags. + /// + /// # Safety + /// + /// This must not be used to set the flags to an invalid value. + #[inline] + unsafe fn update_flags(&self, set: vm_flags_t, unset: vm_flags_t) { + let mut flags = self.flags(); + flags |= set; + flags &= !unset; + + // SAFETY: This is not a data race: the vma is undergoing initial setup, so it's not yet + // shared. Additionally, `VmaNew` is `!Sync`, so it cannot be used to write in parallel. + // The caller promises that this does not set the flags to an invalid value. + unsafe { (*self.as_ptr()).__bindgen_anon_2.__vm_flags = flags }; + } + + /// Set the `VM_MIXEDMAP` flag on this vma. + /// + /// This enables the vma to contain both `struct page` and pure PFN pages. Returns a reference + /// that can be used to call `vm_insert_page` on the vma. + #[inline] + pub fn set_mixedmap(&self) -> &VmaMixedMap { + // SAFETY: We don't yet provide a way to set VM_PFNMAP, so this cannot put the flags in an + // invalid state. + unsafe { self.update_flags(flags::MIXEDMAP, 0) }; + + // SAFETY: We just set `VM_MIXEDMAP` on the vma. + unsafe { VmaMixedMap::from_raw(self.vma.as_ptr()) } + } + + /// Set the `VM_IO` flag on this vma. + /// + /// This is used for memory mapped IO and similar. The flag tells other parts of the kernel to + /// avoid looking at the pages. For memory mapped IO this is useful as accesses to the pages + /// could have side effects. + #[inline] + pub fn set_io(&self) { + // SAFETY: Setting the VM_IO flag is always okay. + unsafe { self.update_flags(flags::IO, 0) }; + } + + /// Set the `VM_DONTEXPAND` flag on this vma. + /// + /// This prevents the vma from being expanded with `mremap()`. + #[inline] + pub fn set_dontexpand(&self) { + // SAFETY: Setting the VM_DONTEXPAND flag is always okay. + unsafe { self.update_flags(flags::DONTEXPAND, 0) }; + } + + /// Set the `VM_DONTCOPY` flag on this vma. + /// + /// This prevents the vma from being copied on fork. This option is only permanent if `VM_IO` + /// is set. + #[inline] + pub fn set_dontcopy(&self) { + // SAFETY: Setting the VM_DONTCOPY flag is always okay. + unsafe { self.update_flags(flags::DONTCOPY, 0) }; + } + + /// Set the `VM_DONTDUMP` flag on this vma. + /// + /// This prevents the vma from being included in core dumps. This option is only permanent if + /// `VM_IO` is set. + #[inline] + pub fn set_dontdump(&self) { + // SAFETY: Setting the VM_DONTDUMP flag is always okay. + unsafe { self.update_flags(flags::DONTDUMP, 0) }; + } + + /// Returns whether `VM_READ` is set. + /// + /// This flag indicates whether userspace is mapping this vma as readable. + #[inline] + pub fn readable(&self) -> bool { + (self.flags() & flags::READ) != 0 + } + + /// Try to clear the `VM_MAYREAD` flag, failing if `VM_READ` is set. + /// + /// This flag indicates whether userspace is allowed to make this vma readable with + /// `mprotect()`. + /// + /// Note that this operation is irreversible. Once `VM_MAYREAD` has been cleared, it can never + /// be set again. + #[inline] + pub fn try_clear_mayread(&self) -> Result { + if self.readable() { + return Err(EINVAL); + } + // SAFETY: Clearing `VM_MAYREAD` is okay when `VM_READ` is not set. + unsafe { self.update_flags(0, flags::MAYREAD) }; + Ok(()) + } + + /// Returns whether `VM_WRITE` is set. + /// + /// This flag indicates whether userspace is mapping this vma as writable. + #[inline] + pub fn writable(&self) -> bool { + (self.flags() & flags::WRITE) != 0 + } + + /// Try to clear the `VM_MAYWRITE` flag, failing if `VM_WRITE` is set. + /// + /// This flag indicates whether userspace is allowed to make this vma writable with + /// `mprotect()`. + /// + /// Note that this operation is irreversible. Once `VM_MAYWRITE` has been cleared, it can never + /// be set again. + #[inline] + pub fn try_clear_maywrite(&self) -> Result { + if self.writable() { + return Err(EINVAL); + } + // SAFETY: Clearing `VM_MAYWRITE` is okay when `VM_WRITE` is not set. + unsafe { self.update_flags(0, flags::MAYWRITE) }; + Ok(()) + } + + /// Returns whether `VM_EXEC` is set. + /// + /// This flag indicates whether userspace is mapping this vma as executable. + #[inline] + pub fn executable(&self) -> bool { + (self.flags() & flags::EXEC) != 0 + } + + /// Try to clear the `VM_MAYEXEC` flag, failing if `VM_EXEC` is set. + /// + /// This flag indicates whether userspace is allowed to make this vma executable with + /// `mprotect()`. + /// + /// Note that this operation is irreversible. Once `VM_MAYEXEC` has been cleared, it can never + /// be set again. + #[inline] + pub fn try_clear_mayexec(&self) -> Result { + if self.executable() { + return Err(EINVAL); + } + // SAFETY: Clearing `VM_MAYEXEC` is okay when `VM_EXEC` is not set. + unsafe { self.update_flags(0, flags::MAYEXEC) }; + Ok(()) + } +} + +/// The integer type used for vma flags. +#[doc(inline)] +pub use bindings::vm_flags_t; + +/// All possible flags for [`VmaRef`]. +pub mod flags { + use super::vm_flags_t; + use crate::bindings; + + /// No flags are set. + pub const NONE: vm_flags_t = bindings::VM_NONE as vm_flags_t; + + /// Mapping allows reads. + pub const READ: vm_flags_t = bindings::VM_READ as vm_flags_t; + + /// Mapping allows writes. + pub const WRITE: vm_flags_t = bindings::VM_WRITE as vm_flags_t; + + /// Mapping allows execution. + pub const EXEC: vm_flags_t = bindings::VM_EXEC as vm_flags_t; + + /// Mapping is shared. + pub const SHARED: vm_flags_t = bindings::VM_SHARED as vm_flags_t; + + /// Mapping may be updated to allow reads. + pub const MAYREAD: vm_flags_t = bindings::VM_MAYREAD as vm_flags_t; + + /// Mapping may be updated to allow writes. + pub const MAYWRITE: vm_flags_t = bindings::VM_MAYWRITE as vm_flags_t; + + /// Mapping may be updated to allow execution. + pub const MAYEXEC: vm_flags_t = bindings::VM_MAYEXEC as vm_flags_t; + + /// Mapping may be updated to be shared. + pub const MAYSHARE: vm_flags_t = bindings::VM_MAYSHARE as vm_flags_t; + + /// Page-ranges managed without `struct page`, just pure PFN. + pub const PFNMAP: vm_flags_t = bindings::VM_PFNMAP as vm_flags_t; + + /// Memory mapped I/O or similar. + pub const IO: vm_flags_t = bindings::VM_IO as vm_flags_t; + + /// Do not copy this vma on fork. + pub const DONTCOPY: vm_flags_t = bindings::VM_DONTCOPY as vm_flags_t; + + /// Cannot expand with mremap(). + pub const DONTEXPAND: vm_flags_t = bindings::VM_DONTEXPAND as vm_flags_t; + + /// Lock the pages covered when they are faulted in. + pub const LOCKONFAULT: vm_flags_t = bindings::VM_LOCKONFAULT as vm_flags_t; + + /// Is a VM accounted object. + pub const ACCOUNT: vm_flags_t = bindings::VM_ACCOUNT as vm_flags_t; + + /// Should the VM suppress accounting. + pub const NORESERVE: vm_flags_t = bindings::VM_NORESERVE as vm_flags_t; + + /// Huge TLB Page VM. + pub const HUGETLB: vm_flags_t = bindings::VM_HUGETLB as vm_flags_t; + + /// Synchronous page faults. (DAX-specific) + pub const SYNC: vm_flags_t = bindings::VM_SYNC as vm_flags_t; + + /// Architecture-specific flag. + pub const ARCH_1: vm_flags_t = bindings::VM_ARCH_1 as vm_flags_t; + + /// Wipe VMA contents in child on fork. + pub const WIPEONFORK: vm_flags_t = bindings::VM_WIPEONFORK as vm_flags_t; + + /// Do not include in the core dump. + pub const DONTDUMP: vm_flags_t = bindings::VM_DONTDUMP as vm_flags_t; + + /// Not soft dirty clean area. + pub const SOFTDIRTY: vm_flags_t = bindings::VM_SOFTDIRTY as vm_flags_t; + + /// Can contain `struct page` and pure PFN pages. + pub const MIXEDMAP: vm_flags_t = bindings::VM_MIXEDMAP as vm_flags_t; + + /// MADV_HUGEPAGE marked this vma. + pub const HUGEPAGE: vm_flags_t = bindings::VM_HUGEPAGE as vm_flags_t; + + /// MADV_NOHUGEPAGE marked this vma. + pub const NOHUGEPAGE: vm_flags_t = bindings::VM_NOHUGEPAGE as vm_flags_t; + + /// KSM may merge identical pages. + pub const MERGEABLE: vm_flags_t = bindings::VM_MERGEABLE as vm_flags_t; +} diff --git a/rust/kernel/net/phy.rs b/rust/kernel/net/phy.rs index 2fbfb6a94c11..7de5cc7a0eee 100644 --- a/rust/kernel/net/phy.rs +++ b/rust/kernel/net/phy.rs @@ -6,7 +6,7 @@ //! //! C headers: [`include/linux/phy.h`](srctree/include/linux/phy.h). -use crate::{error::*, prelude::*, types::Opaque}; +use crate::{device_id::RawDeviceId, error::*, prelude::*, types::Opaque}; use core::{marker::PhantomData, ptr::addr_of_mut}; pub mod reg; @@ -142,7 +142,7 @@ impl Device { // SAFETY: The struct invariant ensures that we may access // this field without additional synchronization. let bit_field = unsafe { &(*self.0.get())._bitfield_1 }; - bit_field.get(13, 1) == bindings::AUTONEG_ENABLE as u64 + bit_field.get(13, 1) == u64::from(bindings::AUTONEG_ENABLE) } /// Gets the current auto-negotiation state. @@ -163,20 +163,20 @@ impl Device { let phydev = self.0.get(); // SAFETY: The struct invariant ensures that we may access // this field without additional synchronization. - unsafe { (*phydev).speed = speed as i32 }; + unsafe { (*phydev).speed = speed as c_int }; } /// Sets duplex mode. pub fn set_duplex(&mut self, mode: DuplexMode) { let phydev = self.0.get(); let v = match mode { - DuplexMode::Full => bindings::DUPLEX_FULL as i32, - DuplexMode::Half => bindings::DUPLEX_HALF as i32, - DuplexMode::Unknown => bindings::DUPLEX_UNKNOWN as i32, + DuplexMode::Full => bindings::DUPLEX_FULL, + DuplexMode::Half => bindings::DUPLEX_HALF, + DuplexMode::Unknown => bindings::DUPLEX_UNKNOWN, }; // SAFETY: The struct invariant ensures that we may access // this field without additional synchronization. - unsafe { (*phydev).duplex = v }; + unsafe { (*phydev).duplex = v as c_int }; } /// Reads a PHY register. @@ -285,7 +285,7 @@ impl AsRef<kernel::device::Device> for Device { fn as_ref(&self) -> &kernel::device::Device { let phydev = self.0.get(); // SAFETY: The struct invariant ensures that `mdio.dev` is valid. - unsafe { kernel::device::Device::as_ref(addr_of_mut!((*phydev).mdio.dev)) } + unsafe { kernel::device::Device::from_raw(addr_of_mut!((*phydev).mdio.dev)) } } } @@ -312,9 +312,7 @@ impl<T: Driver> Adapter<T> { /// # Safety /// /// `phydev` must be passed by the corresponding callback in `phy_driver`. - unsafe extern "C" fn soft_reset_callback( - phydev: *mut bindings::phy_device, - ) -> crate::ffi::c_int { + unsafe extern "C" fn soft_reset_callback(phydev: *mut bindings::phy_device) -> c_int { from_result(|| { // SAFETY: This callback is called only in contexts // where we hold `phy_device->lock`, so the accessors on @@ -328,7 +326,7 @@ impl<T: Driver> Adapter<T> { /// # Safety /// /// `phydev` must be passed by the corresponding callback in `phy_driver`. - unsafe extern "C" fn probe_callback(phydev: *mut bindings::phy_device) -> crate::ffi::c_int { + unsafe extern "C" fn probe_callback(phydev: *mut bindings::phy_device) -> c_int { from_result(|| { // SAFETY: This callback is called only in contexts // where we can exclusively access `phy_device` because @@ -343,9 +341,7 @@ impl<T: Driver> Adapter<T> { /// # Safety /// /// `phydev` must be passed by the corresponding callback in `phy_driver`. - unsafe extern "C" fn get_features_callback( - phydev: *mut bindings::phy_device, - ) -> crate::ffi::c_int { + unsafe extern "C" fn get_features_callback(phydev: *mut bindings::phy_device) -> c_int { from_result(|| { // SAFETY: This callback is called only in contexts // where we hold `phy_device->lock`, so the accessors on @@ -359,7 +355,7 @@ impl<T: Driver> Adapter<T> { /// # Safety /// /// `phydev` must be passed by the corresponding callback in `phy_driver`. - unsafe extern "C" fn suspend_callback(phydev: *mut bindings::phy_device) -> crate::ffi::c_int { + unsafe extern "C" fn suspend_callback(phydev: *mut bindings::phy_device) -> c_int { from_result(|| { // SAFETY: The C core code ensures that the accessors on // `Device` are okay to call even though `phy_device->lock` @@ -373,7 +369,7 @@ impl<T: Driver> Adapter<T> { /// # Safety /// /// `phydev` must be passed by the corresponding callback in `phy_driver`. - unsafe extern "C" fn resume_callback(phydev: *mut bindings::phy_device) -> crate::ffi::c_int { + unsafe extern "C" fn resume_callback(phydev: *mut bindings::phy_device) -> c_int { from_result(|| { // SAFETY: The C core code ensures that the accessors on // `Device` are okay to call even though `phy_device->lock` @@ -387,9 +383,7 @@ impl<T: Driver> Adapter<T> { /// # Safety /// /// `phydev` must be passed by the corresponding callback in `phy_driver`. - unsafe extern "C" fn config_aneg_callback( - phydev: *mut bindings::phy_device, - ) -> crate::ffi::c_int { + unsafe extern "C" fn config_aneg_callback(phydev: *mut bindings::phy_device) -> c_int { from_result(|| { // SAFETY: This callback is called only in contexts // where we hold `phy_device->lock`, so the accessors on @@ -403,9 +397,7 @@ impl<T: Driver> Adapter<T> { /// # Safety /// /// `phydev` must be passed by the corresponding callback in `phy_driver`. - unsafe extern "C" fn read_status_callback( - phydev: *mut bindings::phy_device, - ) -> crate::ffi::c_int { + unsafe extern "C" fn read_status_callback(phydev: *mut bindings::phy_device) -> c_int { from_result(|| { // SAFETY: This callback is called only in contexts // where we hold `phy_device->lock`, so the accessors on @@ -421,12 +413,13 @@ impl<T: Driver> Adapter<T> { /// `phydev` must be passed by the corresponding callback in `phy_driver`. unsafe extern "C" fn match_phy_device_callback( phydev: *mut bindings::phy_device, - ) -> crate::ffi::c_int { + _phydrv: *const bindings::phy_driver, + ) -> c_int { // SAFETY: This callback is called only in contexts // where we hold `phy_device->lock`, so the accessors on // `Device` are okay to call. let dev = unsafe { Device::from_raw(phydev) }; - T::match_phy_device(dev) as i32 + T::match_phy_device(dev).into() } /// # Safety @@ -506,7 +499,7 @@ pub const fn create_phy_driver<T: Driver>() -> DriverVTable { DriverVTable(Opaque::new(bindings::phy_driver { name: T::NAME.as_char_ptr().cast_mut(), flags: T::FLAGS, - phy_id: T::PHY_DEVICE_ID.id, + phy_id: T::PHY_DEVICE_ID.id(), phy_id_mask: T::PHY_DEVICE_ID.mask_as_int(), soft_reset: if T::HAS_SOFT_RESET { Some(Adapter::<T>::soft_reset_callback) @@ -587,17 +580,17 @@ pub trait Driver { /// Issues a PHY software reset. fn soft_reset(_dev: &mut Device) -> Result { - kernel::build_error(VTABLE_DEFAULT_ERROR) + build_error!(VTABLE_DEFAULT_ERROR) } /// Sets up device-specific structures during discovery. fn probe(_dev: &mut Device) -> Result { - kernel::build_error(VTABLE_DEFAULT_ERROR) + build_error!(VTABLE_DEFAULT_ERROR) } /// Probes the hardware to determine what abilities it has. fn get_features(_dev: &mut Device) -> Result { - kernel::build_error(VTABLE_DEFAULT_ERROR) + build_error!(VTABLE_DEFAULT_ERROR) } /// Returns true if this is a suitable driver for the given phydev. @@ -609,32 +602,32 @@ pub trait Driver { /// Configures the advertisement and resets auto-negotiation /// if auto-negotiation is enabled. fn config_aneg(_dev: &mut Device) -> Result { - kernel::build_error(VTABLE_DEFAULT_ERROR) + build_error!(VTABLE_DEFAULT_ERROR) } /// Determines the negotiated speed and duplex. fn read_status(_dev: &mut Device) -> Result<u16> { - kernel::build_error(VTABLE_DEFAULT_ERROR) + build_error!(VTABLE_DEFAULT_ERROR) } /// Suspends the hardware, saving state if needed. fn suspend(_dev: &mut Device) -> Result { - kernel::build_error(VTABLE_DEFAULT_ERROR) + build_error!(VTABLE_DEFAULT_ERROR) } /// Resumes the hardware, restoring state if needed. fn resume(_dev: &mut Device) -> Result { - kernel::build_error(VTABLE_DEFAULT_ERROR) + build_error!(VTABLE_DEFAULT_ERROR) } /// Overrides the default MMD read function for reading a MMD register. fn read_mmd(_dev: &mut Device, _devnum: u8, _regnum: u16) -> Result<u16> { - kernel::build_error(VTABLE_DEFAULT_ERROR) + build_error!(VTABLE_DEFAULT_ERROR) } /// Overrides the default MMD write function for writing a MMD register. fn write_mmd(_dev: &mut Device, _devnum: u8, _regnum: u16, _val: u16) -> Result { - kernel::build_error(VTABLE_DEFAULT_ERROR) + build_error!(VTABLE_DEFAULT_ERROR) } /// Callback for notification of link change. @@ -690,42 +683,41 @@ impl Drop for Registration { /// /// Represents the kernel's `struct mdio_device_id`. This is used to find an appropriate /// PHY driver. -pub struct DeviceId { - id: u32, - mask: DeviceMask, -} +#[repr(transparent)] +#[derive(Clone, Copy)] +pub struct DeviceId(bindings::mdio_device_id); impl DeviceId { /// Creates a new instance with the exact match mask. pub const fn new_with_exact_mask(id: u32) -> Self { - DeviceId { - id, - mask: DeviceMask::Exact, - } + Self(bindings::mdio_device_id { + phy_id: id, + phy_id_mask: DeviceMask::Exact.as_int(), + }) } /// Creates a new instance with the model match mask. pub const fn new_with_model_mask(id: u32) -> Self { - DeviceId { - id, - mask: DeviceMask::Model, - } + Self(bindings::mdio_device_id { + phy_id: id, + phy_id_mask: DeviceMask::Model.as_int(), + }) } /// Creates a new instance with the vendor match mask. pub const fn new_with_vendor_mask(id: u32) -> Self { - DeviceId { - id, - mask: DeviceMask::Vendor, - } + Self(bindings::mdio_device_id { + phy_id: id, + phy_id_mask: DeviceMask::Vendor.as_int(), + }) } /// Creates a new instance with a custom match mask. pub const fn new_with_custom_mask(id: u32, mask: u32) -> Self { - DeviceId { - id, - mask: DeviceMask::Custom(mask), - } + Self(bindings::mdio_device_id { + phy_id: id, + phy_id_mask: DeviceMask::Custom(mask).as_int(), + }) } /// Creates a new instance from [`Driver`]. @@ -733,21 +725,29 @@ impl DeviceId { T::PHY_DEVICE_ID } - /// Get a `mask` as u32. + /// Get the MDIO device's PHY ID. + pub const fn id(&self) -> u32 { + self.0.phy_id + } + + /// Get the MDIO device's match mask. pub const fn mask_as_int(&self) -> u32 { - self.mask.as_int() + self.0.phy_id_mask } // macro use only #[doc(hidden)] pub const fn mdio_device_id(&self) -> bindings::mdio_device_id { - bindings::mdio_device_id { - phy_id: self.id, - phy_id_mask: self.mask.as_int(), - } + self.0 } } +// SAFETY: `DeviceId` is a `#[repr(transparent)]` wrapper of `struct mdio_device_id` +// and does not add additional invariants, so it's safe to transmute to `RawType`. +unsafe impl RawDeviceId for DeviceId { + type RawType = bindings::mdio_device_id; +} + enum DeviceMask { Exact, Model, @@ -790,7 +790,7 @@ impl DeviceMask { /// DeviceId::new_with_driver::<PhySample>() /// ], /// name: "rust_sample_phy", -/// author: "Rust for Linux Contributors", +/// authors: ["Rust for Linux Contributors"], /// description: "Rust sample PHYs driver", /// license: "GPL", /// } @@ -819,7 +819,7 @@ impl DeviceMask { /// module! { /// type: Module, /// name: "rust_sample_phy", -/// author: "Rust for Linux Contributors", +/// authors: ["Rust for Linux Contributors"], /// description: "Rust sample PHYs driver", /// license: "GPL", /// } @@ -837,7 +837,7 @@ impl DeviceMask { /// [::kernel::net::phy::create_phy_driver::<PhySample>()]; /// /// impl ::kernel::Module for Module { -/// fn init(module: &'static ThisModule) -> Result<Self> { +/// fn init(module: &'static ::kernel::ThisModule) -> Result<Self> { /// let drivers = unsafe { &mut DRIVERS }; /// let mut reg = ::kernel::net::phy::Registration::register( /// module, @@ -848,19 +848,18 @@ impl DeviceMask { /// } /// }; /// -/// const _DEVICE_TABLE: [::kernel::bindings::mdio_device_id; 2] = [ -/// ::kernel::bindings::mdio_device_id { -/// phy_id: 0x00000001, -/// phy_id_mask: 0xffffffff, -/// }, -/// ::kernel::bindings::mdio_device_id { -/// phy_id: 0, -/// phy_id_mask: 0, -/// }, -/// ]; -/// #[cfg(MODULE)] -/// #[no_mangle] -/// static __mod_device_table__mdio__phydev: [::kernel::bindings::mdio_device_id; 2] = _DEVICE_TABLE; +/// const N: usize = 1; +/// +/// const TABLE: ::kernel::device_id::IdArray<::kernel::net::phy::DeviceId, (), N> = +/// ::kernel::device_id::IdArray::new_without_index([ +/// ::kernel::net::phy::DeviceId( +/// ::kernel::bindings::mdio_device_id { +/// phy_id: 0x00000001, +/// phy_id_mask: 0xffffffff, +/// }), +/// ]); +/// +/// ::kernel::module_device_table!("mdio", phydev, TABLE); /// ``` #[macro_export] macro_rules! module_phy_driver { @@ -871,20 +870,12 @@ macro_rules! module_phy_driver { }; (@device_table [$($dev:expr),+]) => { - // SAFETY: C will not read off the end of this constant since the last element is zero. - const _DEVICE_TABLE: [$crate::bindings::mdio_device_id; - $crate::module_phy_driver!(@count_devices $($dev),+) + 1] = [ - $($dev.mdio_device_id()),+, - $crate::bindings::mdio_device_id { - phy_id: 0, - phy_id_mask: 0 - } - ]; + const N: usize = $crate::module_phy_driver!(@count_devices $($dev),+); + + const TABLE: $crate::device_id::IdArray<$crate::net::phy::DeviceId, (), N> = + $crate::device_id::IdArray::new_without_index([ $(($dev,())),+, ]); - #[cfg(MODULE)] - #[no_mangle] - static __mod_device_table__mdio__phydev: [$crate::bindings::mdio_device_id; - $crate::module_phy_driver!(@count_devices $($dev),+) + 1] = _DEVICE_TABLE; + $crate::module_device_table!("mdio", phydev, TABLE); }; (drivers: [$($driver:ident),+ $(,)?], device_table: [$($dev:expr),+ $(,)?], $($f:tt)*) => { @@ -903,7 +894,7 @@ macro_rules! module_phy_driver { [$($crate::net::phy::create_phy_driver::<$driver>()),+]; impl $crate::Module for Module { - fn init(module: &'static ThisModule) -> Result<Self> { + fn init(module: &'static $crate::ThisModule) -> Result<Self> { // SAFETY: The anonymous constant guarantees that nobody else can access // the `DRIVERS` static. The array is used only in the C side. let drivers = unsafe { &mut DRIVERS }; diff --git a/rust/kernel/of.rs b/rust/kernel/of.rs new file mode 100644 index 000000000000..b76b35265df2 --- /dev/null +++ b/rust/kernel/of.rs @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Device Tree / Open Firmware abstractions. + +use crate::{ + bindings, + device_id::{RawDeviceId, RawDeviceIdIndex}, + prelude::*, +}; + +/// IdTable type for OF drivers. +pub type IdTable<T> = &'static dyn kernel::device_id::IdTable<DeviceId, T>; + +/// An open firmware device id. +#[repr(transparent)] +#[derive(Clone, Copy)] +pub struct DeviceId(bindings::of_device_id); + +// SAFETY: `DeviceId` is a `#[repr(transparent)]` wrapper of `struct of_device_id` and +// does not add additional invariants, so it's safe to transmute to `RawType`. +unsafe impl RawDeviceId for DeviceId { + type RawType = bindings::of_device_id; +} + +// SAFETY: `DRIVER_DATA_OFFSET` is the offset to the `data` field. +unsafe impl RawDeviceIdIndex for DeviceId { + const DRIVER_DATA_OFFSET: usize = core::mem::offset_of!(bindings::of_device_id, data); + + fn index(&self) -> usize { + self.0.data as usize + } +} + +impl DeviceId { + /// Create a new device id from an OF 'compatible' string. + pub const fn new(compatible: &'static CStr) -> Self { + let src = compatible.as_bytes_with_nul(); + // Replace with `bindings::of_device_id::default()` once stabilized for `const`. + // SAFETY: FFI type is valid to be zero-initialized. + let mut of: bindings::of_device_id = unsafe { core::mem::zeroed() }; + + // TODO: Use `copy_from_slice` once stabilized for `const`. + let mut i = 0; + while i < src.len() { + of.compatible[i] = src[i]; + i += 1; + } + + Self(of) + } +} + +/// Create an OF `IdTable` with an "alias" for modpost. +#[macro_export] +macro_rules! of_device_table { + ($table_name:ident, $module_table_name:ident, $id_info_type: ty, $table_data: expr) => { + const $table_name: $crate::device_id::IdArray< + $crate::of::DeviceId, + $id_info_type, + { $table_data.len() }, + > = $crate::device_id::IdArray::new($table_data); + + $crate::module_device_table!("of", $module_table_name, $table_name); + }; +} diff --git a/rust/kernel/opp.rs b/rust/kernel/opp.rs new file mode 100644 index 000000000000..08126035d2c6 --- /dev/null +++ b/rust/kernel/opp.rs @@ -0,0 +1,1146 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Operating performance points. +//! +//! This module provides rust abstractions for interacting with the OPP subsystem. +//! +//! C header: [`include/linux/pm_opp.h`](srctree/include/linux/pm_opp.h) +//! +//! Reference: <https://docs.kernel.org/power/opp.html> + +use crate::{ + clk::Hertz, + cpumask::{Cpumask, CpumaskVar}, + device::Device, + error::{code::*, from_err_ptr, from_result, to_result, Error, Result, VTABLE_DEFAULT_ERROR}, + ffi::c_ulong, + prelude::*, + str::CString, + types::{ARef, AlwaysRefCounted, Opaque}, +}; + +#[cfg(CONFIG_CPU_FREQ)] +/// Frequency table implementation. +mod freq { + use super::*; + use crate::cpufreq; + use core::ops::Deref; + + /// OPP frequency table. + /// + /// A [`cpufreq::Table`] created from [`Table`]. + pub struct FreqTable { + dev: ARef<Device>, + ptr: *mut bindings::cpufreq_frequency_table, + } + + impl FreqTable { + /// Creates a new instance of [`FreqTable`] from [`Table`]. + pub(crate) fn new(table: &Table) -> Result<Self> { + let mut ptr: *mut bindings::cpufreq_frequency_table = ptr::null_mut(); + + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + to_result(unsafe { + bindings::dev_pm_opp_init_cpufreq_table(table.dev.as_raw(), &mut ptr) + })?; + + Ok(Self { + dev: table.dev.clone(), + ptr, + }) + } + + /// Returns a reference to the underlying [`cpufreq::Table`]. + #[inline] + fn table(&self) -> &cpufreq::Table { + // SAFETY: The `ptr` is guaranteed by the C code to be valid. + unsafe { cpufreq::Table::from_raw(self.ptr) } + } + } + + impl Deref for FreqTable { + type Target = cpufreq::Table; + + #[inline] + fn deref(&self) -> &Self::Target { + self.table() + } + } + + impl Drop for FreqTable { + fn drop(&mut self) { + // SAFETY: The pointer was created via `dev_pm_opp_init_cpufreq_table`, and is only + // freed here. + unsafe { + bindings::dev_pm_opp_free_cpufreq_table(self.dev.as_raw(), &mut self.as_raw()) + }; + } + } +} + +#[cfg(CONFIG_CPU_FREQ)] +pub use freq::FreqTable; + +use core::{marker::PhantomData, ptr}; + +use macros::vtable; + +/// Creates a null-terminated slice of pointers to [`Cstring`]s. +fn to_c_str_array(names: &[CString]) -> Result<KVec<*const u8>> { + // Allocated a null-terminated vector of pointers. + let mut list = KVec::with_capacity(names.len() + 1, GFP_KERNEL)?; + + for name in names.iter() { + list.push(name.as_ptr().cast(), GFP_KERNEL)?; + } + + list.push(ptr::null(), GFP_KERNEL)?; + Ok(list) +} + +/// The voltage unit. +/// +/// Represents voltage in microvolts, wrapping a [`c_ulong`] value. +/// +/// # Examples +/// +/// ``` +/// use kernel::opp::MicroVolt; +/// +/// let raw = 90500; +/// let volt = MicroVolt(raw); +/// +/// assert_eq!(usize::from(volt), raw); +/// assert_eq!(volt, MicroVolt(raw)); +/// ``` +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +pub struct MicroVolt(pub c_ulong); + +impl From<MicroVolt> for c_ulong { + #[inline] + fn from(volt: MicroVolt) -> Self { + volt.0 + } +} + +/// The power unit. +/// +/// Represents power in microwatts, wrapping a [`c_ulong`] value. +/// +/// # Examples +/// +/// ``` +/// use kernel::opp::MicroWatt; +/// +/// let raw = 1000000; +/// let power = MicroWatt(raw); +/// +/// assert_eq!(usize::from(power), raw); +/// assert_eq!(power, MicroWatt(raw)); +/// ``` +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +pub struct MicroWatt(pub c_ulong); + +impl From<MicroWatt> for c_ulong { + #[inline] + fn from(power: MicroWatt) -> Self { + power.0 + } +} + +/// Handle for a dynamically created [`OPP`]. +/// +/// The associated [`OPP`] is automatically removed when the [`Token`] is dropped. +/// +/// # Examples +/// +/// The following example demonstrates how to create an [`OPP`] dynamically. +/// +/// ``` +/// use kernel::clk::Hertz; +/// use kernel::device::Device; +/// use kernel::error::Result; +/// use kernel::opp::{Data, MicroVolt, Token}; +/// use kernel::types::ARef; +/// +/// fn create_opp(dev: &ARef<Device>, freq: Hertz, volt: MicroVolt, level: u32) -> Result<Token> { +/// let data = Data::new(freq, volt, level, false); +/// +/// // OPP is removed once token goes out of scope. +/// data.add_opp(dev) +/// } +/// ``` +pub struct Token { + dev: ARef<Device>, + freq: Hertz, +} + +impl Token { + /// Dynamically adds an [`OPP`] and returns a [`Token`] that removes it on drop. + fn new(dev: &ARef<Device>, mut data: Data) -> Result<Self> { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + to_result(unsafe { bindings::dev_pm_opp_add_dynamic(dev.as_raw(), &mut data.0) })?; + Ok(Self { + dev: dev.clone(), + freq: data.freq(), + }) + } +} + +impl Drop for Token { + fn drop(&mut self) { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + unsafe { bindings::dev_pm_opp_remove(self.dev.as_raw(), self.freq.into()) }; + } +} + +/// OPP data. +/// +/// Rust abstraction for the C `struct dev_pm_opp_data`, used to define operating performance +/// points (OPPs) dynamically. +/// +/// # Examples +/// +/// The following example demonstrates how to create an [`OPP`] with [`Data`]. +/// +/// ``` +/// use kernel::clk::Hertz; +/// use kernel::device::Device; +/// use kernel::error::Result; +/// use kernel::opp::{Data, MicroVolt, Token}; +/// use kernel::types::ARef; +/// +/// fn create_opp(dev: &ARef<Device>, freq: Hertz, volt: MicroVolt, level: u32) -> Result<Token> { +/// let data = Data::new(freq, volt, level, false); +/// +/// // OPP is removed once token goes out of scope. +/// data.add_opp(dev) +/// } +/// ``` +#[repr(transparent)] +pub struct Data(bindings::dev_pm_opp_data); + +impl Data { + /// Creates a new instance of [`Data`]. + /// + /// This can be used to define a dynamic OPP to be added to a device. + pub fn new(freq: Hertz, volt: MicroVolt, level: u32, turbo: bool) -> Self { + Self(bindings::dev_pm_opp_data { + turbo, + freq: freq.into(), + u_volt: volt.into(), + level, + }) + } + + /// Adds an [`OPP`] dynamically. + /// + /// Returns a [`Token`] that ensures the OPP is automatically removed + /// when it goes out of scope. + #[inline] + pub fn add_opp(self, dev: &ARef<Device>) -> Result<Token> { + Token::new(dev, self) + } + + /// Returns the frequency associated with this OPP data. + #[inline] + fn freq(&self) -> Hertz { + Hertz(self.0.freq) + } +} + +/// [`OPP`] search options. +/// +/// # Examples +/// +/// Defines how to search for an [`OPP`] in a [`Table`] relative to a frequency. +/// +/// ``` +/// use kernel::clk::Hertz; +/// use kernel::error::Result; +/// use kernel::opp::{OPP, SearchType, Table}; +/// use kernel::types::ARef; +/// +/// fn find_opp(table: &Table, freq: Hertz) -> Result<ARef<OPP>> { +/// let opp = table.opp_from_freq(freq, Some(true), None, SearchType::Exact)?; +/// +/// pr_info!("OPP frequency is: {:?}\n", opp.freq(None)); +/// pr_info!("OPP voltage is: {:?}\n", opp.voltage()); +/// pr_info!("OPP level is: {}\n", opp.level()); +/// pr_info!("OPP power is: {:?}\n", opp.power()); +/// +/// Ok(opp) +/// } +/// ``` +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum SearchType { + /// Match the exact frequency. + Exact, + /// Find the highest frequency less than or equal to the given value. + Floor, + /// Find the lowest frequency greater than or equal to the given value. + Ceil, +} + +/// OPP configuration callbacks. +/// +/// Implement this trait to customize OPP clock and regulator setup for your device. +#[vtable] +pub trait ConfigOps { + /// This is typically used to scale clocks when transitioning between OPPs. + #[inline] + fn config_clks(_dev: &Device, _table: &Table, _opp: &OPP, _scaling_down: bool) -> Result { + build_error!(VTABLE_DEFAULT_ERROR) + } + + /// This provides access to the old and new OPPs, allowing for safe regulator adjustments. + #[inline] + fn config_regulators( + _dev: &Device, + _opp_old: &OPP, + _opp_new: &OPP, + _data: *mut *mut bindings::regulator, + _count: u32, + ) -> Result { + build_error!(VTABLE_DEFAULT_ERROR) + } +} + +/// OPP configuration token. +/// +/// Returned by the OPP core when configuration is applied to a [`Device`]. The associated +/// configuration is automatically cleared when the token is dropped. +pub struct ConfigToken(i32); + +impl Drop for ConfigToken { + fn drop(&mut self) { + // SAFETY: This is the same token value returned by the C code via `dev_pm_opp_set_config`. + unsafe { bindings::dev_pm_opp_clear_config(self.0) }; + } +} + +/// OPP configurations. +/// +/// Rust abstraction for the C `struct dev_pm_opp_config`. +/// +/// # Examples +/// +/// The following example demonstrates how to set OPP property-name configuration for a [`Device`]. +/// +/// ``` +/// use kernel::device::Device; +/// use kernel::error::Result; +/// use kernel::opp::{Config, ConfigOps, ConfigToken}; +/// use kernel::str::CString; +/// use kernel::types::ARef; +/// use kernel::macros::vtable; +/// +/// #[derive(Default)] +/// struct Driver; +/// +/// #[vtable] +/// impl ConfigOps for Driver {} +/// +/// fn configure(dev: &ARef<Device>) -> Result<ConfigToken> { +/// let name = CString::try_from_fmt(fmt!("slow"))?; +/// +/// // The OPP configuration is cleared once the [`ConfigToken`] goes out of scope. +/// Config::<Driver>::new() +/// .set_prop_name(name)? +/// .set(dev) +/// } +/// ``` +#[derive(Default)] +pub struct Config<T: ConfigOps> +where + T: Default, +{ + clk_names: Option<KVec<CString>>, + prop_name: Option<CString>, + regulator_names: Option<KVec<CString>>, + supported_hw: Option<KVec<u32>>, + + // Tuple containing (required device, index) + required_dev: Option<(ARef<Device>, u32)>, + _data: PhantomData<T>, +} + +impl<T: ConfigOps + Default> Config<T> { + /// Creates a new instance of [`Config`]. + #[inline] + pub fn new() -> Self { + Self::default() + } + + /// Initializes clock names. + pub fn set_clk_names(mut self, names: KVec<CString>) -> Result<Self> { + if self.clk_names.is_some() { + return Err(EBUSY); + } + + if names.is_empty() { + return Err(EINVAL); + } + + self.clk_names = Some(names); + Ok(self) + } + + /// Initializes property name. + pub fn set_prop_name(mut self, name: CString) -> Result<Self> { + if self.prop_name.is_some() { + return Err(EBUSY); + } + + self.prop_name = Some(name); + Ok(self) + } + + /// Initializes regulator names. + pub fn set_regulator_names(mut self, names: KVec<CString>) -> Result<Self> { + if self.regulator_names.is_some() { + return Err(EBUSY); + } + + if names.is_empty() { + return Err(EINVAL); + } + + self.regulator_names = Some(names); + + Ok(self) + } + + /// Initializes required devices. + pub fn set_required_dev(mut self, dev: ARef<Device>, index: u32) -> Result<Self> { + if self.required_dev.is_some() { + return Err(EBUSY); + } + + self.required_dev = Some((dev, index)); + Ok(self) + } + + /// Initializes supported hardware. + pub fn set_supported_hw(mut self, hw: KVec<u32>) -> Result<Self> { + if self.supported_hw.is_some() { + return Err(EBUSY); + } + + if hw.is_empty() { + return Err(EINVAL); + } + + self.supported_hw = Some(hw); + Ok(self) + } + + /// Sets the configuration with the OPP core. + /// + /// The returned [`ConfigToken`] will remove the configuration when dropped. + pub fn set(self, dev: &Device) -> Result<ConfigToken> { + let (_clk_list, clk_names) = match &self.clk_names { + Some(x) => { + let list = to_c_str_array(x)?; + let ptr = list.as_ptr(); + (Some(list), ptr) + } + None => (None, ptr::null()), + }; + + let (_regulator_list, regulator_names) = match &self.regulator_names { + Some(x) => { + let list = to_c_str_array(x)?; + let ptr = list.as_ptr(); + (Some(list), ptr) + } + None => (None, ptr::null()), + }; + + let prop_name = self + .prop_name + .as_ref() + .map_or(ptr::null(), |p| p.as_char_ptr()); + + let (supported_hw, supported_hw_count) = self + .supported_hw + .as_ref() + .map_or((ptr::null(), 0), |hw| (hw.as_ptr(), hw.len() as u32)); + + let (required_dev, required_dev_index) = self + .required_dev + .as_ref() + .map_or((ptr::null_mut(), 0), |(dev, idx)| (dev.as_raw(), *idx)); + + let mut config = bindings::dev_pm_opp_config { + clk_names, + config_clks: if T::HAS_CONFIG_CLKS { + Some(Self::config_clks) + } else { + None + }, + prop_name, + regulator_names, + config_regulators: if T::HAS_CONFIG_REGULATORS { + Some(Self::config_regulators) + } else { + None + }, + supported_hw, + supported_hw_count, + + required_dev, + required_dev_index, + }; + + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. The OPP core guarantees not to access fields of [`Config`] after this call + // and so we don't need to save a copy of them for future use. + let ret = unsafe { bindings::dev_pm_opp_set_config(dev.as_raw(), &mut config) }; + if ret < 0 { + Err(Error::from_errno(ret)) + } else { + Ok(ConfigToken(ret)) + } + } + + /// Config's clk callback. + /// + /// SAFETY: Called from C. Inputs must be valid pointers. + extern "C" fn config_clks( + dev: *mut bindings::device, + opp_table: *mut bindings::opp_table, + opp: *mut bindings::dev_pm_opp, + _data: *mut c_void, + scaling_down: bool, + ) -> c_int { + from_result(|| { + // SAFETY: 'dev' is guaranteed by the C code to be valid. + let dev = unsafe { Device::get_device(dev) }; + T::config_clks( + &dev, + // SAFETY: 'opp_table' is guaranteed by the C code to be valid. + &unsafe { Table::from_raw_table(opp_table, &dev) }, + // SAFETY: 'opp' is guaranteed by the C code to be valid. + unsafe { OPP::from_raw_opp(opp)? }, + scaling_down, + ) + .map(|()| 0) + }) + } + + /// Config's regulator callback. + /// + /// SAFETY: Called from C. Inputs must be valid pointers. + extern "C" fn config_regulators( + dev: *mut bindings::device, + old_opp: *mut bindings::dev_pm_opp, + new_opp: *mut bindings::dev_pm_opp, + regulators: *mut *mut bindings::regulator, + count: c_uint, + ) -> c_int { + from_result(|| { + // SAFETY: 'dev' is guaranteed by the C code to be valid. + let dev = unsafe { Device::get_device(dev) }; + T::config_regulators( + &dev, + // SAFETY: 'old_opp' is guaranteed by the C code to be valid. + unsafe { OPP::from_raw_opp(old_opp)? }, + // SAFETY: 'new_opp' is guaranteed by the C code to be valid. + unsafe { OPP::from_raw_opp(new_opp)? }, + regulators, + count, + ) + .map(|()| 0) + }) + } +} + +/// A reference-counted OPP table. +/// +/// Rust abstraction for the C `struct opp_table`. +/// +/// # Invariants +/// +/// The pointer stored in `Self` is non-null and valid for the lifetime of the [`Table`]. +/// +/// Instances of this type are reference-counted. +/// +/// # Examples +/// +/// The following example demonstrates how to get OPP [`Table`] for a [`Cpumask`] and set its +/// frequency. +/// +/// ``` +/// # #![cfg(CONFIG_OF)] +/// use kernel::clk::Hertz; +/// use kernel::cpumask::Cpumask; +/// use kernel::device::Device; +/// use kernel::error::Result; +/// use kernel::opp::Table; +/// use kernel::types::ARef; +/// +/// fn get_table(dev: &ARef<Device>, mask: &mut Cpumask, freq: Hertz) -> Result<Table> { +/// let mut opp_table = Table::from_of_cpumask(dev, mask)?; +/// +/// if opp_table.opp_count()? == 0 { +/// return Err(EINVAL); +/// } +/// +/// pr_info!("Max transition latency is: {} ns\n", opp_table.max_transition_latency_ns()); +/// pr_info!("Suspend frequency is: {:?}\n", opp_table.suspend_freq()); +/// +/// opp_table.set_rate(freq)?; +/// Ok(opp_table) +/// } +/// ``` +pub struct Table { + ptr: *mut bindings::opp_table, + dev: ARef<Device>, + #[allow(dead_code)] + em: bool, + #[allow(dead_code)] + of: bool, + cpus: Option<CpumaskVar>, +} + +/// SAFETY: It is okay to send ownership of [`Table`] across thread boundaries. +unsafe impl Send for Table {} + +/// SAFETY: It is okay to access [`Table`] through shared references from other threads because +/// we're either accessing properties that don't change or that are properly synchronised by C code. +unsafe impl Sync for Table {} + +impl Table { + /// Creates a new reference-counted [`Table`] from a raw pointer. + /// + /// # Safety + /// + /// Callers must ensure that `ptr` is valid and non-null. + unsafe fn from_raw_table(ptr: *mut bindings::opp_table, dev: &ARef<Device>) -> Self { + // SAFETY: By the safety requirements, ptr is valid and its refcount will be incremented. + // + // INVARIANT: The reference-count is decremented when [`Table`] goes out of scope. + unsafe { bindings::dev_pm_opp_get_opp_table_ref(ptr) }; + + Self { + ptr, + dev: dev.clone(), + em: false, + of: false, + cpus: None, + } + } + + /// Creates a new reference-counted [`Table`] instance for a [`Device`]. + pub fn from_dev(dev: &Device) -> Result<Self> { + // SAFETY: The requirements are satisfied by the existence of the [`Device`] and its safety + // requirements. + // + // INVARIANT: The reference-count is incremented by the C code and is decremented when + // [`Table`] goes out of scope. + let ptr = from_err_ptr(unsafe { bindings::dev_pm_opp_get_opp_table(dev.as_raw()) })?; + + Ok(Self { + ptr, + dev: dev.into(), + em: false, + of: false, + cpus: None, + }) + } + + /// Creates a new reference-counted [`Table`] instance for a [`Device`] based on device tree + /// entries. + #[cfg(CONFIG_OF)] + pub fn from_of(dev: &ARef<Device>, index: i32) -> Result<Self> { + // SAFETY: The requirements are satisfied by the existence of the [`Device`] and its safety + // requirements. + // + // INVARIANT: The reference-count is incremented by the C code and is decremented when + // [`Table`] goes out of scope. + to_result(unsafe { bindings::dev_pm_opp_of_add_table_indexed(dev.as_raw(), index) })?; + + // Get the newly created [`Table`]. + let mut table = Self::from_dev(dev)?; + table.of = true; + + Ok(table) + } + + /// Remove device tree based [`Table`]. + #[cfg(CONFIG_OF)] + #[inline] + fn remove_of(&self) { + // SAFETY: The requirements are satisfied by the existence of the [`Device`] and its safety + // requirements. We took the reference from [`from_of`] earlier, it is safe to drop the + // same now. + unsafe { bindings::dev_pm_opp_of_remove_table(self.dev.as_raw()) }; + } + + /// Creates a new reference-counted [`Table`] instance for a [`Cpumask`] based on device tree + /// entries. + #[cfg(CONFIG_OF)] + pub fn from_of_cpumask(dev: &Device, cpumask: &mut Cpumask) -> Result<Self> { + // SAFETY: The cpumask is valid and the returned pointer will be owned by the [`Table`] + // instance. + // + // INVARIANT: The reference-count is incremented by the C code and is decremented when + // [`Table`] goes out of scope. + to_result(unsafe { bindings::dev_pm_opp_of_cpumask_add_table(cpumask.as_raw()) })?; + + // Fetch the newly created table. + let mut table = Self::from_dev(dev)?; + table.cpus = Some(CpumaskVar::try_clone(cpumask)?); + + Ok(table) + } + + /// Remove device tree based [`Table`] for a [`Cpumask`]. + #[cfg(CONFIG_OF)] + #[inline] + fn remove_of_cpumask(&self, cpumask: &Cpumask) { + // SAFETY: The cpumask is valid and we took the reference from [`from_of_cpumask`] earlier, + // it is safe to drop the same now. + unsafe { bindings::dev_pm_opp_of_cpumask_remove_table(cpumask.as_raw()) }; + } + + /// Returns the number of [`OPP`]s in the [`Table`]. + pub fn opp_count(&self) -> Result<u32> { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + let ret = unsafe { bindings::dev_pm_opp_get_opp_count(self.dev.as_raw()) }; + if ret < 0 { + Err(Error::from_errno(ret)) + } else { + Ok(ret as u32) + } + } + + /// Returns max clock latency (in nanoseconds) of the [`OPP`]s in the [`Table`]. + #[inline] + pub fn max_clock_latency_ns(&self) -> usize { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + unsafe { bindings::dev_pm_opp_get_max_clock_latency(self.dev.as_raw()) } + } + + /// Returns max volt latency (in nanoseconds) of the [`OPP`]s in the [`Table`]. + #[inline] + pub fn max_volt_latency_ns(&self) -> usize { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + unsafe { bindings::dev_pm_opp_get_max_volt_latency(self.dev.as_raw()) } + } + + /// Returns max transition latency (in nanoseconds) of the [`OPP`]s in the [`Table`]. + #[inline] + pub fn max_transition_latency_ns(&self) -> usize { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + unsafe { bindings::dev_pm_opp_get_max_transition_latency(self.dev.as_raw()) } + } + + /// Returns the suspend [`OPP`]'s frequency. + #[inline] + pub fn suspend_freq(&self) -> Hertz { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + Hertz(unsafe { bindings::dev_pm_opp_get_suspend_opp_freq(self.dev.as_raw()) }) + } + + /// Synchronizes regulators used by the [`Table`]. + #[inline] + pub fn sync_regulators(&self) -> Result { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + to_result(unsafe { bindings::dev_pm_opp_sync_regulators(self.dev.as_raw()) }) + } + + /// Gets sharing CPUs. + #[inline] + pub fn sharing_cpus(dev: &Device, cpumask: &mut Cpumask) -> Result { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + to_result(unsafe { bindings::dev_pm_opp_get_sharing_cpus(dev.as_raw(), cpumask.as_raw()) }) + } + + /// Sets sharing CPUs. + pub fn set_sharing_cpus(&mut self, cpumask: &mut Cpumask) -> Result { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + to_result(unsafe { + bindings::dev_pm_opp_set_sharing_cpus(self.dev.as_raw(), cpumask.as_raw()) + })?; + + if let Some(mask) = self.cpus.as_mut() { + // Update the cpumask as this will be used while removing the table. + cpumask.copy(mask); + } + + Ok(()) + } + + /// Gets sharing CPUs from device tree. + #[cfg(CONFIG_OF)] + #[inline] + pub fn of_sharing_cpus(dev: &Device, cpumask: &mut Cpumask) -> Result { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + to_result(unsafe { + bindings::dev_pm_opp_of_get_sharing_cpus(dev.as_raw(), cpumask.as_raw()) + }) + } + + /// Updates the voltage value for an [`OPP`]. + #[inline] + pub fn adjust_voltage( + &self, + freq: Hertz, + volt: MicroVolt, + volt_min: MicroVolt, + volt_max: MicroVolt, + ) -> Result { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + to_result(unsafe { + bindings::dev_pm_opp_adjust_voltage( + self.dev.as_raw(), + freq.into(), + volt.into(), + volt_min.into(), + volt_max.into(), + ) + }) + } + + /// Creates [`FreqTable`] from [`Table`]. + #[cfg(CONFIG_CPU_FREQ)] + #[inline] + pub fn cpufreq_table(&mut self) -> Result<FreqTable> { + FreqTable::new(self) + } + + /// Configures device with [`OPP`] matching the frequency value. + #[inline] + pub fn set_rate(&self, freq: Hertz) -> Result { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + to_result(unsafe { bindings::dev_pm_opp_set_rate(self.dev.as_raw(), freq.into()) }) + } + + /// Configures device with [`OPP`]. + #[inline] + pub fn set_opp(&self, opp: &OPP) -> Result { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + to_result(unsafe { bindings::dev_pm_opp_set_opp(self.dev.as_raw(), opp.as_raw()) }) + } + + /// Finds [`OPP`] based on frequency. + pub fn opp_from_freq( + &self, + freq: Hertz, + available: Option<bool>, + index: Option<u32>, + stype: SearchType, + ) -> Result<ARef<OPP>> { + let raw_dev = self.dev.as_raw(); + let index = index.unwrap_or(0); + let mut rate = freq.into(); + + let ptr = from_err_ptr(match stype { + SearchType::Exact => { + if let Some(available) = available { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and + // its safety requirements. The returned pointer will be owned by the new + // [`OPP`] instance. + unsafe { + bindings::dev_pm_opp_find_freq_exact_indexed( + raw_dev, rate, index, available, + ) + } + } else { + return Err(EINVAL); + } + } + + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. The returned pointer will be owned by the new [`OPP`] instance. + SearchType::Ceil => unsafe { + bindings::dev_pm_opp_find_freq_ceil_indexed(raw_dev, &mut rate, index) + }, + + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. The returned pointer will be owned by the new [`OPP`] instance. + SearchType::Floor => unsafe { + bindings::dev_pm_opp_find_freq_floor_indexed(raw_dev, &mut rate, index) + }, + })?; + + // SAFETY: The `ptr` is guaranteed by the C code to be valid. + unsafe { OPP::from_raw_opp_owned(ptr) } + } + + /// Finds [`OPP`] based on level. + pub fn opp_from_level(&self, mut level: u32, stype: SearchType) -> Result<ARef<OPP>> { + let raw_dev = self.dev.as_raw(); + + let ptr = from_err_ptr(match stype { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. The returned pointer will be owned by the new [`OPP`] instance. + SearchType::Exact => unsafe { bindings::dev_pm_opp_find_level_exact(raw_dev, level) }, + + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. The returned pointer will be owned by the new [`OPP`] instance. + SearchType::Ceil => unsafe { + bindings::dev_pm_opp_find_level_ceil(raw_dev, &mut level) + }, + + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. The returned pointer will be owned by the new [`OPP`] instance. + SearchType::Floor => unsafe { + bindings::dev_pm_opp_find_level_floor(raw_dev, &mut level) + }, + })?; + + // SAFETY: The `ptr` is guaranteed by the C code to be valid. + unsafe { OPP::from_raw_opp_owned(ptr) } + } + + /// Finds [`OPP`] based on bandwidth. + pub fn opp_from_bw(&self, mut bw: u32, index: i32, stype: SearchType) -> Result<ARef<OPP>> { + let raw_dev = self.dev.as_raw(); + + let ptr = from_err_ptr(match stype { + // The OPP core doesn't support this yet. + SearchType::Exact => return Err(EINVAL), + + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. The returned pointer will be owned by the new [`OPP`] instance. + SearchType::Ceil => unsafe { + bindings::dev_pm_opp_find_bw_ceil(raw_dev, &mut bw, index) + }, + + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. The returned pointer will be owned by the new [`OPP`] instance. + SearchType::Floor => unsafe { + bindings::dev_pm_opp_find_bw_floor(raw_dev, &mut bw, index) + }, + })?; + + // SAFETY: The `ptr` is guaranteed by the C code to be valid. + unsafe { OPP::from_raw_opp_owned(ptr) } + } + + /// Enables the [`OPP`]. + #[inline] + pub fn enable_opp(&self, freq: Hertz) -> Result { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + to_result(unsafe { bindings::dev_pm_opp_enable(self.dev.as_raw(), freq.into()) }) + } + + /// Disables the [`OPP`]. + #[inline] + pub fn disable_opp(&self, freq: Hertz) -> Result { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + to_result(unsafe { bindings::dev_pm_opp_disable(self.dev.as_raw(), freq.into()) }) + } + + /// Registers with the Energy model. + #[cfg(CONFIG_OF)] + pub fn of_register_em(&mut self, cpumask: &mut Cpumask) -> Result { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. + to_result(unsafe { + bindings::dev_pm_opp_of_register_em(self.dev.as_raw(), cpumask.as_raw()) + })?; + + self.em = true; + Ok(()) + } + + /// Unregisters with the Energy model. + #[cfg(all(CONFIG_OF, CONFIG_ENERGY_MODEL))] + #[inline] + fn of_unregister_em(&self) { + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. We registered with the EM framework earlier, it is safe to unregister now. + unsafe { bindings::em_dev_unregister_perf_domain(self.dev.as_raw()) }; + } +} + +impl Drop for Table { + fn drop(&mut self) { + // SAFETY: By the type invariants, we know that `self` owns a reference, so it is safe + // to relinquish it now. + unsafe { bindings::dev_pm_opp_put_opp_table(self.ptr) }; + + #[cfg(CONFIG_OF)] + { + #[cfg(CONFIG_ENERGY_MODEL)] + if self.em { + self.of_unregister_em(); + } + + if self.of { + self.remove_of(); + } else if let Some(cpumask) = self.cpus.take() { + self.remove_of_cpumask(&cpumask); + } + } + } +} + +/// A reference-counted Operating performance point (OPP). +/// +/// Rust abstraction for the C `struct dev_pm_opp`. +/// +/// # Invariants +/// +/// The pointer stored in `Self` is non-null and valid for the lifetime of the [`OPP`]. +/// +/// Instances of this type are reference-counted. The reference count is incremented by the +/// `dev_pm_opp_get` function and decremented by `dev_pm_opp_put`. The Rust type `ARef<OPP>` +/// represents a pointer that owns a reference count on the [`OPP`]. +/// +/// A reference to the [`OPP`], &[`OPP`], isn't refcounted by the Rust code. +/// +/// # Examples +/// +/// The following example demonstrates how to get [`OPP`] corresponding to a frequency value and +/// configure the device with it. +/// +/// ``` +/// use kernel::clk::Hertz; +/// use kernel::error::Result; +/// use kernel::opp::{SearchType, Table}; +/// +/// fn configure_opp(table: &Table, freq: Hertz) -> Result { +/// let opp = table.opp_from_freq(freq, Some(true), None, SearchType::Exact)?; +/// +/// if opp.freq(None) != freq { +/// return Err(EINVAL); +/// } +/// +/// table.set_opp(&opp) +/// } +/// ``` +#[repr(transparent)] +pub struct OPP(Opaque<bindings::dev_pm_opp>); + +/// SAFETY: It is okay to send the ownership of [`OPP`] across thread boundaries. +unsafe impl Send for OPP {} + +/// SAFETY: It is okay to access [`OPP`] through shared references from other threads because we're +/// either accessing properties that don't change or that are properly synchronised by C code. +unsafe impl Sync for OPP {} + +/// SAFETY: The type invariants guarantee that [`OPP`] is always refcounted. +unsafe impl AlwaysRefCounted for OPP { + fn inc_ref(&self) { + // SAFETY: The existence of a shared reference means that the refcount is nonzero. + unsafe { bindings::dev_pm_opp_get(self.0.get()) }; + } + + unsafe fn dec_ref(obj: ptr::NonNull<Self>) { + // SAFETY: The safety requirements guarantee that the refcount is nonzero. + unsafe { bindings::dev_pm_opp_put(obj.cast().as_ptr()) } + } +} + +impl OPP { + /// Creates an owned reference to a [`OPP`] from a valid pointer. + /// + /// The refcount is incremented by the C code and will be decremented by `dec_ref` when the + /// [`ARef`] object is dropped. + /// + /// # Safety + /// + /// The caller must ensure that `ptr` is valid and the refcount of the [`OPP`] is incremented. + /// The caller must also ensure that it doesn't explicitly drop the refcount of the [`OPP`], as + /// the returned [`ARef`] object takes over the refcount increment on the underlying object and + /// the same will be dropped along with it. + pub unsafe fn from_raw_opp_owned(ptr: *mut bindings::dev_pm_opp) -> Result<ARef<Self>> { + let ptr = ptr::NonNull::new(ptr).ok_or(ENODEV)?; + + // SAFETY: The safety requirements guarantee the validity of the pointer. + // + // INVARIANT: The reference-count is decremented when [`OPP`] goes out of scope. + Ok(unsafe { ARef::from_raw(ptr.cast()) }) + } + + /// Creates a reference to a [`OPP`] from a valid pointer. + /// + /// The refcount is not updated by the Rust API unless the returned reference is converted to + /// an [`ARef`] object. + /// + /// # Safety + /// + /// The caller must ensure that `ptr` is valid and remains valid for the duration of `'a`. + #[inline] + pub unsafe fn from_raw_opp<'a>(ptr: *mut bindings::dev_pm_opp) -> Result<&'a Self> { + // SAFETY: The caller guarantees that the pointer is not dangling and stays valid for the + // duration of 'a. The cast is okay because [`OPP`] is `repr(transparent)`. + Ok(unsafe { &*ptr.cast() }) + } + + #[inline] + fn as_raw(&self) -> *mut bindings::dev_pm_opp { + self.0.get() + } + + /// Returns the frequency of an [`OPP`]. + pub fn freq(&self, index: Option<u32>) -> Hertz { + let index = index.unwrap_or(0); + + // SAFETY: By the type invariants, we know that `self` owns a reference, so it is safe to + // use it. + Hertz(unsafe { bindings::dev_pm_opp_get_freq_indexed(self.as_raw(), index) }) + } + + /// Returns the voltage of an [`OPP`]. + #[inline] + pub fn voltage(&self) -> MicroVolt { + // SAFETY: By the type invariants, we know that `self` owns a reference, so it is safe to + // use it. + MicroVolt(unsafe { bindings::dev_pm_opp_get_voltage(self.as_raw()) }) + } + + /// Returns the level of an [`OPP`]. + #[inline] + pub fn level(&self) -> u32 { + // SAFETY: By the type invariants, we know that `self` owns a reference, so it is safe to + // use it. + unsafe { bindings::dev_pm_opp_get_level(self.as_raw()) } + } + + /// Returns the power of an [`OPP`]. + #[inline] + pub fn power(&self) -> MicroWatt { + // SAFETY: By the type invariants, we know that `self` owns a reference, so it is safe to + // use it. + MicroWatt(unsafe { bindings::dev_pm_opp_get_power(self.as_raw()) }) + } + + /// Returns the required pstate of an [`OPP`]. + #[inline] + pub fn required_pstate(&self, index: u32) -> u32 { + // SAFETY: By the type invariants, we know that `self` owns a reference, so it is safe to + // use it. + unsafe { bindings::dev_pm_opp_get_required_pstate(self.as_raw(), index) } + } + + /// Returns true if the [`OPP`] is turbo. + #[inline] + pub fn is_turbo(&self) -> bool { + // SAFETY: By the type invariants, we know that `self` owns a reference, so it is safe to + // use it. + unsafe { bindings::dev_pm_opp_is_turbo(self.as_raw()) } + } +} diff --git a/rust/kernel/page.rs b/rust/kernel/page.rs index fdac6c375fe4..7c1b17246ed5 100644 --- a/rust/kernel/page.rs +++ b/rust/kernel/page.rs @@ -57,9 +57,8 @@ impl Page { /// ``` /// use kernel::page::Page; /// - /// # fn dox() -> Result<(), kernel::alloc::AllocError> { /// let page = Page::alloc_page(GFP_KERNEL)?; - /// # Ok(()) } + /// # Ok::<(), kernel::alloc::AllocError>(()) /// ``` /// /// Allocate memory for a page and zero its contents. @@ -67,10 +66,10 @@ impl Page { /// ``` /// use kernel::page::Page; /// - /// # fn dox() -> Result<(), kernel::alloc::AllocError> { /// let page = Page::alloc_page(GFP_KERNEL | __GFP_ZERO)?; - /// # Ok(()) } + /// # Ok::<(), kernel::alloc::AllocError>(()) /// ``` + #[inline] pub fn alloc_page(flags: Flags) -> Result<Self, AllocError> { // SAFETY: Depending on the value of `gfp_flags`, this call may sleep. Other than that, it // is always safe to call this method. @@ -253,6 +252,7 @@ impl Page { } impl Drop for Page { + #[inline] fn drop(&mut self) { // SAFETY: By the type invariants, we have ownership of the page and can free it. unsafe { bindings::__free_pages(self.page.as_ptr(), 0) }; diff --git a/rust/kernel/pci.rs b/rust/kernel/pci.rs new file mode 100644 index 000000000000..658e806a5da7 --- /dev/null +++ b/rust/kernel/pci.rs @@ -0,0 +1,506 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Abstractions for the PCI bus. +//! +//! C header: [`include/linux/pci.h`](srctree/include/linux/pci.h) + +use crate::{ + bindings, container_of, device, + device_id::{RawDeviceId, RawDeviceIdIndex}, + devres::Devres, + driver, + error::{from_result, to_result, Result}, + io::Io, + io::IoRaw, + str::CStr, + types::{ARef, Opaque}, + ThisModule, +}; +use core::{ + marker::PhantomData, + ops::Deref, + ptr::{addr_of_mut, NonNull}, +}; +use kernel::prelude::*; + +/// An adapter for the registration of PCI drivers. +pub struct Adapter<T: Driver>(T); + +// SAFETY: A call to `unregister` for a given instance of `RegType` is guaranteed to be valid if +// a preceding call to `register` has been successful. +unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { + type RegType = bindings::pci_driver; + + unsafe fn register( + pdrv: &Opaque<Self::RegType>, + name: &'static CStr, + module: &'static ThisModule, + ) -> Result { + // SAFETY: It's safe to set the fields of `struct pci_driver` on initialization. + unsafe { + (*pdrv.get()).name = name.as_char_ptr(); + (*pdrv.get()).probe = Some(Self::probe_callback); + (*pdrv.get()).remove = Some(Self::remove_callback); + (*pdrv.get()).id_table = T::ID_TABLE.as_ptr(); + } + + // SAFETY: `pdrv` is guaranteed to be a valid `RegType`. + to_result(unsafe { + bindings::__pci_register_driver(pdrv.get(), module.0, name.as_char_ptr()) + }) + } + + unsafe fn unregister(pdrv: &Opaque<Self::RegType>) { + // SAFETY: `pdrv` is guaranteed to be a valid `RegType`. + unsafe { bindings::pci_unregister_driver(pdrv.get()) } + } +} + +impl<T: Driver + 'static> Adapter<T> { + extern "C" fn probe_callback( + pdev: *mut bindings::pci_dev, + id: *const bindings::pci_device_id, + ) -> kernel::ffi::c_int { + // SAFETY: The PCI bus only ever calls the probe callback with a valid pointer to a + // `struct pci_dev`. + // + // INVARIANT: `pdev` is valid for the duration of `probe_callback()`. + let pdev = unsafe { &*pdev.cast::<Device<device::CoreInternal>>() }; + + // SAFETY: `DeviceId` is a `#[repr(transparent)]` wrapper of `struct pci_device_id` and + // does not add additional invariants, so it's safe to transmute. + let id = unsafe { &*id.cast::<DeviceId>() }; + let info = T::ID_TABLE.info(id.index()); + + from_result(|| { + let data = T::probe(pdev, info)?; + + pdev.as_ref().set_drvdata(data); + Ok(0) + }) + } + + extern "C" fn remove_callback(pdev: *mut bindings::pci_dev) { + // SAFETY: The PCI bus only ever calls the remove callback with a valid pointer to a + // `struct pci_dev`. + // + // INVARIANT: `pdev` is valid for the duration of `remove_callback()`. + let pdev = unsafe { &*pdev.cast::<Device<device::CoreInternal>>() }; + + // SAFETY: `remove_callback` is only ever called after a successful call to + // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called + // and stored a `Pin<KBox<T>>`. + let data = unsafe { pdev.as_ref().drvdata_obtain::<Pin<KBox<T>>>() }; + + T::unbind(pdev, data.as_ref()); + } +} + +/// Declares a kernel module that exposes a single PCI driver. +/// +/// # Examples +/// +///```ignore +/// kernel::module_pci_driver! { +/// type: MyDriver, +/// name: "Module name", +/// authors: ["Author name"], +/// description: "Description", +/// license: "GPL v2", +/// } +///``` +#[macro_export] +macro_rules! module_pci_driver { +($($f:tt)*) => { + $crate::module_driver!(<T>, $crate::pci::Adapter<T>, { $($f)* }); +}; +} + +/// Abstraction for the PCI device ID structure ([`struct pci_device_id`]). +/// +/// [`struct pci_device_id`]: https://docs.kernel.org/PCI/pci.html#c.pci_device_id +#[repr(transparent)] +#[derive(Clone, Copy)] +pub struct DeviceId(bindings::pci_device_id); + +impl DeviceId { + const PCI_ANY_ID: u32 = !0; + + /// Equivalent to C's `PCI_DEVICE` macro. + /// + /// Create a new `pci::DeviceId` from a vendor and device ID number. + pub const fn from_id(vendor: u32, device: u32) -> Self { + Self(bindings::pci_device_id { + vendor, + device, + subvendor: DeviceId::PCI_ANY_ID, + subdevice: DeviceId::PCI_ANY_ID, + class: 0, + class_mask: 0, + driver_data: 0, + override_only: 0, + }) + } + + /// Equivalent to C's `PCI_DEVICE_CLASS` macro. + /// + /// Create a new `pci::DeviceId` from a class number and mask. + pub const fn from_class(class: u32, class_mask: u32) -> Self { + Self(bindings::pci_device_id { + vendor: DeviceId::PCI_ANY_ID, + device: DeviceId::PCI_ANY_ID, + subvendor: DeviceId::PCI_ANY_ID, + subdevice: DeviceId::PCI_ANY_ID, + class, + class_mask, + driver_data: 0, + override_only: 0, + }) + } +} + +// SAFETY: `DeviceId` is a `#[repr(transparent)]` wrapper of `pci_device_id` and does not add +// additional invariants, so it's safe to transmute to `RawType`. +unsafe impl RawDeviceId for DeviceId { + type RawType = bindings::pci_device_id; +} + +// SAFETY: `DRIVER_DATA_OFFSET` is the offset to the `driver_data` field. +unsafe impl RawDeviceIdIndex for DeviceId { + const DRIVER_DATA_OFFSET: usize = core::mem::offset_of!(bindings::pci_device_id, driver_data); + + fn index(&self) -> usize { + self.0.driver_data + } +} + +/// `IdTable` type for PCI. +pub type IdTable<T> = &'static dyn kernel::device_id::IdTable<DeviceId, T>; + +/// Create a PCI `IdTable` with its alias for modpost. +#[macro_export] +macro_rules! pci_device_table { + ($table_name:ident, $module_table_name:ident, $id_info_type: ty, $table_data: expr) => { + const $table_name: $crate::device_id::IdArray< + $crate::pci::DeviceId, + $id_info_type, + { $table_data.len() }, + > = $crate::device_id::IdArray::new($table_data); + + $crate::module_device_table!("pci", $module_table_name, $table_name); + }; +} + +/// The PCI driver trait. +/// +/// # Examples +/// +///``` +/// # use kernel::{bindings, device::Core, pci}; +/// +/// struct MyDriver; +/// +/// kernel::pci_device_table!( +/// PCI_TABLE, +/// MODULE_PCI_TABLE, +/// <MyDriver as pci::Driver>::IdInfo, +/// [ +/// ( +/// pci::DeviceId::from_id(bindings::PCI_VENDOR_ID_REDHAT, bindings::PCI_ANY_ID as u32), +/// (), +/// ) +/// ] +/// ); +/// +/// impl pci::Driver for MyDriver { +/// type IdInfo = (); +/// const ID_TABLE: pci::IdTable<Self::IdInfo> = &PCI_TABLE; +/// +/// fn probe( +/// _pdev: &pci::Device<Core>, +/// _id_info: &Self::IdInfo, +/// ) -> Result<Pin<KBox<Self>>> { +/// Err(ENODEV) +/// } +/// } +///``` +/// Drivers must implement this trait in order to get a PCI driver registered. Please refer to the +/// `Adapter` documentation for an example. +pub trait Driver: Send { + /// The type holding information about each device id supported by the driver. + // TODO: Use `associated_type_defaults` once stabilized: + // + // ``` + // type IdInfo: 'static = (); + // ``` + type IdInfo: 'static; + + /// The table of device ids supported by the driver. + const ID_TABLE: IdTable<Self::IdInfo>; + + /// PCI driver probe. + /// + /// Called when a new pci device is added or discovered. Implementers should + /// attempt to initialize the device here. + fn probe(dev: &Device<device::Core>, id_info: &Self::IdInfo) -> Result<Pin<KBox<Self>>>; + + /// PCI driver unbind. + /// + /// Called when a [`Device`] is unbound from its bound [`Driver`]. Implementing this callback + /// is optional. + /// + /// This callback serves as a place for drivers to perform teardown operations that require a + /// `&Device<Core>` or `&Device<Bound>` reference. For instance, drivers may try to perform I/O + /// operations to gracefully tear down the device. + /// + /// Otherwise, release operations for driver resources should be performed in `Self::drop`. + fn unbind(dev: &Device<device::Core>, this: Pin<&Self>) { + let _ = (dev, this); + } +} + +/// The PCI device representation. +/// +/// This structure represents the Rust abstraction for a C `struct pci_dev`. The implementation +/// abstracts the usage of an already existing C `struct pci_dev` within Rust code that we get +/// passed from the C side. +/// +/// # Invariants +/// +/// A [`Device`] instance represents a valid `struct pci_dev` created by the C portion of the +/// kernel. +#[repr(transparent)] +pub struct Device<Ctx: device::DeviceContext = device::Normal>( + Opaque<bindings::pci_dev>, + PhantomData<Ctx>, +); + +/// A PCI BAR to perform I/O-Operations on. +/// +/// # Invariants +/// +/// `Bar` always holds an `IoRaw` inststance that holds a valid pointer to the start of the I/O +/// memory mapped PCI bar and its size. +pub struct Bar<const SIZE: usize = 0> { + pdev: ARef<Device>, + io: IoRaw<SIZE>, + num: i32, +} + +impl<const SIZE: usize> Bar<SIZE> { + fn new(pdev: &Device, num: u32, name: &CStr) -> Result<Self> { + let len = pdev.resource_len(num)?; + if len == 0 { + return Err(ENOMEM); + } + + // Convert to `i32`, since that's what all the C bindings use. + let num = i32::try_from(num)?; + + // SAFETY: + // `pdev` is valid by the invariants of `Device`. + // `num` is checked for validity by a previous call to `Device::resource_len`. + // `name` is always valid. + let ret = unsafe { bindings::pci_request_region(pdev.as_raw(), num, name.as_char_ptr()) }; + if ret != 0 { + return Err(EBUSY); + } + + // SAFETY: + // `pdev` is valid by the invariants of `Device`. + // `num` is checked for validity by a previous call to `Device::resource_len`. + // `name` is always valid. + let ioptr: usize = unsafe { bindings::pci_iomap(pdev.as_raw(), num, 0) } as usize; + if ioptr == 0 { + // SAFETY: + // `pdev` valid by the invariants of `Device`. + // `num` is checked for validity by a previous call to `Device::resource_len`. + unsafe { bindings::pci_release_region(pdev.as_raw(), num) }; + return Err(ENOMEM); + } + + let io = match IoRaw::new(ioptr, len as usize) { + Ok(io) => io, + Err(err) => { + // SAFETY: + // `pdev` is valid by the invariants of `Device`. + // `ioptr` is guaranteed to be the start of a valid I/O mapped memory region. + // `num` is checked for validity by a previous call to `Device::resource_len`. + unsafe { Self::do_release(pdev, ioptr, num) }; + return Err(err); + } + }; + + Ok(Bar { + pdev: pdev.into(), + io, + num, + }) + } + + /// # Safety + /// + /// `ioptr` must be a valid pointer to the memory mapped PCI bar number `num`. + unsafe fn do_release(pdev: &Device, ioptr: usize, num: i32) { + // SAFETY: + // `pdev` is valid by the invariants of `Device`. + // `ioptr` is valid by the safety requirements. + // `num` is valid by the safety requirements. + unsafe { + bindings::pci_iounmap(pdev.as_raw(), ioptr as *mut kernel::ffi::c_void); + bindings::pci_release_region(pdev.as_raw(), num); + } + } + + fn release(&self) { + // SAFETY: The safety requirements are guaranteed by the type invariant of `self.pdev`. + unsafe { Self::do_release(&self.pdev, self.io.addr(), self.num) }; + } +} + +impl Bar { + fn index_is_valid(index: u32) -> bool { + // A `struct pci_dev` owns an array of resources with at most `PCI_NUM_RESOURCES` entries. + index < bindings::PCI_NUM_RESOURCES + } +} + +impl<const SIZE: usize> Drop for Bar<SIZE> { + fn drop(&mut self) { + self.release(); + } +} + +impl<const SIZE: usize> Deref for Bar<SIZE> { + type Target = Io<SIZE>; + + fn deref(&self) -> &Self::Target { + // SAFETY: By the type invariant of `Self`, the MMIO range in `self.io` is properly mapped. + unsafe { Io::from_raw(&self.io) } + } +} + +impl<Ctx: device::DeviceContext> Device<Ctx> { + fn as_raw(&self) -> *mut bindings::pci_dev { + self.0.get() + } +} + +impl Device { + /// Returns the PCI vendor ID. + pub fn vendor_id(&self) -> u16 { + // SAFETY: `self.as_raw` is a valid pointer to a `struct pci_dev`. + unsafe { (*self.as_raw()).vendor } + } + + /// Returns the PCI device ID. + pub fn device_id(&self) -> u16 { + // SAFETY: `self.as_raw` is a valid pointer to a `struct pci_dev`. + unsafe { (*self.as_raw()).device } + } + + /// Returns the size of the given PCI bar resource. + pub fn resource_len(&self, bar: u32) -> Result<bindings::resource_size_t> { + if !Bar::index_is_valid(bar) { + return Err(EINVAL); + } + + // SAFETY: + // - `bar` is a valid bar number, as guaranteed by the above call to `Bar::index_is_valid`, + // - by its type invariant `self.as_raw` is always a valid pointer to a `struct pci_dev`. + Ok(unsafe { bindings::pci_resource_len(self.as_raw(), bar.try_into()?) }) + } +} + +impl Device<device::Bound> { + /// Mapps an entire PCI-BAR after performing a region-request on it. I/O operation bound checks + /// can be performed on compile time for offsets (plus the requested type size) < SIZE. + pub fn iomap_region_sized<'a, const SIZE: usize>( + &'a self, + bar: u32, + name: &'a CStr, + ) -> impl PinInit<Devres<Bar<SIZE>>, Error> + 'a { + Devres::new(self.as_ref(), Bar::<SIZE>::new(self, bar, name)) + } + + /// Mapps an entire PCI-BAR after performing a region-request on it. + pub fn iomap_region<'a>( + &'a self, + bar: u32, + name: &'a CStr, + ) -> impl PinInit<Devres<Bar>, Error> + 'a { + self.iomap_region_sized::<0>(bar, name) + } +} + +impl Device<device::Core> { + /// Enable memory resources for this device. + pub fn enable_device_mem(&self) -> Result { + // SAFETY: `self.as_raw` is guaranteed to be a pointer to a valid `struct pci_dev`. + to_result(unsafe { bindings::pci_enable_device_mem(self.as_raw()) }) + } + + /// Enable bus-mastering for this device. + pub fn set_master(&self) { + // SAFETY: `self.as_raw` is guaranteed to be a pointer to a valid `struct pci_dev`. + unsafe { bindings::pci_set_master(self.as_raw()) }; + } +} + +// SAFETY: `Device` is a transparent wrapper of a type that doesn't depend on `Device`'s generic +// argument. +kernel::impl_device_context_deref!(unsafe { Device }); +kernel::impl_device_context_into_aref!(Device); + +impl crate::dma::Device for Device<device::Core> {} + +// SAFETY: Instances of `Device` are always reference-counted. +unsafe impl crate::types::AlwaysRefCounted for Device { + fn inc_ref(&self) { + // SAFETY: The existence of a shared reference guarantees that the refcount is non-zero. + unsafe { bindings::pci_dev_get(self.as_raw()) }; + } + + unsafe fn dec_ref(obj: NonNull<Self>) { + // SAFETY: The safety requirements guarantee that the refcount is non-zero. + unsafe { bindings::pci_dev_put(obj.cast().as_ptr()) } + } +} + +impl<Ctx: device::DeviceContext> AsRef<device::Device<Ctx>> for Device<Ctx> { + fn as_ref(&self) -> &device::Device<Ctx> { + // SAFETY: By the type invariant of `Self`, `self.as_raw()` is a pointer to a valid + // `struct pci_dev`. + let dev = unsafe { addr_of_mut!((*self.as_raw()).dev) }; + + // SAFETY: `dev` points to a valid `struct device`. + unsafe { device::Device::from_raw(dev) } + } +} + +impl<Ctx: device::DeviceContext> TryFrom<&device::Device<Ctx>> for &Device<Ctx> { + type Error = kernel::error::Error; + + fn try_from(dev: &device::Device<Ctx>) -> Result<Self, Self::Error> { + // SAFETY: By the type invariant of `Device`, `dev.as_raw()` is a valid pointer to a + // `struct device`. + if !unsafe { bindings::dev_is_pci(dev.as_raw()) } { + return Err(EINVAL); + } + + // SAFETY: We've just verified that the bus type of `dev` equals `bindings::pci_bus_type`, + // hence `dev` must be embedded in a valid `struct pci_dev` as guaranteed by the + // corresponding C code. + let pdev = unsafe { container_of!(dev.as_raw(), bindings::pci_dev, dev) }; + + // SAFETY: `pdev` is a valid pointer to a `struct pci_dev`. + Ok(unsafe { &*pdev.cast() }) + } +} + +// SAFETY: A `Device` is always reference-counted and can be released from any thread. +unsafe impl Send for Device {} + +// SAFETY: `Device` can be shared among threads because all methods of `Device` +// (i.e. `Device<Normal>) are thread safe. +unsafe impl Sync for Device {} diff --git a/rust/kernel/platform.rs b/rust/kernel/platform.rs new file mode 100644 index 000000000000..8f028c76f9fa --- /dev/null +++ b/rust/kernel/platform.rs @@ -0,0 +1,343 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Abstractions for the platform bus. +//! +//! C header: [`include/linux/platform_device.h`](srctree/include/linux/platform_device.h) + +use crate::{ + acpi, bindings, container_of, + device::{self, Bound}, + driver, + error::{from_result, to_result, Result}, + io::{mem::IoRequest, Resource}, + of, + prelude::*, + types::Opaque, + ThisModule, +}; + +use core::{ + marker::PhantomData, + ptr::{addr_of_mut, NonNull}, +}; + +/// An adapter for the registration of platform drivers. +pub struct Adapter<T: Driver>(T); + +// SAFETY: A call to `unregister` for a given instance of `RegType` is guaranteed to be valid if +// a preceding call to `register` has been successful. +unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { + type RegType = bindings::platform_driver; + + unsafe fn register( + pdrv: &Opaque<Self::RegType>, + name: &'static CStr, + module: &'static ThisModule, + ) -> Result { + let of_table = match T::OF_ID_TABLE { + Some(table) => table.as_ptr(), + None => core::ptr::null(), + }; + + let acpi_table = match T::ACPI_ID_TABLE { + Some(table) => table.as_ptr(), + None => core::ptr::null(), + }; + + // SAFETY: It's safe to set the fields of `struct platform_driver` on initialization. + unsafe { + (*pdrv.get()).driver.name = name.as_char_ptr(); + (*pdrv.get()).probe = Some(Self::probe_callback); + (*pdrv.get()).remove = Some(Self::remove_callback); + (*pdrv.get()).driver.of_match_table = of_table; + (*pdrv.get()).driver.acpi_match_table = acpi_table; + } + + // SAFETY: `pdrv` is guaranteed to be a valid `RegType`. + to_result(unsafe { bindings::__platform_driver_register(pdrv.get(), module.0) }) + } + + unsafe fn unregister(pdrv: &Opaque<Self::RegType>) { + // SAFETY: `pdrv` is guaranteed to be a valid `RegType`. + unsafe { bindings::platform_driver_unregister(pdrv.get()) }; + } +} + +impl<T: Driver + 'static> Adapter<T> { + extern "C" fn probe_callback(pdev: *mut bindings::platform_device) -> kernel::ffi::c_int { + // SAFETY: The platform bus only ever calls the probe callback with a valid pointer to a + // `struct platform_device`. + // + // INVARIANT: `pdev` is valid for the duration of `probe_callback()`. + let pdev = unsafe { &*pdev.cast::<Device<device::CoreInternal>>() }; + let info = <Self as driver::Adapter>::id_info(pdev.as_ref()); + + from_result(|| { + let data = T::probe(pdev, info)?; + + pdev.as_ref().set_drvdata(data); + Ok(0) + }) + } + + extern "C" fn remove_callback(pdev: *mut bindings::platform_device) { + // SAFETY: The platform bus only ever calls the remove callback with a valid pointer to a + // `struct platform_device`. + // + // INVARIANT: `pdev` is valid for the duration of `probe_callback()`. + let pdev = unsafe { &*pdev.cast::<Device<device::CoreInternal>>() }; + + // SAFETY: `remove_callback` is only ever called after a successful call to + // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called + // and stored a `Pin<KBox<T>>`. + let data = unsafe { pdev.as_ref().drvdata_obtain::<Pin<KBox<T>>>() }; + + T::unbind(pdev, data.as_ref()); + } +} + +impl<T: Driver + 'static> driver::Adapter for Adapter<T> { + type IdInfo = T::IdInfo; + + fn of_id_table() -> Option<of::IdTable<Self::IdInfo>> { + T::OF_ID_TABLE + } + + fn acpi_id_table() -> Option<acpi::IdTable<Self::IdInfo>> { + T::ACPI_ID_TABLE + } +} + +/// Declares a kernel module that exposes a single platform driver. +/// +/// # Examples +/// +/// ```ignore +/// kernel::module_platform_driver! { +/// type: MyDriver, +/// name: "Module name", +/// authors: ["Author name"], +/// description: "Description", +/// license: "GPL v2", +/// } +/// ``` +#[macro_export] +macro_rules! module_platform_driver { + ($($f:tt)*) => { + $crate::module_driver!(<T>, $crate::platform::Adapter<T>, { $($f)* }); + }; +} + +/// The platform driver trait. +/// +/// Drivers must implement this trait in order to get a platform driver registered. +/// +/// # Examples +/// +///``` +/// # use kernel::{acpi, bindings, c_str, device::Core, of, platform}; +/// +/// struct MyDriver; +/// +/// kernel::of_device_table!( +/// OF_TABLE, +/// MODULE_OF_TABLE, +/// <MyDriver as platform::Driver>::IdInfo, +/// [ +/// (of::DeviceId::new(c_str!("test,device")), ()) +/// ] +/// ); +/// +/// kernel::acpi_device_table!( +/// ACPI_TABLE, +/// MODULE_ACPI_TABLE, +/// <MyDriver as platform::Driver>::IdInfo, +/// [ +/// (acpi::DeviceId::new(c_str!("LNUXBEEF")), ()) +/// ] +/// ); +/// +/// impl platform::Driver for MyDriver { +/// type IdInfo = (); +/// const OF_ID_TABLE: Option<of::IdTable<Self::IdInfo>> = Some(&OF_TABLE); +/// const ACPI_ID_TABLE: Option<acpi::IdTable<Self::IdInfo>> = Some(&ACPI_TABLE); +/// +/// fn probe( +/// _pdev: &platform::Device<Core>, +/// _id_info: Option<&Self::IdInfo>, +/// ) -> Result<Pin<KBox<Self>>> { +/// Err(ENODEV) +/// } +/// } +///``` +pub trait Driver: Send { + /// The type holding driver private data about each device id supported by the driver. + // TODO: Use associated_type_defaults once stabilized: + // + // ``` + // type IdInfo: 'static = (); + // ``` + type IdInfo: 'static; + + /// The table of OF device ids supported by the driver. + const OF_ID_TABLE: Option<of::IdTable<Self::IdInfo>> = None; + + /// The table of ACPI device ids supported by the driver. + const ACPI_ID_TABLE: Option<acpi::IdTable<Self::IdInfo>> = None; + + /// Platform driver probe. + /// + /// Called when a new platform device is added or discovered. + /// Implementers should attempt to initialize the device here. + fn probe(dev: &Device<device::Core>, id_info: Option<&Self::IdInfo>) + -> Result<Pin<KBox<Self>>>; + + /// Platform driver unbind. + /// + /// Called when a [`Device`] is unbound from its bound [`Driver`]. Implementing this callback + /// is optional. + /// + /// This callback serves as a place for drivers to perform teardown operations that require a + /// `&Device<Core>` or `&Device<Bound>` reference. For instance, drivers may try to perform I/O + /// operations to gracefully tear down the device. + /// + /// Otherwise, release operations for driver resources should be performed in `Self::drop`. + fn unbind(dev: &Device<device::Core>, this: Pin<&Self>) { + let _ = (dev, this); + } +} + +/// The platform device representation. +/// +/// This structure represents the Rust abstraction for a C `struct platform_device`. The +/// implementation abstracts the usage of an already existing C `struct platform_device` within Rust +/// code that we get passed from the C side. +/// +/// # Invariants +/// +/// A [`Device`] instance represents a valid `struct platform_device` created by the C portion of +/// the kernel. +#[repr(transparent)] +pub struct Device<Ctx: device::DeviceContext = device::Normal>( + Opaque<bindings::platform_device>, + PhantomData<Ctx>, +); + +impl<Ctx: device::DeviceContext> Device<Ctx> { + fn as_raw(&self) -> *mut bindings::platform_device { + self.0.get() + } + + /// Returns the resource at `index`, if any. + pub fn resource_by_index(&self, index: u32) -> Option<&Resource> { + // SAFETY: `self.as_raw()` returns a valid pointer to a `struct platform_device`. + let resource = unsafe { + bindings::platform_get_resource(self.as_raw(), bindings::IORESOURCE_MEM, index) + }; + + if resource.is_null() { + return None; + } + + // SAFETY: `resource` is a valid pointer to a `struct resource` as + // returned by `platform_get_resource`. + Some(unsafe { Resource::from_raw(resource) }) + } + + /// Returns the resource with a given `name`, if any. + pub fn resource_by_name(&self, name: &CStr) -> Option<&Resource> { + // SAFETY: `self.as_raw()` returns a valid pointer to a `struct + // platform_device` and `name` points to a valid C string. + let resource = unsafe { + bindings::platform_get_resource_byname( + self.as_raw(), + bindings::IORESOURCE_MEM, + name.as_char_ptr(), + ) + }; + + if resource.is_null() { + return None; + } + + // SAFETY: `resource` is a valid pointer to a `struct resource` as + // returned by `platform_get_resource`. + Some(unsafe { Resource::from_raw(resource) }) + } +} + +impl Device<Bound> { + /// Returns an `IoRequest` for the resource at `index`, if any. + pub fn io_request_by_index(&self, index: u32) -> Option<IoRequest<'_>> { + self.resource_by_index(index) + // SAFETY: `resource` is a valid resource for `&self` during the + // lifetime of the `IoRequest`. + .map(|resource| unsafe { IoRequest::new(self.as_ref(), resource) }) + } + + /// Returns an `IoRequest` for the resource with a given `name`, if any. + pub fn io_request_by_name(&self, name: &CStr) -> Option<IoRequest<'_>> { + self.resource_by_name(name) + // SAFETY: `resource` is a valid resource for `&self` during the + // lifetime of the `IoRequest`. + .map(|resource| unsafe { IoRequest::new(self.as_ref(), resource) }) + } +} + +// SAFETY: `Device` is a transparent wrapper of a type that doesn't depend on `Device`'s generic +// argument. +kernel::impl_device_context_deref!(unsafe { Device }); +kernel::impl_device_context_into_aref!(Device); + +impl crate::dma::Device for Device<device::Core> {} + +// SAFETY: Instances of `Device` are always reference-counted. +unsafe impl crate::types::AlwaysRefCounted for Device { + fn inc_ref(&self) { + // SAFETY: The existence of a shared reference guarantees that the refcount is non-zero. + unsafe { bindings::get_device(self.as_ref().as_raw()) }; + } + + unsafe fn dec_ref(obj: NonNull<Self>) { + // SAFETY: The safety requirements guarantee that the refcount is non-zero. + unsafe { bindings::platform_device_put(obj.cast().as_ptr()) } + } +} + +impl<Ctx: device::DeviceContext> AsRef<device::Device<Ctx>> for Device<Ctx> { + fn as_ref(&self) -> &device::Device<Ctx> { + // SAFETY: By the type invariant of `Self`, `self.as_raw()` is a pointer to a valid + // `struct platform_device`. + let dev = unsafe { addr_of_mut!((*self.as_raw()).dev) }; + + // SAFETY: `dev` points to a valid `struct device`. + unsafe { device::Device::from_raw(dev) } + } +} + +impl<Ctx: device::DeviceContext> TryFrom<&device::Device<Ctx>> for &Device<Ctx> { + type Error = kernel::error::Error; + + fn try_from(dev: &device::Device<Ctx>) -> Result<Self, Self::Error> { + // SAFETY: By the type invariant of `Device`, `dev.as_raw()` is a valid pointer to a + // `struct device`. + if !unsafe { bindings::dev_is_platform(dev.as_raw()) } { + return Err(EINVAL); + } + + // SAFETY: We've just verified that the bus type of `dev` equals + // `bindings::platform_bus_type`, hence `dev` must be embedded in a valid + // `struct platform_device` as guaranteed by the corresponding C code. + let pdev = unsafe { container_of!(dev.as_raw(), bindings::platform_device, dev) }; + + // SAFETY: `pdev` is a valid pointer to a `struct platform_device`. + Ok(unsafe { &*pdev.cast() }) + } +} + +// SAFETY: A `Device` is always reference-counted and can be released from any thread. +unsafe impl Send for Device {} + +// SAFETY: `Device` can be shared among threads because all methods of `Device` +// (i.e. `Device<Normal>) are thread safe. +unsafe impl Sync for Device {} diff --git a/rust/kernel/prelude.rs b/rust/kernel/prelude.rs index 9ab4e0b6cbc9..25fe97aafd02 100644 --- a/rust/kernel/prelude.rs +++ b/rust/kernel/prelude.rs @@ -14,21 +14,28 @@ #[doc(no_inline)] pub use core::pin::Pin; +pub use ::ffi::{ + c_char, c_int, c_long, c_longlong, c_schar, c_short, c_uchar, c_uint, c_ulong, c_ulonglong, + c_ushort, c_void, +}; + pub use crate::alloc::{flags::*, Box, KBox, KVBox, KVVec, KVec, VBox, VVec, Vec}; #[doc(no_inline)] -pub use macros::{module, pin_data, pinned_drop, vtable, Zeroable}; +pub use macros::{export, kunit_tests, module, vtable}; + +pub use pin_init::{init, pin_data, pin_init, pinned_drop, InPlaceWrite, Init, PinInit, Zeroable}; -pub use super::build_assert; +pub use super::{build_assert, build_error}; // `super::std_vendor` is hidden, which makes the macro inline for some reason. #[doc(no_inline)] pub use super::dbg; -pub use super::fmt; pub use super::{dev_alert, dev_crit, dev_dbg, dev_emerg, dev_err, dev_info, dev_notice, dev_warn}; pub use super::{pr_alert, pr_crit, pr_debug, pr_emerg, pr_err, pr_info, pr_notice, pr_warn}; +pub use core::format_args as fmt; -pub use super::{init, pin_init, try_init, try_pin_init}; +pub use super::{try_init, try_pin_init}; pub use super::static_assert; @@ -36,6 +43,8 @@ pub use super::error::{code::*, Error, Result}; pub use super::{str::CStr, ThisModule}; -pub use super::init::{InPlaceInit, InPlaceWrite, Init, PinInit}; +pub use super::init::InPlaceInit; pub use super::current; + +pub use super::uaccess::UserPtr; diff --git a/rust/kernel/print.rs b/rust/kernel/print.rs index b19ee490be58..2d743d78d220 100644 --- a/rust/kernel/print.rs +++ b/rust/kernel/print.rs @@ -6,16 +6,16 @@ //! //! Reference: <https://docs.kernel.org/core-api/printk-basics.html> -use core::{ +use crate::{ ffi::{c_char, c_void}, fmt, + prelude::*, + str::RawFormatter, }; -use crate::str::RawFormatter; - // Called from `vsprintf` with format specifier `%pA`. #[expect(clippy::missing_safety_doc)] -#[no_mangle] +#[export] unsafe extern "C" fn rust_fmt_argument( buf: *mut c_char, end: *mut c_char, @@ -25,7 +25,7 @@ unsafe extern "C" fn rust_fmt_argument( // SAFETY: The C contract guarantees that `buf` is valid if it's less than `end`. let mut w = unsafe { RawFormatter::from_ptrs(buf.cast(), end.cast()) }; // SAFETY: TODO. - let _ = w.write_fmt(unsafe { *(ptr as *const fmt::Arguments<'_>) }); + let _ = w.write_fmt(unsafe { *ptr.cast::<fmt::Arguments<'_>>() }); w.pos().cast() } @@ -109,7 +109,7 @@ pub unsafe fn call_printk( bindings::_printk( format_string.as_ptr(), module_name.as_ptr(), - &args as *const _ as *const c_void, + core::ptr::from_ref(&args).cast::<c_void>(), ); } } @@ -129,7 +129,7 @@ pub fn call_printk_cont(args: fmt::Arguments<'_>) { unsafe { bindings::_printk( format_strings::CONT.as_ptr(), - &args as *const _ as *const c_void, + core::ptr::from_ref(&args).cast::<c_void>(), ); } } @@ -149,7 +149,7 @@ macro_rules! print_macro ( // takes borrows on the arguments, but does not extend the scope of temporaries. // Therefore, a `match` expression is used to keep them around, since // the scrutinee is kept until the end of the `match`. - match format_args!($($arg)+) { + match $crate::prelude::fmt!($($arg)+) { // SAFETY: This hidden macro should only be called by the documented // printing macros which ensure the format string is one of the fixed // ones. All `__LOG_PREFIX`s are null-terminated as they are generated @@ -168,7 +168,7 @@ macro_rules! print_macro ( // The `CONT` case. ($format_string:path, true, $($arg:tt)+) => ( $crate::print::call_printk_cont( - format_args!($($arg)+), + $crate::prelude::fmt!($($arg)+), ); ); ); @@ -198,10 +198,11 @@ macro_rules! print_macro ( /// Equivalent to the kernel's [`pr_emerg`] macro. /// /// Mimics the interface of [`std::print!`]. See [`core::fmt`] and -/// `alloc::format!` for information about the formatting syntax. +/// [`std::format!`] for information about the formatting syntax. /// /// [`pr_emerg`]: https://docs.kernel.org/core-api/printk-basics.html#c.pr_emerg /// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html +/// [`std::format!`]: https://doc.rust-lang.org/std/macro.format.html /// /// # Examples /// @@ -222,10 +223,11 @@ macro_rules! pr_emerg ( /// Equivalent to the kernel's [`pr_alert`] macro. /// /// Mimics the interface of [`std::print!`]. See [`core::fmt`] and -/// `alloc::format!` for information about the formatting syntax. +/// [`std::format!`] for information about the formatting syntax. /// /// [`pr_alert`]: https://docs.kernel.org/core-api/printk-basics.html#c.pr_alert /// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html +/// [`std::format!`]: https://doc.rust-lang.org/std/macro.format.html /// /// # Examples /// @@ -246,10 +248,11 @@ macro_rules! pr_alert ( /// Equivalent to the kernel's [`pr_crit`] macro. /// /// Mimics the interface of [`std::print!`]. See [`core::fmt`] and -/// `alloc::format!` for information about the formatting syntax. +/// [`std::format!`] for information about the formatting syntax. /// /// [`pr_crit`]: https://docs.kernel.org/core-api/printk-basics.html#c.pr_crit /// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html +/// [`std::format!`]: https://doc.rust-lang.org/std/macro.format.html /// /// # Examples /// @@ -270,10 +273,11 @@ macro_rules! pr_crit ( /// Equivalent to the kernel's [`pr_err`] macro. /// /// Mimics the interface of [`std::print!`]. See [`core::fmt`] and -/// `alloc::format!` for information about the formatting syntax. +/// [`std::format!`] for information about the formatting syntax. /// /// [`pr_err`]: https://docs.kernel.org/core-api/printk-basics.html#c.pr_err /// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html +/// [`std::format!`]: https://doc.rust-lang.org/std/macro.format.html /// /// # Examples /// @@ -294,10 +298,11 @@ macro_rules! pr_err ( /// Equivalent to the kernel's [`pr_warn`] macro. /// /// Mimics the interface of [`std::print!`]. See [`core::fmt`] and -/// `alloc::format!` for information about the formatting syntax. +/// [`std::format!`] for information about the formatting syntax. /// /// [`pr_warn`]: https://docs.kernel.org/core-api/printk-basics.html#c.pr_warn /// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html +/// [`std::format!`]: https://doc.rust-lang.org/std/macro.format.html /// /// # Examples /// @@ -318,10 +323,11 @@ macro_rules! pr_warn ( /// Equivalent to the kernel's [`pr_notice`] macro. /// /// Mimics the interface of [`std::print!`]. See [`core::fmt`] and -/// `alloc::format!` for information about the formatting syntax. +/// [`std::format!`] for information about the formatting syntax. /// /// [`pr_notice`]: https://docs.kernel.org/core-api/printk-basics.html#c.pr_notice /// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html +/// [`std::format!`]: https://doc.rust-lang.org/std/macro.format.html /// /// # Examples /// @@ -342,10 +348,11 @@ macro_rules! pr_notice ( /// Equivalent to the kernel's [`pr_info`] macro. /// /// Mimics the interface of [`std::print!`]. See [`core::fmt`] and -/// `alloc::format!` for information about the formatting syntax. +/// [`std::format!`] for information about the formatting syntax. /// /// [`pr_info`]: https://docs.kernel.org/core-api/printk-basics.html#c.pr_info /// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html +/// [`std::format!`]: https://doc.rust-lang.org/std/macro.format.html /// /// # Examples /// @@ -368,10 +375,11 @@ macro_rules! pr_info ( /// yet. /// /// Mimics the interface of [`std::print!`]. See [`core::fmt`] and -/// `alloc::format!` for information about the formatting syntax. +/// [`std::format!`] for information about the formatting syntax. /// /// [`pr_debug`]: https://docs.kernel.org/core-api/printk-basics.html#c.pr_debug /// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html +/// [`std::format!`]: https://doc.rust-lang.org/std/macro.format.html /// /// # Examples /// @@ -395,11 +403,12 @@ macro_rules! pr_debug ( /// Equivalent to the kernel's [`pr_cont`] macro. /// /// Mimics the interface of [`std::print!`]. See [`core::fmt`] and -/// `alloc::format!` for information about the formatting syntax. +/// [`std::format!`] for information about the formatting syntax. /// /// [`pr_info!`]: crate::pr_info! /// [`pr_cont`]: https://docs.kernel.org/core-api/printk-basics.html#c.pr_cont /// [`std::print!`]: https://doc.rust-lang.org/std/macro.print.html +/// [`std::format!`]: https://doc.rust-lang.org/std/macro.format.html /// /// # Examples /// diff --git a/rust/kernel/rbtree.rs b/rust/kernel/rbtree.rs index 571e27efe544..b8fe6be6fcc4 100644 --- a/rust/kernel/rbtree.rs +++ b/rust/kernel/rbtree.rs @@ -36,17 +36,17 @@ use core::{ /// /// // Check the nodes we just inserted. /// { -/// assert_eq!(tree.get(&10).unwrap(), &100); -/// assert_eq!(tree.get(&20).unwrap(), &200); -/// assert_eq!(tree.get(&30).unwrap(), &300); +/// assert_eq!(tree.get(&10), Some(&100)); +/// assert_eq!(tree.get(&20), Some(&200)); +/// assert_eq!(tree.get(&30), Some(&300)); /// } /// /// // Iterate over the nodes we just inserted. /// { /// let mut iter = tree.iter(); -/// assert_eq!(iter.next().unwrap(), (&10, &100)); -/// assert_eq!(iter.next().unwrap(), (&20, &200)); -/// assert_eq!(iter.next().unwrap(), (&30, &300)); +/// assert_eq!(iter.next(), Some((&10, &100))); +/// assert_eq!(iter.next(), Some((&20, &200))); +/// assert_eq!(iter.next(), Some((&30, &300))); /// assert!(iter.next().is_none()); /// } /// @@ -61,9 +61,9 @@ use core::{ /// // Check that the tree reflects the replacement. /// { /// let mut iter = tree.iter(); -/// assert_eq!(iter.next().unwrap(), (&10, &1000)); -/// assert_eq!(iter.next().unwrap(), (&20, &200)); -/// assert_eq!(iter.next().unwrap(), (&30, &300)); +/// assert_eq!(iter.next(), Some((&10, &1000))); +/// assert_eq!(iter.next(), Some((&20, &200))); +/// assert_eq!(iter.next(), Some((&30, &300))); /// assert!(iter.next().is_none()); /// } /// @@ -73,9 +73,9 @@ use core::{ /// // Check that the tree reflects the update. /// { /// let mut iter = tree.iter(); -/// assert_eq!(iter.next().unwrap(), (&10, &1000)); -/// assert_eq!(iter.next().unwrap(), (&20, &200)); -/// assert_eq!(iter.next().unwrap(), (&30, &3000)); +/// assert_eq!(iter.next(), Some((&10, &1000))); +/// assert_eq!(iter.next(), Some((&20, &200))); +/// assert_eq!(iter.next(), Some((&30, &3000))); /// assert!(iter.next().is_none()); /// } /// @@ -85,8 +85,8 @@ use core::{ /// // Check that the tree reflects the removal. /// { /// let mut iter = tree.iter(); -/// assert_eq!(iter.next().unwrap(), (&20, &200)); -/// assert_eq!(iter.next().unwrap(), (&30, &3000)); +/// assert_eq!(iter.next(), Some((&20, &200))); +/// assert_eq!(iter.next(), Some((&30, &3000))); /// assert!(iter.next().is_none()); /// } /// @@ -128,20 +128,20 @@ use core::{ /// // Check the nodes we just inserted. /// { /// let mut iter = tree.iter(); -/// assert_eq!(iter.next().unwrap(), (&10, &100)); -/// assert_eq!(iter.next().unwrap(), (&20, &200)); -/// assert_eq!(iter.next().unwrap(), (&30, &300)); +/// assert_eq!(iter.next(), Some((&10, &100))); +/// assert_eq!(iter.next(), Some((&20, &200))); +/// assert_eq!(iter.next(), Some((&30, &300))); /// assert!(iter.next().is_none()); /// } /// /// // Remove a node, getting back ownership of it. -/// let existing = tree.remove(&30).unwrap(); +/// let existing = tree.remove(&30); /// /// // Check that the tree reflects the removal. /// { /// let mut iter = tree.iter(); -/// assert_eq!(iter.next().unwrap(), (&10, &100)); -/// assert_eq!(iter.next().unwrap(), (&20, &200)); +/// assert_eq!(iter.next(), Some((&10, &100))); +/// assert_eq!(iter.next(), Some((&20, &200))); /// assert!(iter.next().is_none()); /// } /// @@ -155,9 +155,9 @@ use core::{ /// // Check that the tree reflect the new insertion. /// { /// let mut iter = tree.iter(); -/// assert_eq!(iter.next().unwrap(), (&10, &100)); -/// assert_eq!(iter.next().unwrap(), (&15, &150)); -/// assert_eq!(iter.next().unwrap(), (&20, &200)); +/// assert_eq!(iter.next(), Some((&10, &100))); +/// assert_eq!(iter.next(), Some((&15, &150))); +/// assert_eq!(iter.next(), Some((&20, &200))); /// assert!(iter.next().is_none()); /// } /// @@ -191,6 +191,12 @@ impl<K, V> RBTree<K, V> { } } + /// Returns true if this tree is empty. + #[inline] + pub fn is_empty(&self) -> bool { + self.root.rb_node.is_null() + } + /// Returns an iterator over the tree nodes, sorted by key. pub fn iter(&self) -> Iter<'_, K, V> { Iter { @@ -424,7 +430,7 @@ where while !node.is_null() { // SAFETY: By the type invariant of `Self`, all non-null `rb_node` pointers stored in `self` // point to the links field of `Node<K, V>` objects. - let this = unsafe { container_of!(node, Node<K, V>, links) }.cast_mut(); + let this = unsafe { container_of!(node, Node<K, V>, links) }; // SAFETY: `this` is a non-null node so it is valid by the type invariants. let this_key = unsafe { &(*this).key }; // SAFETY: `node` is a non-null node so it is valid by the type invariants. @@ -496,7 +502,7 @@ impl<K, V> Drop for RBTree<K, V> { // but it is not observable. The loop invariant is still maintained. // SAFETY: `this` is valid per the loop invariant. - unsafe { drop(KBox::from_raw(this.cast_mut())) }; + unsafe { drop(KBox::from_raw(this)) }; } } } @@ -761,7 +767,7 @@ impl<'a, K, V> Cursor<'a, K, V> { let next = self.get_neighbor_raw(Direction::Next); // SAFETY: By the type invariant of `Self`, all non-null `rb_node` pointers stored in `self` // point to the links field of `Node<K, V>` objects. - let this = unsafe { container_of!(self.current.as_ptr(), Node<K, V>, links) }.cast_mut(); + let this = unsafe { container_of!(self.current.as_ptr(), Node<K, V>, links) }; // SAFETY: `this` is valid by the type invariants as described above. let node = unsafe { KBox::from_raw(this) }; let node = RBTreeNode { node }; @@ -769,23 +775,14 @@ impl<'a, K, V> Cursor<'a, K, V> { // the tree cannot change. By the tree invariant, all nodes are valid. unsafe { bindings::rb_erase(&mut (*this).links, addr_of_mut!(self.tree.root)) }; - let current = match (prev, next) { - (_, Some(next)) => next, - (Some(prev), None) => prev, - (None, None) => { - return (None, node); - } - }; + // INVARIANT: + // - `current` is a valid node in the [`RBTree`] pointed to by `self.tree`. + let cursor = next.or(prev).map(|current| Self { + current, + tree: self.tree, + }); - ( - // INVARIANT: - // - `current` is a valid node in the [`RBTree`] pointed to by `self.tree`. - Some(Self { - current, - tree: self.tree, - }), - node, - ) + (cursor, node) } /// Remove the previous node, returning it if it exists. @@ -806,7 +803,7 @@ impl<'a, K, V> Cursor<'a, K, V> { unsafe { bindings::rb_erase(neighbor, addr_of_mut!(self.tree.root)) }; // SAFETY: By the type invariant of `Self`, all non-null `rb_node` pointers stored in `self` // point to the links field of `Node<K, V>` objects. - let this = unsafe { container_of!(neighbor, Node<K, V>, links) }.cast_mut(); + let this = unsafe { container_of!(neighbor, Node<K, V>, links) }; // SAFETY: `this` is valid by the type invariants as described above. let node = unsafe { KBox::from_raw(this) }; return Some(RBTreeNode { node }); @@ -886,7 +883,7 @@ impl<'a, K, V> Cursor<'a, K, V> { /// # Safety /// /// - `node` must be a valid pointer to a node in an [`RBTree`]. - /// - The caller has immutable access to `node` for the duration of 'b. + /// - The caller has immutable access to `node` for the duration of `'b`. unsafe fn to_key_value<'b>(node: NonNull<bindings::rb_node>) -> (&'b K, &'b V) { // SAFETY: the caller guarantees that `node` is a valid pointer in an `RBTree`. let (k, v) = unsafe { Self::to_key_value_raw(node) }; @@ -897,7 +894,7 @@ impl<'a, K, V> Cursor<'a, K, V> { /// # Safety /// /// - `node` must be a valid pointer to a node in an [`RBTree`]. - /// - The caller has mutable access to `node` for the duration of 'b. + /// - The caller has mutable access to `node` for the duration of `'b`. unsafe fn to_key_value_mut<'b>(node: NonNull<bindings::rb_node>) -> (&'b K, &'b mut V) { // SAFETY: the caller guarantees that `node` is a valid pointer in an `RBTree`. let (k, v) = unsafe { Self::to_key_value_raw(node) }; @@ -908,11 +905,11 @@ impl<'a, K, V> Cursor<'a, K, V> { /// # Safety /// /// - `node` must be a valid pointer to a node in an [`RBTree`]. - /// - The caller has immutable access to the key for the duration of 'b. + /// - The caller has immutable access to the key for the duration of `'b`. unsafe fn to_key_value_raw<'b>(node: NonNull<bindings::rb_node>) -> (&'b K, *mut V) { // SAFETY: By the type invariant of `Self`, all non-null `rb_node` pointers stored in `self` // point to the links field of `Node<K, V>` objects. - let this = unsafe { container_of!(node.as_ptr(), Node<K, V>, links) }.cast_mut(); + let this = unsafe { container_of!(node.as_ptr(), Node<K, V>, links) }; // SAFETY: The passed `node` is the current node or a non-null neighbor, // thus `this` is valid by the type invariants. let k = unsafe { &(*this).key }; @@ -1021,7 +1018,7 @@ impl<K, V> Iterator for IterRaw<K, V> { // SAFETY: By the type invariant of `IterRaw`, `self.next` is a valid node in an `RBTree`, // and by the type invariant of `RBTree`, all nodes point to the links field of `Node<K, V>` objects. - let cur = unsafe { container_of!(self.next, Node<K, V>, links) }.cast_mut(); + let cur = unsafe { container_of!(self.next, Node<K, V>, links) }; // SAFETY: `self.next` is a valid tree node by the type invariants. self.next = unsafe { bindings::rb_next(self.next) }; @@ -1168,12 +1165,12 @@ impl<'a, K, V> RawVacantEntry<'a, K, V> { fn insert(self, node: RBTreeNode<K, V>) -> &'a mut V { let node = KBox::into_raw(node.node); - // SAFETY: `node` is valid at least until we call `Box::from_raw`, which only happens when + // SAFETY: `node` is valid at least until we call `KBox::from_raw`, which only happens when // the node is removed or replaced. let node_links = unsafe { addr_of_mut!((*node).links) }; // INVARIANT: We are linking in a new node, which is valid. It remains valid because we - // "forgot" it with `Box::into_raw`. + // "forgot" it with `KBox::into_raw`. // SAFETY: The type invariants of `RawVacantEntry` are exactly the safety requirements of `rb_link_node`. unsafe { bindings::rb_link_node(node_links, self.parent, self.child_field_of_parent) }; @@ -1216,7 +1213,7 @@ impl<'a, K, V> OccupiedEntry<'a, K, V> { // SAFETY: // - `self.node_links` is a valid pointer to a node in the tree. // - We have exclusive access to the underlying tree, and can thus give out a mutable reference. - unsafe { &mut (*(container_of!(self.node_links, Node<K, V>, links).cast_mut())).value } + unsafe { &mut (*(container_of!(self.node_links, Node<K, V>, links))).value } } /// Converts the entry into a mutable reference to its value. @@ -1226,7 +1223,7 @@ impl<'a, K, V> OccupiedEntry<'a, K, V> { // SAFETY: // - `self.node_links` is a valid pointer to a node in the tree. // - This consumes the `&'a mut RBTree<K, V>`, therefore it can give out a mutable reference that lives for `'a`. - unsafe { &mut (*(container_of!(self.node_links, Node<K, V>, links).cast_mut())).value } + unsafe { &mut (*(container_of!(self.node_links, Node<K, V>, links))).value } } /// Remove this entry from the [`RBTree`]. @@ -1239,9 +1236,7 @@ impl<'a, K, V> OccupiedEntry<'a, K, V> { RBTreeNode { // SAFETY: The node was a node in the tree, but we removed it, so we can convert it // back into a box. - node: unsafe { - KBox::from_raw(container_of!(self.node_links, Node<K, V>, links).cast_mut()) - }, + node: unsafe { KBox::from_raw(container_of!(self.node_links, Node<K, V>, links)) }, } } @@ -1259,7 +1254,7 @@ impl<'a, K, V> OccupiedEntry<'a, K, V> { fn replace(self, node: RBTreeNode<K, V>) -> RBTreeNode<K, V> { let node = KBox::into_raw(node.node); - // SAFETY: `node` is valid at least until we call `Box::from_raw`, which only happens when + // SAFETY: `node` is valid at least until we call `KBox::from_raw`, which only happens when // the node is removed or replaced. let new_node_links = unsafe { addr_of_mut!((*node).links) }; @@ -1272,8 +1267,7 @@ impl<'a, K, V> OccupiedEntry<'a, K, V> { // SAFETY: // - `self.node_ptr` produces a valid pointer to a node in the tree. // - Now that we removed this entry from the tree, we can convert the node to a box. - let old_node = - unsafe { KBox::from_raw(container_of!(self.node_links, Node<K, V>, links).cast_mut()) }; + let old_node = unsafe { KBox::from_raw(container_of!(self.node_links, Node<K, V>, links)) }; RBTreeNode { node: old_node } } diff --git a/rust/kernel/regulator.rs b/rust/kernel/regulator.rs new file mode 100644 index 000000000000..65f3a125348f --- /dev/null +++ b/rust/kernel/regulator.rs @@ -0,0 +1,418 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Regulator abstractions, providing a standard kernel interface to control +//! voltage and current regulators. +//! +//! The intention is to allow systems to dynamically control regulator power +//! output in order to save power and prolong battery life. This applies to both +//! voltage regulators (where voltage output is controllable) and current sinks +//! (where current limit is controllable). +//! +//! C header: [`include/linux/regulator/consumer.h`](srctree/include/linux/regulator/consumer.h) +//! +//! Regulators are modeled in Rust with a collection of states. Each state may +//! enforce a given invariant, and they may convert between each other where applicable. +//! +//! See [Voltage and current regulator API](https://docs.kernel.org/driver-api/regulator.html) +//! for more information. + +use crate::{ + bindings, + device::Device, + error::{from_err_ptr, to_result, Result}, + prelude::*, +}; + +use core::{marker::PhantomData, mem::ManuallyDrop, ptr::NonNull}; + +mod private { + pub trait Sealed {} + + impl Sealed for super::Enabled {} + impl Sealed for super::Disabled {} + impl Sealed for super::Dynamic {} +} + +/// A trait representing the different states a [`Regulator`] can be in. +pub trait RegulatorState: private::Sealed + 'static { + /// Whether the regulator should be disabled when dropped. + const DISABLE_ON_DROP: bool; +} + +/// A state where the [`Regulator`] is known to be enabled. +/// +/// The `enable` reference count held by this state is decremented when it is +/// dropped. +pub struct Enabled; + +/// A state where this [`Regulator`] handle has not specifically asked for the +/// underlying regulator to be enabled. This means that this reference does not +/// own an `enable` reference count, but the regulator may still be on. +pub struct Disabled; + +/// A state that models the C API. The [`Regulator`] can be either enabled or +/// disabled, and the user is in control of the reference count. This is also +/// the default state. +/// +/// Use [`Regulator::is_enabled`] to check the regulator's current state. +pub struct Dynamic; + +impl RegulatorState for Enabled { + const DISABLE_ON_DROP: bool = true; +} + +impl RegulatorState for Disabled { + const DISABLE_ON_DROP: bool = false; +} + +impl RegulatorState for Dynamic { + const DISABLE_ON_DROP: bool = false; +} + +/// A trait that abstracts the ability to check if a [`Regulator`] is enabled. +pub trait IsEnabled: RegulatorState {} +impl IsEnabled for Disabled {} +impl IsEnabled for Dynamic {} + +/// An error that can occur when trying to convert a [`Regulator`] between states. +pub struct Error<State: RegulatorState> { + /// The error that occurred. + pub error: kernel::error::Error, + + /// The regulator that caused the error, so that the operation may be retried. + pub regulator: Regulator<State>, +} + +/// A `struct regulator` abstraction. +/// +/// # Examples +/// +/// ## Enabling a regulator +/// +/// This example uses [`Regulator<Enabled>`], which is suitable for drivers that +/// enable a regulator at probe time and leave them on until the device is +/// removed or otherwise shutdown. +/// +/// These users can store [`Regulator<Enabled>`] directly in their driver's +/// private data struct. +/// +/// ``` +/// # use kernel::prelude::*; +/// # use kernel::c_str; +/// # use kernel::device::Device; +/// # use kernel::regulator::{Voltage, Regulator, Disabled, Enabled}; +/// fn enable(dev: &Device, min_voltage: Voltage, max_voltage: Voltage) -> Result { +/// // Obtain a reference to a (fictitious) regulator. +/// let regulator: Regulator<Disabled> = Regulator::<Disabled>::get(dev, c_str!("vcc"))?; +/// +/// // The voltage can be set before enabling the regulator if needed, e.g.: +/// regulator.set_voltage(min_voltage, max_voltage)?; +/// +/// // The same applies for `get_voltage()`, i.e.: +/// let voltage: Voltage = regulator.get_voltage()?; +/// +/// // Enables the regulator, consuming the previous value. +/// // +/// // From now on, the regulator is known to be enabled because of the type +/// // `Enabled`. +/// // +/// // If this operation fails, the `Error` will contain the regulator +/// // reference, so that the operation may be retried. +/// let regulator: Regulator<Enabled> = +/// regulator.try_into_enabled().map_err(|error| error.error)?; +/// +/// // The voltage can also be set after enabling the regulator, e.g.: +/// regulator.set_voltage(min_voltage, max_voltage)?; +/// +/// // The same applies for `get_voltage()`, i.e.: +/// let voltage: Voltage = regulator.get_voltage()?; +/// +/// // Dropping an enabled regulator will disable it. The refcount will be +/// // decremented. +/// drop(regulator); +/// +/// // ... +/// +/// Ok(()) +/// } +/// ``` +/// +/// A more concise shortcut is available for enabling a regulator. This is +/// equivalent to `regulator_get_enable()`: +/// +/// ``` +/// # use kernel::prelude::*; +/// # use kernel::c_str; +/// # use kernel::device::Device; +/// # use kernel::regulator::{Voltage, Regulator, Enabled}; +/// fn enable(dev: &Device) -> Result { +/// // Obtain a reference to a (fictitious) regulator and enable it. +/// let regulator: Regulator<Enabled> = Regulator::<Enabled>::get(dev, c_str!("vcc"))?; +/// +/// // Dropping an enabled regulator will disable it. The refcount will be +/// // decremented. +/// drop(regulator); +/// +/// // ... +/// +/// Ok(()) +/// } +/// ``` +/// +/// ## Disabling a regulator +/// +/// ``` +/// # use kernel::prelude::*; +/// # use kernel::device::Device; +/// # use kernel::regulator::{Regulator, Enabled, Disabled}; +/// fn disable(dev: &Device, regulator: Regulator<Enabled>) -> Result { +/// // We can also disable an enabled regulator without reliquinshing our +/// // refcount: +/// // +/// // If this operation fails, the `Error` will contain the regulator +/// // reference, so that the operation may be retried. +/// let regulator: Regulator<Disabled> = +/// regulator.try_into_disabled().map_err(|error| error.error)?; +/// +/// // The refcount will be decremented when `regulator` is dropped. +/// drop(regulator); +/// +/// // ... +/// +/// Ok(()) +/// } +/// ``` +/// +/// ## Using [`Regulator<Dynamic>`] +/// +/// This example mimics the behavior of the C API, where the user is in +/// control of the enabled reference count. This is useful for drivers that +/// might call enable and disable to manage the `enable` reference count at +/// runtime, perhaps as a result of `open()` and `close()` calls or whatever +/// other driver-specific or subsystem-specific hooks. +/// +/// ``` +/// # use kernel::prelude::*; +/// # use kernel::c_str; +/// # use kernel::device::Device; +/// # use kernel::regulator::{Regulator, Dynamic}; +/// struct PrivateData { +/// regulator: Regulator<Dynamic>, +/// } +/// +/// // A fictictious probe function that obtains a regulator and sets it up. +/// fn probe(dev: &Device) -> Result<PrivateData> { +/// // Obtain a reference to a (fictitious) regulator. +/// let mut regulator = Regulator::<Dynamic>::get(dev, c_str!("vcc"))?; +/// +/// Ok(PrivateData { regulator }) +/// } +/// +/// // A fictictious function that indicates that the device is going to be used. +/// fn open(dev: &Device, data: &mut PrivateData) -> Result { +/// // Increase the `enabled` reference count. +/// data.regulator.enable()?; +/// +/// Ok(()) +/// } +/// +/// fn close(dev: &Device, data: &mut PrivateData) -> Result { +/// // Decrease the `enabled` reference count. +/// data.regulator.disable()?; +/// +/// Ok(()) +/// } +/// +/// fn remove(dev: &Device, data: PrivateData) -> Result { +/// // `PrivateData` is dropped here, which will drop the +/// // `Regulator<Dynamic>` in turn. +/// // +/// // The reference that was obtained by `regulator_get()` will be +/// // released, but it is up to the user to make sure that the number of calls +/// // to `enable()` and `disabled()` are balanced before this point. +/// Ok(()) +/// } +/// ``` +/// +/// # Invariants +/// +/// - `inner` is a non-null wrapper over a pointer to a `struct +/// regulator` obtained from [`regulator_get()`]. +/// +/// [`regulator_get()`]: https://docs.kernel.org/driver-api/regulator.html#c.regulator_get +pub struct Regulator<State = Dynamic> +where + State: RegulatorState, +{ + inner: NonNull<bindings::regulator>, + _phantom: PhantomData<State>, +} + +impl<T: RegulatorState> Regulator<T> { + /// Sets the voltage for the regulator. + /// + /// This can be used to ensure that the device powers up cleanly. + pub fn set_voltage(&self, min_voltage: Voltage, max_voltage: Voltage) -> Result { + // SAFETY: Safe as per the type invariants of `Regulator`. + to_result(unsafe { + bindings::regulator_set_voltage( + self.inner.as_ptr(), + min_voltage.as_microvolts(), + max_voltage.as_microvolts(), + ) + }) + } + + /// Gets the current voltage of the regulator. + pub fn get_voltage(&self) -> Result<Voltage> { + // SAFETY: Safe as per the type invariants of `Regulator`. + let voltage = unsafe { bindings::regulator_get_voltage(self.inner.as_ptr()) }; + if voltage < 0 { + Err(kernel::error::Error::from_errno(voltage)) + } else { + Ok(Voltage::from_microvolts(voltage)) + } + } + + fn get_internal(dev: &Device, name: &CStr) -> Result<Regulator<T>> { + // SAFETY: It is safe to call `regulator_get()`, on a device pointer + // received from the C code. + let inner = from_err_ptr(unsafe { bindings::regulator_get(dev.as_raw(), name.as_ptr()) })?; + + // SAFETY: We can safely trust `inner` to be a pointer to a valid + // regulator if `ERR_PTR` was not returned. + let inner = unsafe { NonNull::new_unchecked(inner) }; + + Ok(Self { + inner, + _phantom: PhantomData, + }) + } + + fn enable_internal(&mut self) -> Result { + // SAFETY: Safe as per the type invariants of `Regulator`. + to_result(unsafe { bindings::regulator_enable(self.inner.as_ptr()) }) + } + + fn disable_internal(&mut self) -> Result { + // SAFETY: Safe as per the type invariants of `Regulator`. + to_result(unsafe { bindings::regulator_disable(self.inner.as_ptr()) }) + } +} + +impl Regulator<Disabled> { + /// Obtains a [`Regulator`] instance from the system. + pub fn get(dev: &Device, name: &CStr) -> Result<Self> { + Regulator::get_internal(dev, name) + } + + /// Attempts to convert the regulator to an enabled state. + pub fn try_into_enabled(self) -> Result<Regulator<Enabled>, Error<Disabled>> { + // We will be transferring the ownership of our `regulator_get()` count to + // `Regulator<Enabled>`. + let mut regulator = ManuallyDrop::new(self); + + regulator + .enable_internal() + .map(|()| Regulator { + inner: regulator.inner, + _phantom: PhantomData, + }) + .map_err(|error| Error { + error, + regulator: ManuallyDrop::into_inner(regulator), + }) + } +} + +impl Regulator<Enabled> { + /// Obtains a [`Regulator`] instance from the system and enables it. + /// + /// This is equivalent to calling `regulator_get_enable()` in the C API. + pub fn get(dev: &Device, name: &CStr) -> Result<Self> { + Regulator::<Disabled>::get_internal(dev, name)? + .try_into_enabled() + .map_err(|error| error.error) + } + + /// Attempts to convert the regulator to a disabled state. + pub fn try_into_disabled(self) -> Result<Regulator<Disabled>, Error<Enabled>> { + // We will be transferring the ownership of our `regulator_get()` count + // to `Regulator<Disabled>`. + let mut regulator = ManuallyDrop::new(self); + + regulator + .disable_internal() + .map(|()| Regulator { + inner: regulator.inner, + _phantom: PhantomData, + }) + .map_err(|error| Error { + error, + regulator: ManuallyDrop::into_inner(regulator), + }) + } +} + +impl Regulator<Dynamic> { + /// Obtains a [`Regulator`] instance from the system. The current state of + /// the regulator is unknown and it is up to the user to manage the enabled + /// reference count. + /// + /// This closely mimics the behavior of the C API and can be used to + /// dynamically manage the enabled reference count at runtime. + pub fn get(dev: &Device, name: &CStr) -> Result<Self> { + Regulator::get_internal(dev, name) + } + + /// Increases the `enabled` reference count. + pub fn enable(&mut self) -> Result { + self.enable_internal() + } + + /// Decreases the `enabled` reference count. + pub fn disable(&mut self) -> Result { + self.disable_internal() + } +} + +impl<T: IsEnabled> Regulator<T> { + /// Checks if the regulator is enabled. + pub fn is_enabled(&self) -> bool { + // SAFETY: Safe as per the type invariants of `Regulator`. + unsafe { bindings::regulator_is_enabled(self.inner.as_ptr()) != 0 } + } +} + +impl<T: RegulatorState> Drop for Regulator<T> { + fn drop(&mut self) { + if T::DISABLE_ON_DROP { + // SAFETY: By the type invariants, we know that `self` owns a + // reference on the enabled refcount, so it is safe to relinquish it + // now. + unsafe { bindings::regulator_disable(self.inner.as_ptr()) }; + } + // SAFETY: By the type invariants, we know that `self` owns a reference, + // so it is safe to relinquish it now. + unsafe { bindings::regulator_put(self.inner.as_ptr()) }; + } +} + +/// A voltage. +/// +/// This type represents a voltage value in microvolts. +#[repr(transparent)] +#[derive(Copy, Clone, PartialEq, Eq)] +pub struct Voltage(i32); + +impl Voltage { + /// Creates a new `Voltage` from a value in microvolts. + pub fn from_microvolts(uv: i32) -> Self { + Self(uv) + } + + /// Returns the value of the voltage in microvolts as an [`i32`]. + pub fn as_microvolts(self) -> i32 { + self.0 + } +} diff --git a/rust/kernel/revocable.rs b/rust/kernel/revocable.rs new file mode 100644 index 000000000000..0f4ae673256d --- /dev/null +++ b/rust/kernel/revocable.rs @@ -0,0 +1,263 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Revocable objects. +//! +//! The [`Revocable`] type wraps other types and allows access to them to be revoked. The existence +//! of a [`RevocableGuard`] ensures that objects remain valid. + +use pin_init::Wrapper; + +use crate::{bindings, prelude::*, sync::rcu, types::Opaque}; +use core::{ + marker::PhantomData, + ops::Deref, + ptr::drop_in_place, + sync::atomic::{AtomicBool, Ordering}, +}; + +/// An object that can become inaccessible at runtime. +/// +/// Once access is revoked and all concurrent users complete (i.e., all existing instances of +/// [`RevocableGuard`] are dropped), the wrapped object is also dropped. +/// +/// # Examples +/// +/// ``` +/// # use kernel::revocable::Revocable; +/// +/// struct Example { +/// a: u32, +/// b: u32, +/// } +/// +/// fn add_two(v: &Revocable<Example>) -> Option<u32> { +/// let guard = v.try_access()?; +/// Some(guard.a + guard.b) +/// } +/// +/// let v = KBox::pin_init(Revocable::new(Example { a: 10, b: 20 }), GFP_KERNEL).unwrap(); +/// assert_eq!(add_two(&v), Some(30)); +/// v.revoke(); +/// assert_eq!(add_two(&v), None); +/// ``` +/// +/// Sample example as above, but explicitly using the rcu read side lock. +/// +/// ``` +/// # use kernel::revocable::Revocable; +/// use kernel::sync::rcu; +/// +/// struct Example { +/// a: u32, +/// b: u32, +/// } +/// +/// fn add_two(v: &Revocable<Example>) -> Option<u32> { +/// let guard = rcu::read_lock(); +/// let e = v.try_access_with_guard(&guard)?; +/// Some(e.a + e.b) +/// } +/// +/// let v = KBox::pin_init(Revocable::new(Example { a: 10, b: 20 }), GFP_KERNEL).unwrap(); +/// assert_eq!(add_two(&v), Some(30)); +/// v.revoke(); +/// assert_eq!(add_two(&v), None); +/// ``` +#[pin_data(PinnedDrop)] +pub struct Revocable<T> { + is_available: AtomicBool, + #[pin] + data: Opaque<T>, +} + +// SAFETY: `Revocable` is `Send` if the wrapped object is also `Send`. This is because while the +// functionality exposed by `Revocable` can be accessed from any thread/CPU, it is possible that +// this isn't supported by the wrapped object. +unsafe impl<T: Send> Send for Revocable<T> {} + +// SAFETY: `Revocable` is `Sync` if the wrapped object is both `Send` and `Sync`. We require `Send` +// from the wrapped object as well because of `Revocable::revoke`, which can trigger the `Drop` +// implementation of the wrapped object from an arbitrary thread. +unsafe impl<T: Sync + Send> Sync for Revocable<T> {} + +impl<T> Revocable<T> { + /// Creates a new revocable instance of the given data. + pub fn new<E>(data: impl PinInit<T, E>) -> impl PinInit<Self, E> { + try_pin_init!(Self { + is_available: AtomicBool::new(true), + data <- Opaque::pin_init(data), + }? E) + } + + /// Tries to access the revocable wrapped object. + /// + /// Returns `None` if the object has been revoked and is therefore no longer accessible. + /// + /// Returns a guard that gives access to the object otherwise; the object is guaranteed to + /// remain accessible while the guard is alive. In such cases, callers are not allowed to sleep + /// because another CPU may be waiting to complete the revocation of this object. + pub fn try_access(&self) -> Option<RevocableGuard<'_, T>> { + let guard = rcu::read_lock(); + if self.is_available.load(Ordering::Relaxed) { + // Since `self.is_available` is true, data is initialised and has to remain valid + // because the RCU read side lock prevents it from being dropped. + Some(RevocableGuard::new(self.data.get(), guard)) + } else { + None + } + } + + /// Tries to access the revocable wrapped object. + /// + /// Returns `None` if the object has been revoked and is therefore no longer accessible. + /// + /// Returns a shared reference to the object otherwise; the object is guaranteed to + /// remain accessible while the rcu read side guard is alive. In such cases, callers are not + /// allowed to sleep because another CPU may be waiting to complete the revocation of this + /// object. + pub fn try_access_with_guard<'a>(&'a self, _guard: &'a rcu::Guard) -> Option<&'a T> { + if self.is_available.load(Ordering::Relaxed) { + // SAFETY: Since `self.is_available` is true, data is initialised and has to remain + // valid because the RCU read side lock prevents it from being dropped. + Some(unsafe { &*self.data.get() }) + } else { + None + } + } + + /// Tries to access the wrapped object and run a closure on it while the guard is held. + /// + /// This is a convenience method to run short non-sleepable code blocks while ensuring the + /// guard is dropped afterwards. [`Self::try_access`] carries the risk that the caller will + /// forget to explicitly drop that returned guard before calling sleepable code; this method + /// adds an extra safety to make sure it doesn't happen. + /// + /// Returns [`None`] if the object has been revoked and is therefore no longer accessible, or + /// the result of the closure wrapped in [`Some`]. If the closure returns a [`Result`] then the + /// return type becomes `Option<Result<>>`, which can be inconvenient. Users are encouraged to + /// define their own macro that turns the [`Option`] into a proper error code and flattens the + /// inner result into it if it makes sense within their subsystem. + pub fn try_access_with<R, F: FnOnce(&T) -> R>(&self, f: F) -> Option<R> { + self.try_access().map(|t| f(&*t)) + } + + /// Directly access the revocable wrapped object. + /// + /// # Safety + /// + /// The caller must ensure this [`Revocable`] instance hasn't been revoked and won't be revoked + /// as long as the returned `&T` lives. + pub unsafe fn access(&self) -> &T { + // SAFETY: By the safety requirement of this function it is guaranteed that + // `self.data.get()` is a valid pointer to an instance of `T`. + unsafe { &*self.data.get() } + } + + /// # Safety + /// + /// Callers must ensure that there are no more concurrent users of the revocable object. + unsafe fn revoke_internal<const SYNC: bool>(&self) -> bool { + let revoke = self.is_available.swap(false, Ordering::Relaxed); + + if revoke { + if SYNC { + // SAFETY: Just an FFI call, there are no further requirements. + unsafe { bindings::synchronize_rcu() }; + } + + // SAFETY: We know `self.data` is valid because only one CPU can succeed the + // `compare_exchange` above that takes `is_available` from `true` to `false`. + unsafe { drop_in_place(self.data.get()) }; + } + + revoke + } + + /// Revokes access to and drops the wrapped object. + /// + /// Access to the object is revoked immediately to new callers of [`Revocable::try_access`], + /// expecting that there are no concurrent users of the object. + /// + /// Returns `true` if `&self` has been revoked with this call, `false` if it was revoked + /// already. + /// + /// # Safety + /// + /// Callers must ensure that there are no more concurrent users of the revocable object. + pub unsafe fn revoke_nosync(&self) -> bool { + // SAFETY: By the safety requirement of this function, the caller ensures that nobody is + // accessing the data anymore and hence we don't have to wait for the grace period to + // finish. + unsafe { self.revoke_internal::<false>() } + } + + /// Revokes access to and drops the wrapped object. + /// + /// Access to the object is revoked immediately to new callers of [`Revocable::try_access`]. + /// + /// If there are concurrent users of the object (i.e., ones that called + /// [`Revocable::try_access`] beforehand and still haven't dropped the returned guard), this + /// function waits for the concurrent access to complete before dropping the wrapped object. + /// + /// Returns `true` if `&self` has been revoked with this call, `false` if it was revoked + /// already. + pub fn revoke(&self) -> bool { + // SAFETY: By passing `true` we ask `revoke_internal` to wait for the grace period to + // finish. + unsafe { self.revoke_internal::<true>() } + } +} + +#[pinned_drop] +impl<T> PinnedDrop for Revocable<T> { + fn drop(self: Pin<&mut Self>) { + // Drop only if the data hasn't been revoked yet (in which case it has already been + // dropped). + // SAFETY: We are not moving out of `p`, only dropping in place + let p = unsafe { self.get_unchecked_mut() }; + if *p.is_available.get_mut() { + // SAFETY: We know `self.data` is valid because no other CPU has changed + // `is_available` to `false` yet, and no other CPU can do it anymore because this CPU + // holds the only reference (mutable) to `self` now. + unsafe { drop_in_place(p.data.get()) }; + } + } +} + +/// A guard that allows access to a revocable object and keeps it alive. +/// +/// CPUs may not sleep while holding on to [`RevocableGuard`] because it's in atomic context +/// holding the RCU read-side lock. +/// +/// # Invariants +/// +/// The RCU read-side lock is held while the guard is alive. +pub struct RevocableGuard<'a, T> { + // This can't use the `&'a T` type because references that appear in function arguments must + // not become dangling during the execution of the function, which can happen if the + // `RevocableGuard` is passed as a function argument and then dropped during execution of the + // function. + data_ref: *const T, + _rcu_guard: rcu::Guard, + _p: PhantomData<&'a ()>, +} + +impl<T> RevocableGuard<'_, T> { + fn new(data_ref: *const T, rcu_guard: rcu::Guard) -> Self { + Self { + data_ref, + _rcu_guard: rcu_guard, + _p: PhantomData, + } + } +} + +impl<T> Deref for RevocableGuard<'_, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + // SAFETY: By the type invariants, we hold the rcu read-side lock, so the object is + // guaranteed to remain valid. + unsafe { &*self.data_ref } + } +} diff --git a/rust/kernel/security.rs b/rust/kernel/security.rs index ea4c58c81703..0c63e9e7e564 100644 --- a/rust/kernel/security.rs +++ b/rust/kernel/security.rs @@ -15,60 +15,60 @@ use crate::{ /// /// # Invariants /// -/// The `secdata` and `seclen` fields correspond to a valid security context as returned by a -/// successful call to `security_secid_to_secctx`, that has not yet been destroyed by calling -/// `security_release_secctx`. +/// The `ctx` field corresponds to a valid security context as returned by a successful call to +/// `security_secid_to_secctx`, that has not yet been released by `security_release_secctx`. pub struct SecurityCtx { - secdata: *mut crate::ffi::c_char, - seclen: usize, + ctx: bindings::lsm_context, } impl SecurityCtx { /// Get the security context given its id. + #[inline] pub fn from_secid(secid: u32) -> Result<Self> { - let mut secdata = core::ptr::null_mut(); - let mut seclen = 0u32; - // SAFETY: Just a C FFI call. The pointers are valid for writes. - to_result(unsafe { bindings::security_secid_to_secctx(secid, &mut secdata, &mut seclen) })?; + // SAFETY: `struct lsm_context` can be initialized to all zeros. + let mut ctx: bindings::lsm_context = unsafe { core::mem::zeroed() }; + + // SAFETY: Just a C FFI call. The pointer is valid for writes. + to_result(unsafe { bindings::security_secid_to_secctx(secid, &mut ctx) })?; // INVARIANT: If the above call did not fail, then we have a valid security context. - Ok(Self { - secdata, - seclen: seclen as usize, - }) + Ok(Self { ctx }) } /// Returns whether the security context is empty. + #[inline] pub fn is_empty(&self) -> bool { - self.seclen == 0 + self.ctx.len == 0 } /// Returns the length of this security context. + #[inline] pub fn len(&self) -> usize { - self.seclen + self.ctx.len as usize } /// Returns the bytes for this security context. + #[inline] pub fn as_bytes(&self) -> &[u8] { - let ptr = self.secdata; + let ptr = self.ctx.context; if ptr.is_null() { - debug_assert_eq!(self.seclen, 0); + debug_assert_eq!(self.len(), 0); // We can't pass a null pointer to `slice::from_raw_parts` even if the length is zero. return &[]; } // SAFETY: The call to `security_secid_to_secctx` guarantees that the pointer is valid for - // `seclen` bytes. Furthermore, if the length is zero, then we have ensured that the + // `self.len()` bytes. Furthermore, if the length is zero, then we have ensured that the // pointer is not null. - unsafe { core::slice::from_raw_parts(ptr.cast(), self.seclen) } + unsafe { core::slice::from_raw_parts(ptr.cast(), self.len()) } } } impl Drop for SecurityCtx { + #[inline] fn drop(&mut self) { - // SAFETY: By the invariant of `Self`, this frees a pointer that came from a successful - // call to `security_secid_to_secctx` and has not yet been destroyed by - // `security_release_secctx`. - unsafe { bindings::security_release_secctx(self.secdata, self.seclen as u32) }; + // SAFETY: By the invariant of `Self`, this releases an lsm context that came from a + // successful call to `security_secid_to_secctx` and has not yet been released. + unsafe { bindings::security_release_secctx(&mut self.ctx) }; } } diff --git a/rust/kernel/seq_file.rs b/rust/kernel/seq_file.rs index 04947c672979..8f199b1a3bb1 100644 --- a/rust/kernel/seq_file.rs +++ b/rust/kernel/seq_file.rs @@ -18,7 +18,7 @@ impl SeqFile { /// /// # Safety /// - /// The caller must ensure that for the duration of 'a the following is satisfied: + /// The caller must ensure that for the duration of `'a` the following is satisfied: /// * The pointer points at a valid `struct seq_file`. /// * The `struct seq_file` is not accessed from any other thread. pub unsafe fn from_raw<'a>(ptr: *mut bindings::seq_file) -> &'a SeqFile { @@ -30,13 +30,14 @@ impl SeqFile { } /// Used by the [`seq_print`] macro. + #[inline] pub fn call_printf(&self, args: core::fmt::Arguments<'_>) { // SAFETY: Passing a void pointer to `Arguments` is valid for `%pA`. unsafe { bindings::seq_printf( self.inner.get(), c_str!("%pA").as_char_ptr(), - &args as *const _ as *const crate::ffi::c_void, + core::ptr::from_ref(&args).cast::<crate::ffi::c_void>(), ); } } diff --git a/rust/kernel/sizes.rs b/rust/kernel/sizes.rs index 834c343e4170..661e680d9330 100644 --- a/rust/kernel/sizes.rs +++ b/rust/kernel/sizes.rs @@ -24,3 +24,27 @@ pub const SZ_128K: usize = bindings::SZ_128K as usize; pub const SZ_256K: usize = bindings::SZ_256K as usize; /// 0x00080000 pub const SZ_512K: usize = bindings::SZ_512K as usize; +/// 0x00100000 +pub const SZ_1M: usize = bindings::SZ_1M as usize; +/// 0x00200000 +pub const SZ_2M: usize = bindings::SZ_2M as usize; +/// 0x00400000 +pub const SZ_4M: usize = bindings::SZ_4M as usize; +/// 0x00800000 +pub const SZ_8M: usize = bindings::SZ_8M as usize; +/// 0x01000000 +pub const SZ_16M: usize = bindings::SZ_16M as usize; +/// 0x02000000 +pub const SZ_32M: usize = bindings::SZ_32M as usize; +/// 0x04000000 +pub const SZ_64M: usize = bindings::SZ_64M as usize; +/// 0x08000000 +pub const SZ_128M: usize = bindings::SZ_128M as usize; +/// 0x10000000 +pub const SZ_256M: usize = bindings::SZ_256M as usize; +/// 0x20000000 +pub const SZ_512M: usize = bindings::SZ_512M as usize; +/// 0x40000000 +pub const SZ_1G: usize = bindings::SZ_1G as usize; +/// 0x80000000 +pub const SZ_2G: usize = bindings::SZ_2G as usize; diff --git a/rust/kernel/static_assert.rs b/rust/kernel/static_assert.rs index 3115ee0ba8e9..a57ba14315a0 100644 --- a/rust/kernel/static_assert.rs +++ b/rust/kernel/static_assert.rs @@ -6,6 +6,10 @@ /// /// Similar to C11 [`_Static_assert`] and C++11 [`static_assert`]. /// +/// An optional panic message can be supplied after the expression. +/// Currently only a string literal without formatting is supported +/// due to constness limitations of the [`assert!`] macro. +/// /// The feature may be added to Rust in the future: see [RFC 2790]. /// /// [`_Static_assert`]: https://en.cppreference.com/w/c/language/_Static_assert @@ -25,10 +29,11 @@ /// x + 2 /// } /// static_assert!(f(40) == 42); +/// static_assert!(f(40) == 42, "f(x) must add 2 to the given input."); /// ``` #[macro_export] macro_rules! static_assert { - ($condition:expr) => { - const _: () = core::assert!($condition); + ($condition:expr $(,$arg:literal)?) => { + const _: () = ::core::assert!($condition $(,$arg)?); }; } diff --git a/rust/kernel/std_vendor.rs b/rust/kernel/std_vendor.rs index 279bd353687a..abbab5050cc5 100644 --- a/rust/kernel/std_vendor.rs +++ b/rust/kernel/std_vendor.rs @@ -148,7 +148,7 @@ macro_rules! dbg { }; ($val:expr $(,)?) => { // Use of `match` here is intentional because it affects the lifetimes - // of temporaries - https://stackoverflow.com/a/48732525/1063961 + // of temporaries - <https://stackoverflow.com/a/48732525/1063961> match $val { tmp => { $crate::pr_info!("[{}:{}:{}] {} = {:#?}\n", diff --git a/rust/kernel/str.rs b/rust/kernel/str.rs index 0f2765463dc8..6c892550c0ba 100644 --- a/rust/kernel/str.rs +++ b/rust/kernel/str.rs @@ -3,10 +3,10 @@ //! String representations. use crate::alloc::{flags::*, AllocError, KVec}; -use core::fmt::{self, Write}; +use crate::fmt::{self, Write}; use core::ops::{self, Deref, DerefMut, Index}; -use crate::error::{code::*, Error}; +use crate::prelude::*; /// Byte string without UTF-8 validity guarantee. #[repr(transparent)] @@ -29,7 +29,24 @@ impl BStr { #[inline] pub const fn from_bytes(bytes: &[u8]) -> &Self { // SAFETY: `BStr` is transparent to `[u8]`. - unsafe { &*(bytes as *const [u8] as *const BStr) } + unsafe { &*(core::ptr::from_ref(bytes) as *const BStr) } + } + + /// Strip a prefix from `self`. Delegates to [`slice::strip_prefix`]. + /// + /// # Examples + /// + /// ``` + /// # use kernel::b_str; + /// assert_eq!(Some(b_str!("bar")), b_str!("foobar").strip_prefix(b_str!("foo"))); + /// assert_eq!(None, b_str!("foobar").strip_prefix(b_str!("bar"))); + /// assert_eq!(Some(b_str!("foobar")), b_str!("foobar").strip_prefix(b_str!(""))); + /// assert_eq!(Some(b_str!("")), b_str!("foobar").strip_prefix(b_str!("foobar"))); + /// ``` + pub fn strip_prefix(&self, pattern: impl AsRef<Self>) -> Option<&BStr> { + self.deref() + .strip_prefix(pattern.as_ref().deref()) + .map(Self::from_bytes) } } @@ -37,14 +54,15 @@ impl fmt::Display for BStr { /// Formats printable ASCII characters, escaping the rest. /// /// ``` - /// # use kernel::{fmt, b_str, str::{BStr, CString}}; + /// # use kernel::{prelude::fmt, b_str, str::{BStr, CString}}; /// let ascii = b_str!("Hello, BStr!"); - /// let s = CString::try_from_fmt(fmt!("{}", ascii)).unwrap(); - /// assert_eq!(s.as_bytes(), "Hello, BStr!".as_bytes()); + /// let s = CString::try_from_fmt(fmt!("{ascii}"))?; + /// assert_eq!(s.to_bytes(), "Hello, BStr!".as_bytes()); /// /// let non_ascii = b_str!("🦀"); - /// let s = CString::try_from_fmt(fmt!("{}", non_ascii)).unwrap(); - /// assert_eq!(s.as_bytes(), "\\xf0\\x9f\\xa6\\x80".as_bytes()); + /// let s = CString::try_from_fmt(fmt!("{non_ascii}"))?; + /// assert_eq!(s.to_bytes(), "\\xf0\\x9f\\xa6\\x80".as_bytes()); + /// # Ok::<(), kernel::error::Error>(()) /// ``` fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { for &b in &self.0 { @@ -55,7 +73,7 @@ impl fmt::Display for BStr { b'\r' => f.write_str("\\r")?, // Printable characters. 0x20..=0x7e => f.write_char(b as char)?, - _ => write!(f, "\\x{:02x}", b)?, + _ => write!(f, "\\x{b:02x}")?, } } Ok(()) @@ -67,15 +85,16 @@ impl fmt::Debug for BStr { /// escaping the rest. /// /// ``` - /// # use kernel::{fmt, b_str, str::{BStr, CString}}; + /// # use kernel::{prelude::fmt, b_str, str::{BStr, CString}}; /// // Embedded double quotes are escaped. /// let ascii = b_str!("Hello, \"BStr\"!"); - /// let s = CString::try_from_fmt(fmt!("{:?}", ascii)).unwrap(); - /// assert_eq!(s.as_bytes(), "\"Hello, \\\"BStr\\\"!\"".as_bytes()); + /// let s = CString::try_from_fmt(fmt!("{ascii:?}"))?; + /// assert_eq!(s.to_bytes(), "\"Hello, \\\"BStr\\\"!\"".as_bytes()); /// /// let non_ascii = b_str!("😺"); - /// let s = CString::try_from_fmt(fmt!("{:?}", non_ascii)).unwrap(); - /// assert_eq!(s.as_bytes(), "\"\\xf0\\x9f\\x98\\xba\"".as_bytes()); + /// let s = CString::try_from_fmt(fmt!("{non_ascii:?}"))?; + /// assert_eq!(s.to_bytes(), "\"\\xf0\\x9f\\x98\\xba\"".as_bytes()); + /// # Ok::<(), kernel::error::Error>(()) /// ``` fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_char('"')?; @@ -90,7 +109,7 @@ impl fmt::Debug for BStr { b'\\' => f.write_str("\\\\")?, // Printable characters. 0x20..=0x7e => f.write_char(b as char)?, - _ => write!(f, "\\x{:02x}", b)?, + _ => write!(f, "\\x{b:02x}")?, } } f.write_char('"') @@ -106,6 +125,35 @@ impl Deref for BStr { } } +impl PartialEq for BStr { + fn eq(&self, other: &Self) -> bool { + self.deref().eq(other.deref()) + } +} + +impl<Idx> Index<Idx> for BStr +where + [u8]: Index<Idx, Output = [u8]>, +{ + type Output = Self; + + fn index(&self, index: Idx) -> &Self::Output { + BStr::from_bytes(&self.0[index]) + } +} + +impl AsRef<BStr> for [u8] { + fn as_ref(&self) -> &BStr { + BStr::from_bytes(self) + } +} + +impl AsRef<BStr> for BStr { + fn as_ref(&self) -> &BStr { + self + } +} + /// Creates a new [`BStr`] from a string literal. /// /// `b_str!` converts the supplied string literal to byte string, so non-ASCII @@ -127,6 +175,15 @@ macro_rules! b_str { }}; } +/// Returns a C pointer to the string. +// It is a free function rather than a method on an extension trait because: +// +// - error[E0379]: functions in trait impls cannot be declared const +#[inline] +pub const fn as_char_ptr_in_const_context(c_str: &CStr) -> *const c_char { + c_str.0.as_ptr() +} + /// Possible errors when using conversion functions in [`CStr`]. #[derive(Debug, Clone, Copy)] pub enum CStrConvertError { @@ -184,12 +241,12 @@ impl CStr { /// last at least `'a`. When `CStr` is alive, the memory pointed by `ptr` /// must not be mutated. #[inline] - pub unsafe fn from_char_ptr<'a>(ptr: *const crate::ffi::c_char) -> &'a Self { + pub unsafe fn from_char_ptr<'a>(ptr: *const c_char) -> &'a Self { // SAFETY: The safety precondition guarantees `ptr` is a valid pointer // to a `NUL`-terminated C string. let len = unsafe { bindings::strlen(ptr) } + 1; // SAFETY: Lifetime guaranteed by the safety precondition. - let bytes = unsafe { core::slice::from_raw_parts(ptr as _, len) }; + let bytes = unsafe { core::slice::from_raw_parts(ptr.cast(), len) }; // SAFETY: As `len` is returned by `strlen`, `bytes` does not contain interior `NUL`. // As we have added 1 to `len`, the last byte is known to be `NUL`. unsafe { Self::from_bytes_with_nul_unchecked(bytes) } @@ -242,27 +299,49 @@ impl CStr { #[inline] pub unsafe fn from_bytes_with_nul_unchecked_mut(bytes: &mut [u8]) -> &mut CStr { // SAFETY: Properties of `bytes` guaranteed by the safety precondition. - unsafe { &mut *(bytes as *mut [u8] as *mut CStr) } + unsafe { &mut *(core::ptr::from_mut(bytes) as *mut CStr) } } /// Returns a C pointer to the string. + /// + /// Using this function in a const context is deprecated in favor of + /// [`as_char_ptr_in_const_context`] in preparation for replacing `CStr` with `core::ffi::CStr` + /// which does not have this method. #[inline] - pub const fn as_char_ptr(&self) -> *const crate::ffi::c_char { - self.0.as_ptr() + pub const fn as_char_ptr(&self) -> *const c_char { + as_char_ptr_in_const_context(self) } /// Convert the string to a byte slice without the trailing `NUL` byte. #[inline] - pub fn as_bytes(&self) -> &[u8] { + pub fn to_bytes(&self) -> &[u8] { &self.0[..self.len()] } + /// Convert the string to a byte slice without the trailing `NUL` byte. + /// + /// This function is deprecated in favor of [`Self::to_bytes`] in preparation for replacing + /// `CStr` with `core::ffi::CStr` which does not have this method. + #[inline] + pub fn as_bytes(&self) -> &[u8] { + self.to_bytes() + } + /// Convert the string to a byte slice containing the trailing `NUL` byte. #[inline] - pub const fn as_bytes_with_nul(&self) -> &[u8] { + pub const fn to_bytes_with_nul(&self) -> &[u8] { &self.0 } + /// Convert the string to a byte slice containing the trailing `NUL` byte. + /// + /// This function is deprecated in favor of [`Self::to_bytes_with_nul`] in preparation for + /// replacing `CStr` with `core::ffi::CStr` which does not have this method. + #[inline] + pub const fn as_bytes_with_nul(&self) -> &[u8] { + self.to_bytes_with_nul() + } + /// Yields a [`&str`] slice if the [`CStr`] contains valid UTF-8. /// /// If the contents of the [`CStr`] are valid UTF-8 data, this @@ -273,8 +352,9 @@ impl CStr { /// /// ``` /// # use kernel::str::CStr; - /// let cstr = CStr::from_bytes_with_nul(b"foo\0").unwrap(); + /// let cstr = CStr::from_bytes_with_nul(b"foo\0")?; /// assert_eq!(cstr.to_str(), Ok("foo")); + /// # Ok::<(), kernel::error::Error>(()) /// ``` #[inline] pub fn to_str(&self) -> Result<&str, core::str::Utf8Error> { @@ -380,24 +460,25 @@ impl fmt::Display for CStr { /// /// ``` /// # use kernel::c_str; - /// # use kernel::fmt; + /// # use kernel::prelude::fmt; /// # use kernel::str::CStr; /// # use kernel::str::CString; /// let penguin = c_str!("🐧"); - /// let s = CString::try_from_fmt(fmt!("{}", penguin)).unwrap(); - /// assert_eq!(s.as_bytes_with_nul(), "\\xf0\\x9f\\x90\\xa7\0".as_bytes()); + /// let s = CString::try_from_fmt(fmt!("{penguin}"))?; + /// assert_eq!(s.to_bytes_with_nul(), "\\xf0\\x9f\\x90\\xa7\0".as_bytes()); /// /// let ascii = c_str!("so \"cool\""); - /// let s = CString::try_from_fmt(fmt!("{}", ascii)).unwrap(); - /// assert_eq!(s.as_bytes_with_nul(), "so \"cool\"\0".as_bytes()); + /// let s = CString::try_from_fmt(fmt!("{ascii}"))?; + /// assert_eq!(s.to_bytes_with_nul(), "so \"cool\"\0".as_bytes()); + /// # Ok::<(), kernel::error::Error>(()) /// ``` fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - for &c in self.as_bytes() { + for &c in self.to_bytes() { if (0x20..0x7f).contains(&c) { // Printable character. f.write_char(c as char)?; } else { - write!(f, "\\x{:02x}", c)?; + write!(f, "\\x{c:02x}")?; } } Ok(()) @@ -409,17 +490,18 @@ impl fmt::Debug for CStr { /// /// ``` /// # use kernel::c_str; - /// # use kernel::fmt; + /// # use kernel::prelude::fmt; /// # use kernel::str::CStr; /// # use kernel::str::CString; /// let penguin = c_str!("🐧"); - /// let s = CString::try_from_fmt(fmt!("{:?}", penguin)).unwrap(); + /// let s = CString::try_from_fmt(fmt!("{penguin:?}"))?; /// assert_eq!(s.as_bytes_with_nul(), "\"\\xf0\\x9f\\x90\\xa7\"\0".as_bytes()); /// /// // Embedded double quotes are escaped. /// let ascii = c_str!("so \"cool\""); - /// let s = CString::try_from_fmt(fmt!("{:?}", ascii)).unwrap(); + /// let s = CString::try_from_fmt(fmt!("{ascii:?}"))?; /// assert_eq!(s.as_bytes_with_nul(), "\"so \\\"cool\\\"\"\0".as_bytes()); + /// # Ok::<(), kernel::error::Error>(()) /// ``` fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str("\"")?; @@ -428,7 +510,7 @@ impl fmt::Debug for CStr { // Printable characters. b'\"' => f.write_str("\\\"")?, 0x20..=0x7e => f.write_char(c as char)?, - _ => write!(f, "\\x{:02x}", c)?, + _ => write!(f, "\\x{c:02x}")?, } } f.write_str("\"") @@ -521,33 +603,17 @@ macro_rules! c_str { }}; } -#[cfg(test)] +#[kunit_tests(rust_kernel_str)] mod tests { use super::*; - struct String(CString); - - impl String { - fn from_fmt(args: fmt::Arguments<'_>) -> Self { - String(CString::try_from_fmt(args).unwrap()) - } - } - - impl Deref for String { - type Target = str; - - fn deref(&self) -> &str { - self.0.to_str().unwrap() - } - } - macro_rules! format { ($($f:tt)*) => ({ - &*String::from_fmt(kernel::fmt!($($f)*)) + CString::try_from_fmt(fmt!($($f)*))?.to_str()? }) } - const ALL_ASCII_CHARS: &'static str = + const ALL_ASCII_CHARS: &str = "\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\x09\\x0a\\x0b\\x0c\\x0d\\x0e\\x0f\ \\x10\\x11\\x12\\x13\\x14\\x15\\x16\\x17\\x18\\x19\\x1a\\x1b\\x1c\\x1d\\x1e\\x1f \ !\"#$%&'()*+,-./0123456789:;<=>?@\ @@ -562,90 +628,98 @@ mod tests { \\xf0\\xf1\\xf2\\xf3\\xf4\\xf5\\xf6\\xf7\\xf8\\xf9\\xfa\\xfb\\xfc\\xfd\\xfe\\xff"; #[test] - fn test_cstr_to_str() { + fn test_cstr_to_str() -> Result { let good_bytes = b"\xf0\x9f\xa6\x80\0"; - let checked_cstr = CStr::from_bytes_with_nul(good_bytes).unwrap(); - let checked_str = checked_cstr.to_str().unwrap(); + let checked_cstr = CStr::from_bytes_with_nul(good_bytes)?; + let checked_str = checked_cstr.to_str()?; assert_eq!(checked_str, "🦀"); + Ok(()) } #[test] - #[should_panic] - fn test_cstr_to_str_panic() { + fn test_cstr_to_str_invalid_utf8() -> Result { let bad_bytes = b"\xc3\x28\0"; - let checked_cstr = CStr::from_bytes_with_nul(bad_bytes).unwrap(); - checked_cstr.to_str().unwrap(); + let checked_cstr = CStr::from_bytes_with_nul(bad_bytes)?; + assert!(checked_cstr.to_str().is_err()); + Ok(()) } #[test] - fn test_cstr_as_str_unchecked() { + fn test_cstr_as_str_unchecked() -> Result { let good_bytes = b"\xf0\x9f\x90\xA7\0"; - let checked_cstr = CStr::from_bytes_with_nul(good_bytes).unwrap(); + let checked_cstr = CStr::from_bytes_with_nul(good_bytes)?; + // SAFETY: The contents come from a string literal which contains valid UTF-8. let unchecked_str = unsafe { checked_cstr.as_str_unchecked() }; assert_eq!(unchecked_str, "🐧"); + Ok(()) } #[test] - fn test_cstr_display() { - let hello_world = CStr::from_bytes_with_nul(b"hello, world!\0").unwrap(); - assert_eq!(format!("{}", hello_world), "hello, world!"); - let non_printables = CStr::from_bytes_with_nul(b"\x01\x09\x0a\0").unwrap(); - assert_eq!(format!("{}", non_printables), "\\x01\\x09\\x0a"); - let non_ascii = CStr::from_bytes_with_nul(b"d\xe9j\xe0 vu\0").unwrap(); - assert_eq!(format!("{}", non_ascii), "d\\xe9j\\xe0 vu"); - let good_bytes = CStr::from_bytes_with_nul(b"\xf0\x9f\xa6\x80\0").unwrap(); - assert_eq!(format!("{}", good_bytes), "\\xf0\\x9f\\xa6\\x80"); + fn test_cstr_display() -> Result { + let hello_world = CStr::from_bytes_with_nul(b"hello, world!\0")?; + assert_eq!(format!("{hello_world}"), "hello, world!"); + let non_printables = CStr::from_bytes_with_nul(b"\x01\x09\x0a\0")?; + assert_eq!(format!("{non_printables}"), "\\x01\\x09\\x0a"); + let non_ascii = CStr::from_bytes_with_nul(b"d\xe9j\xe0 vu\0")?; + assert_eq!(format!("{non_ascii}"), "d\\xe9j\\xe0 vu"); + let good_bytes = CStr::from_bytes_with_nul(b"\xf0\x9f\xa6\x80\0")?; + assert_eq!(format!("{good_bytes}"), "\\xf0\\x9f\\xa6\\x80"); + Ok(()) } #[test] - fn test_cstr_display_all_bytes() { + fn test_cstr_display_all_bytes() -> Result { let mut bytes: [u8; 256] = [0; 256]; // fill `bytes` with [1..=255] + [0] for i in u8::MIN..=u8::MAX { bytes[i as usize] = i.wrapping_add(1); } - let cstr = CStr::from_bytes_with_nul(&bytes).unwrap(); - assert_eq!(format!("{}", cstr), ALL_ASCII_CHARS); + let cstr = CStr::from_bytes_with_nul(&bytes)?; + assert_eq!(format!("{cstr}"), ALL_ASCII_CHARS); + Ok(()) } #[test] - fn test_cstr_debug() { - let hello_world = CStr::from_bytes_with_nul(b"hello, world!\0").unwrap(); - assert_eq!(format!("{:?}", hello_world), "\"hello, world!\""); - let non_printables = CStr::from_bytes_with_nul(b"\x01\x09\x0a\0").unwrap(); - assert_eq!(format!("{:?}", non_printables), "\"\\x01\\x09\\x0a\""); - let non_ascii = CStr::from_bytes_with_nul(b"d\xe9j\xe0 vu\0").unwrap(); - assert_eq!(format!("{:?}", non_ascii), "\"d\\xe9j\\xe0 vu\""); - let good_bytes = CStr::from_bytes_with_nul(b"\xf0\x9f\xa6\x80\0").unwrap(); - assert_eq!(format!("{:?}", good_bytes), "\"\\xf0\\x9f\\xa6\\x80\""); + fn test_cstr_debug() -> Result { + let hello_world = CStr::from_bytes_with_nul(b"hello, world!\0")?; + assert_eq!(format!("{hello_world:?}"), "\"hello, world!\""); + let non_printables = CStr::from_bytes_with_nul(b"\x01\x09\x0a\0")?; + assert_eq!(format!("{non_printables:?}"), "\"\\x01\\x09\\x0a\""); + let non_ascii = CStr::from_bytes_with_nul(b"d\xe9j\xe0 vu\0")?; + assert_eq!(format!("{non_ascii:?}"), "\"d\\xe9j\\xe0 vu\""); + let good_bytes = CStr::from_bytes_with_nul(b"\xf0\x9f\xa6\x80\0")?; + assert_eq!(format!("{good_bytes:?}"), "\"\\xf0\\x9f\\xa6\\x80\""); + Ok(()) } #[test] - fn test_bstr_display() { + fn test_bstr_display() -> Result { let hello_world = BStr::from_bytes(b"hello, world!"); - assert_eq!(format!("{}", hello_world), "hello, world!"); + assert_eq!(format!("{hello_world}"), "hello, world!"); let escapes = BStr::from_bytes(b"_\t_\n_\r_\\_\'_\"_"); - assert_eq!(format!("{}", escapes), "_\\t_\\n_\\r_\\_'_\"_"); + assert_eq!(format!("{escapes}"), "_\\t_\\n_\\r_\\_'_\"_"); let others = BStr::from_bytes(b"\x01"); - assert_eq!(format!("{}", others), "\\x01"); + assert_eq!(format!("{others}"), "\\x01"); let non_ascii = BStr::from_bytes(b"d\xe9j\xe0 vu"); - assert_eq!(format!("{}", non_ascii), "d\\xe9j\\xe0 vu"); + assert_eq!(format!("{non_ascii}"), "d\\xe9j\\xe0 vu"); let good_bytes = BStr::from_bytes(b"\xf0\x9f\xa6\x80"); - assert_eq!(format!("{}", good_bytes), "\\xf0\\x9f\\xa6\\x80"); + assert_eq!(format!("{good_bytes}"), "\\xf0\\x9f\\xa6\\x80"); + Ok(()) } #[test] - fn test_bstr_debug() { + fn test_bstr_debug() -> Result { let hello_world = BStr::from_bytes(b"hello, world!"); - assert_eq!(format!("{:?}", hello_world), "\"hello, world!\""); + assert_eq!(format!("{hello_world:?}"), "\"hello, world!\""); let escapes = BStr::from_bytes(b"_\t_\n_\r_\\_\'_\"_"); - assert_eq!(format!("{:?}", escapes), "\"_\\t_\\n_\\r_\\\\_'_\\\"_\""); + assert_eq!(format!("{escapes:?}"), "\"_\\t_\\n_\\r_\\\\_'_\\\"_\""); let others = BStr::from_bytes(b"\x01"); - assert_eq!(format!("{:?}", others), "\"\\x01\""); + assert_eq!(format!("{others:?}"), "\"\\x01\""); let non_ascii = BStr::from_bytes(b"d\xe9j\xe0 vu"); - assert_eq!(format!("{:?}", non_ascii), "\"d\\xe9j\\xe0 vu\""); + assert_eq!(format!("{non_ascii:?}"), "\"d\\xe9j\\xe0 vu\""); let good_bytes = BStr::from_bytes(b"\xf0\x9f\xa6\x80"); - assert_eq!(format!("{:?}", good_bytes), "\"\\xf0\\x9f\\xa6\\x80\""); + assert_eq!(format!("{good_bytes:?}"), "\"\\xf0\\x9f\\xa6\\x80\""); + Ok(()) } } @@ -685,9 +759,9 @@ impl RawFormatter { pub(crate) unsafe fn from_ptrs(pos: *mut u8, end: *mut u8) -> Self { // INVARIANT: The safety requirements guarantee the type invariants. Self { - beg: pos as _, - pos: pos as _, - end: end as _, + beg: pos as usize, + pos: pos as usize, + end: end as usize, } } @@ -699,7 +773,7 @@ impl RawFormatter { /// for the lifetime of the returned [`RawFormatter`]. pub(crate) unsafe fn from_buffer(buf: *mut u8, len: usize) -> Self { let pos = buf as usize; - // INVARIANT: We ensure that `end` is never less then `buf`, and the safety requirements + // INVARIANT: We ensure that `end` is never less than `buf`, and the safety requirements // guarantees that the memory region is valid for writes. Self { pos, @@ -712,7 +786,7 @@ impl RawFormatter { /// /// N.B. It may point to invalid memory. pub(crate) fn pos(&self) -> *mut u8 { - self.pos as _ + self.pos as *mut u8 } /// Returns the number of bytes written to the formatter. @@ -797,18 +871,19 @@ impl fmt::Write for Formatter { /// # Examples /// /// ``` -/// use kernel::{str::CString, fmt}; +/// use kernel::{str::CString, prelude::fmt}; /// -/// let s = CString::try_from_fmt(fmt!("{}{}{}", "abc", 10, 20)).unwrap(); -/// assert_eq!(s.as_bytes_with_nul(), "abc1020\0".as_bytes()); +/// let s = CString::try_from_fmt(fmt!("{}{}{}", "abc", 10, 20))?; +/// assert_eq!(s.to_bytes_with_nul(), "abc1020\0".as_bytes()); /// /// let tmp = "testing"; -/// let s = CString::try_from_fmt(fmt!("{tmp}{}", 123)).unwrap(); -/// assert_eq!(s.as_bytes_with_nul(), "testing123\0".as_bytes()); +/// let s = CString::try_from_fmt(fmt!("{tmp}{}", 123))?; +/// assert_eq!(s.to_bytes_with_nul(), "testing123\0".as_bytes()); /// /// // This fails because it has an embedded `NUL` byte. /// let s = CString::try_from_fmt(fmt!("a\0b{}", 123)); /// assert_eq!(s.is_ok(), false); +/// # Ok::<(), kernel::error::Error>(()) /// ``` pub struct CString { buf: KVec<u8>, @@ -832,7 +907,7 @@ impl CString { // SAFETY: The number of bytes that can be written to `f` is bounded by `size`, which is // `buf`'s capacity. The contents of the buffer have been initialised by writes to `f`. - unsafe { buf.set_len(f.bytes_written()) }; + unsafe { buf.inc_len(f.bytes_written()) }; // Check that there are no `NUL` bytes before the end. // SAFETY: The buffer is valid for read because `f.bytes_written()` is bounded by `size` @@ -873,7 +948,7 @@ impl<'a> TryFrom<&'a CStr> for CString { fn try_from(cstr: &'a CStr) -> Result<CString, AllocError> { let mut buf = KVec::new(); - buf.extend_from_slice(cstr.as_bytes_with_nul(), GFP_KERNEL)?; + buf.extend_from_slice(cstr.to_bytes_with_nul(), GFP_KERNEL)?; // INVARIANT: The `CStr` and `CString` types have the same invariants for // the string data, and we copied it over without changes. @@ -886,9 +961,3 @@ impl fmt::Debug for CString { fmt::Debug::fmt(&**self, f) } } - -/// A convenience alias for [`core::format_args`]. -#[macro_export] -macro_rules! fmt { - ($($f:tt)*) => ( core::format_args!($($f)*) ) -} diff --git a/rust/kernel/sync.rs b/rust/kernel/sync.rs index 1eab7ebf25fd..00f9b558a3ad 100644 --- a/rust/kernel/sync.rs +++ b/rust/kernel/sync.rs @@ -5,43 +5,88 @@ //! This module contains the kernel APIs related to synchronisation that have been ported or //! wrapped for usage by Rust code in the kernel. +use crate::prelude::*; use crate::types::Opaque; +use pin_init; mod arc; +pub mod aref; +pub mod completion; mod condvar; pub mod lock; mod locked_by; pub mod poll; +pub mod rcu; pub use arc::{Arc, ArcBorrow, UniqueArc}; +pub use completion::Completion; pub use condvar::{new_condvar, CondVar, CondVarTimeoutResult}; pub use lock::global::{global_lock, GlobalGuard, GlobalLock, GlobalLockBackend, GlobalLockedBy}; -pub use lock::mutex::{new_mutex, Mutex}; -pub use lock::spinlock::{new_spinlock, SpinLock}; +pub use lock::mutex::{new_mutex, Mutex, MutexGuard}; +pub use lock::spinlock::{new_spinlock, SpinLock, SpinLockGuard}; pub use locked_by::LockedBy; /// Represents a lockdep class. It's a wrapper around C's `lock_class_key`. #[repr(transparent)] -pub struct LockClassKey(Opaque<bindings::lock_class_key>); +#[pin_data(PinnedDrop)] +pub struct LockClassKey { + #[pin] + inner: Opaque<bindings::lock_class_key>, +} // SAFETY: `bindings::lock_class_key` is designed to be used concurrently from multiple threads and // provides its own synchronization. unsafe impl Sync for LockClassKey {} impl LockClassKey { - /// Creates a new lock class key. - pub const fn new() -> Self { - Self(Opaque::uninit()) + /// Initializes a dynamically allocated lock class key. In the common case of using a + /// statically allocated lock class key, the static_lock_class! macro should be used instead. + /// + /// # Examples + /// ``` + /// # use kernel::c_str; + /// # use kernel::alloc::KBox; + /// # use kernel::types::ForeignOwnable; + /// # use kernel::sync::{LockClassKey, SpinLock}; + /// # use pin_init::stack_pin_init; + /// + /// let key = KBox::pin_init(LockClassKey::new_dynamic(), GFP_KERNEL)?; + /// let key_ptr = key.into_foreign(); + /// + /// { + /// stack_pin_init!(let num: SpinLock<u32> = SpinLock::new( + /// 0, + /// c_str!("my_spinlock"), + /// // SAFETY: `key_ptr` is returned by the above `into_foreign()`, whose + /// // `from_foreign()` has not yet been called. + /// unsafe { <Pin<KBox<LockClassKey>> as ForeignOwnable>::borrow(key_ptr) } + /// )); + /// } + /// + /// // SAFETY: We dropped `num`, the only use of the key, so the result of the previous + /// // `borrow` has also been dropped. Thus, it's safe to use from_foreign. + /// unsafe { drop(<Pin<KBox<LockClassKey>> as ForeignOwnable>::from_foreign(key_ptr)) }; + /// + /// # Ok::<(), Error>(()) + /// ``` + pub fn new_dynamic() -> impl PinInit<Self> { + pin_init!(Self { + // SAFETY: lockdep_register_key expects an uninitialized block of memory + inner <- Opaque::ffi_init(|slot| unsafe { bindings::lockdep_register_key(slot) }) + }) } pub(crate) fn as_ptr(&self) -> *mut bindings::lock_class_key { - self.0.get() + self.inner.get() } } -impl Default for LockClassKey { - fn default() -> Self { - Self::new() +#[pinned_drop] +impl PinnedDrop for LockClassKey { + fn drop(self: Pin<&mut Self>) { + // SAFETY: self.as_ptr was registered with lockdep and self is pinned, so the address + // hasn't changed. Thus, it's safe to pass to unregister. + unsafe { bindings::lockdep_unregister_key(self.as_ptr()) } } } @@ -50,8 +95,14 @@ impl Default for LockClassKey { #[macro_export] macro_rules! static_lock_class { () => {{ - static CLASS: $crate::sync::LockClassKey = $crate::sync::LockClassKey::new(); - &CLASS + static CLASS: $crate::sync::LockClassKey = + // Lockdep expects uninitialized memory when it's handed a statically allocated `struct + // lock_class_key`. + // + // SAFETY: `LockClassKey` transparently wraps `Opaque` which permits uninitialized + // memory. + unsafe { ::core::mem::MaybeUninit::uninit().assume_init() }; + $crate::prelude::Pin::static_ref(&CLASS) }}; } diff --git a/rust/kernel/sync/arc.rs b/rust/kernel/sync/arc.rs index fa4509406ee9..63a66761d0c7 100644 --- a/rust/kernel/sync/arc.rs +++ b/rust/kernel/sync/arc.rs @@ -19,20 +19,22 @@ use crate::{ alloc::{AllocError, Flags, KBox}, bindings, - init::{self, InPlaceInit, Init, PinInit}, + ffi::c_void, + init::InPlaceInit, try_init, types::{ForeignOwnable, Opaque}, }; use core::{ alloc::Layout, + borrow::{Borrow, BorrowMut}, fmt, - marker::{PhantomData, Unsize}, + marker::PhantomData, mem::{ManuallyDrop, MaybeUninit}, ops::{Deref, DerefMut}, pin::Pin, ptr::NonNull, }; -use macros::pin_data; +use pin_init::{self, pin_data, InPlaceWrite, Init, PinInit}; mod std_vendor; @@ -125,8 +127,18 @@ mod std_vendor; /// let coerced: Arc<dyn MyTrait> = obj; /// # Ok::<(), Error>(()) /// ``` +#[repr(transparent)] +#[cfg_attr(CONFIG_RUSTC_HAS_COERCE_POINTEE, derive(core::marker::CoercePointee))] pub struct Arc<T: ?Sized> { ptr: NonNull<ArcInner<T>>, + // NB: this informs dropck that objects of type `ArcInner<T>` may be used in `<Arc<T> as + // Drop>::drop`. Note that dropck already assumes that objects of type `T` may be used in + // `<Arc<T> as Drop>::drop` and the distinction between `T` and `ArcInner<T>` is not presently + // meaningful with respect to dropck - but this may change in the future so this is left here + // out of an abundance of caution. + // + // See <https://doc.rust-lang.org/nomicon/phantom-data.html#generic-parameters-and-drop-checking> + // for more detail on the semantics of dropck in the presence of `PhantomData`. _p: PhantomData<ArcInner<T>>, } @@ -172,10 +184,12 @@ impl<T: ?Sized> ArcInner<T> { // This is to allow coercion from `Arc<T>` to `Arc<U>` if `T` can be converted to the // dynamically-sized type (DST) `U`. -impl<T: ?Sized + Unsize<U>, U: ?Sized> core::ops::CoerceUnsized<Arc<U>> for Arc<T> {} +#[cfg(not(CONFIG_RUSTC_HAS_COERCE_POINTEE))] +impl<T: ?Sized + core::marker::Unsize<U>, U: ?Sized> core::ops::CoerceUnsized<Arc<U>> for Arc<T> {} // This is to allow `Arc<U>` to be dispatched on when `Arc<T>` can be coerced into `Arc<U>`. -impl<T: ?Sized + Unsize<U>, U: ?Sized> core::ops::DispatchFromDyn<Arc<U>> for Arc<T> {} +#[cfg(not(CONFIG_RUSTC_HAS_COERCE_POINTEE))] +impl<T: ?Sized + core::marker::Unsize<U>, U: ?Sized> core::ops::DispatchFromDyn<Arc<U>> for Arc<T> {} // SAFETY: It is safe to send `Arc<T>` to another thread when the underlying `T` is `Sync` because // it effectively means sharing `&T` (which is safe because `T` is `Sync`); additionally, it needs @@ -190,6 +204,26 @@ unsafe impl<T: ?Sized + Sync + Send> Send for Arc<T> {} // the reference count reaches zero and `T` is dropped. unsafe impl<T: ?Sized + Sync + Send> Sync for Arc<T> {} +impl<T> InPlaceInit<T> for Arc<T> { + type PinnedSelf = Self; + + #[inline] + fn try_pin_init<E>(init: impl PinInit<T, E>, flags: Flags) -> Result<Self::PinnedSelf, E> + where + E: From<AllocError>, + { + UniqueArc::try_pin_init(init, flags).map(|u| u.into()) + } + + #[inline] + fn try_init<E>(init: impl Init<T, E>, flags: Flags) -> Result<Self, E> + where + E: From<AllocError>, + { + UniqueArc::try_init(init, flags).map(|u| u.into()) + } +} + impl<T> Arc<T> { /// Constructs a new reference counted instance of `T`. pub fn new(contents: T, flags: Flags) -> Result<Self, AllocError> { @@ -201,10 +235,11 @@ impl<T> Arc<T> { }; let inner = KBox::new(value, flags)?; + let inner = KBox::leak(inner).into(); // SAFETY: We just created `inner` with a reference count of 1, which is owned by the new // `Arc` object. - Ok(unsafe { Self::from_inner(KBox::leak(inner).into()) }) + Ok(unsafe { Self::from_inner(inner) }) } } @@ -233,6 +268,15 @@ impl<T: ?Sized> Arc<T> { unsafe { core::ptr::addr_of!((*ptr).data) } } + /// Return a raw pointer to the data in this arc. + pub fn as_ptr(this: &Self) -> *const T { + let ptr = this.ptr.as_ptr(); + + // SAFETY: As `ptr` points to a valid allocation of type `ArcInner`, + // field projection to `data`is within bounds of the allocation. + unsafe { core::ptr::addr_of!((*ptr).data) } + } + /// Recreates an [`Arc`] instance previously deconstructed via [`Arc::into_raw`]. /// /// # Safety @@ -329,28 +373,43 @@ impl<T: ?Sized> Arc<T> { } } -impl<T: 'static> ForeignOwnable for Arc<T> { +// SAFETY: The pointer returned by `into_foreign` comes from a well aligned +// pointer to `ArcInner<T>`. +unsafe impl<T: 'static> ForeignOwnable for Arc<T> { + const FOREIGN_ALIGN: usize = core::mem::align_of::<ArcInner<T>>(); + type Borrowed<'a> = ArcBorrow<'a, T>; + type BorrowedMut<'a> = Self::Borrowed<'a>; + + fn into_foreign(self) -> *mut c_void { + ManuallyDrop::new(self).ptr.as_ptr().cast() + } - fn into_foreign(self) -> *const crate::ffi::c_void { - ManuallyDrop::new(self).ptr.as_ptr() as _ + unsafe fn from_foreign(ptr: *mut c_void) -> Self { + // SAFETY: The safety requirements of this function ensure that `ptr` comes from a previous + // call to `Self::into_foreign`. + let inner = unsafe { NonNull::new_unchecked(ptr.cast::<ArcInner<T>>()) }; + + // SAFETY: By the safety requirement of this function, we know that `ptr` came from + // a previous call to `Arc::into_foreign`, which guarantees that `ptr` is valid and + // holds a reference count increment that is transferrable to us. + unsafe { Self::from_inner(inner) } } - unsafe fn borrow<'a>(ptr: *const crate::ffi::c_void) -> ArcBorrow<'a, T> { - // By the safety requirement of this function, we know that `ptr` came from - // a previous call to `Arc::into_foreign`. - let inner = NonNull::new(ptr as *mut ArcInner<T>).unwrap(); + unsafe fn borrow<'a>(ptr: *mut c_void) -> ArcBorrow<'a, T> { + // SAFETY: The safety requirements of this function ensure that `ptr` comes from a previous + // call to `Self::into_foreign`. + let inner = unsafe { NonNull::new_unchecked(ptr.cast::<ArcInner<T>>()) }; // SAFETY: The safety requirements of `from_foreign` ensure that the object remains alive // for the lifetime of the returned value. unsafe { ArcBorrow::new(inner) } } - unsafe fn from_foreign(ptr: *const crate::ffi::c_void) -> Self { - // SAFETY: By the safety requirement of this function, we know that `ptr` came from - // a previous call to `Arc::into_foreign`, which guarantees that `ptr` is valid and - // holds a reference count increment that is transferrable to us. - unsafe { Self::from_inner(NonNull::new(ptr as _).unwrap()) } + unsafe fn borrow_mut<'a>(ptr: *mut c_void) -> ArcBorrow<'a, T> { + // SAFETY: The safety requirements for `borrow_mut` are a superset of the safety + // requirements for `borrow`. + unsafe { <Self as ForeignOwnable>::borrow(ptr) } } } @@ -370,12 +429,41 @@ impl<T: ?Sized> AsRef<T> for Arc<T> { } } +/// # Examples +/// +/// ``` +/// # use core::borrow::Borrow; +/// # use kernel::sync::Arc; +/// struct Foo<B: Borrow<u32>>(B); +/// +/// // Owned instance. +/// let owned = Foo(1); +/// +/// // Shared instance. +/// let arc = Arc::new(1, GFP_KERNEL)?; +/// let shared = Foo(arc.clone()); +/// +/// let i = 1; +/// // Borrowed from `i`. +/// let borrowed = Foo(&i); +/// # Ok::<(), Error>(()) +/// ``` +impl<T: ?Sized> Borrow<T> for Arc<T> { + fn borrow(&self) -> &T { + self.deref() + } +} + impl<T: ?Sized> Clone for Arc<T> { fn clone(&self) -> Self { + // SAFETY: By the type invariant, there is necessarily a reference to the object, so it is + // safe to dereference it. + let refcount = unsafe { self.ptr.as_ref() }.refcount.get(); + // INVARIANT: C `refcount_inc` saturates the refcount, so it cannot overflow to zero. // SAFETY: By the type invariant, there is necessarily a reference to the object, so it is // safe to increment the refcount. - unsafe { bindings::refcount_inc(self.ptr.as_ref().refcount.get()) }; + unsafe { bindings::refcount_inc(refcount) }; // SAFETY: We just incremented the refcount. This increment is now owned by the new `Arc`. unsafe { Self::from_inner(self.ptr) } @@ -432,7 +520,7 @@ impl<T: ?Sized> From<Pin<UniqueArc<T>>> for Arc<T> { /// There are no mutable references to the underlying [`Arc`], and it remains valid for the /// lifetime of the [`ArcBorrow`] instance. /// -/// # Example +/// # Examples /// /// ``` /// use kernel::sync::{Arc, ArcBorrow}; @@ -471,6 +559,8 @@ impl<T: ?Sized> From<Pin<UniqueArc<T>>> for Arc<T> { /// obj.as_arc_borrow().use_reference(); /// # Ok::<(), Error>(()) /// ``` +#[repr(transparent)] +#[cfg_attr(CONFIG_RUSTC_HAS_COERCE_POINTEE, derive(core::marker::CoercePointee))] pub struct ArcBorrow<'a, T: ?Sized + 'a> { inner: NonNull<ArcInner<T>>, _p: PhantomData<&'a ()>, @@ -478,7 +568,8 @@ pub struct ArcBorrow<'a, T: ?Sized + 'a> { // This is to allow `ArcBorrow<U>` to be dispatched on when `ArcBorrow<T>` can be coerced into // `ArcBorrow<U>`. -impl<T: ?Sized + Unsize<U>, U: ?Sized> core::ops::DispatchFromDyn<ArcBorrow<'_, U>> +#[cfg(not(CONFIG_RUSTC_HAS_COERCE_POINTEE))] +impl<T: ?Sized + core::marker::Unsize<U>, U: ?Sized> core::ops::DispatchFromDyn<ArcBorrow<'_, U>> for ArcBorrow<'_, T> { } @@ -508,11 +599,11 @@ impl<T: ?Sized> ArcBorrow<'_, T> { } /// Creates an [`ArcBorrow`] to an [`Arc`] that has previously been deconstructed with - /// [`Arc::into_raw`]. + /// [`Arc::into_raw`] or [`Arc::as_ptr`]. /// /// # Safety /// - /// * The provided pointer must originate from a call to [`Arc::into_raw`]. + /// * The provided pointer must originate from a call to [`Arc::into_raw`] or [`Arc::as_ptr`]. /// * For the duration of the lifetime annotated on this `ArcBorrow`, the reference count must /// not hit zero. /// * For the duration of the lifetime annotated on this `ArcBorrow`, there must not be a @@ -628,6 +719,48 @@ pub struct UniqueArc<T: ?Sized> { inner: Arc<T>, } +impl<T> InPlaceInit<T> for UniqueArc<T> { + type PinnedSelf = Pin<Self>; + + #[inline] + fn try_pin_init<E>(init: impl PinInit<T, E>, flags: Flags) -> Result<Self::PinnedSelf, E> + where + E: From<AllocError>, + { + UniqueArc::new_uninit(flags)?.write_pin_init(init) + } + + #[inline] + fn try_init<E>(init: impl Init<T, E>, flags: Flags) -> Result<Self, E> + where + E: From<AllocError>, + { + UniqueArc::new_uninit(flags)?.write_init(init) + } +} + +impl<T> InPlaceWrite<T> for UniqueArc<MaybeUninit<T>> { + type Initialized = UniqueArc<T>; + + fn write_init<E>(mut self, init: impl Init<T, E>) -> Result<Self::Initialized, E> { + let slot = self.as_mut_ptr(); + // SAFETY: When init errors/panics, slot will get deallocated but not dropped, + // slot is valid. + unsafe { init.__init(slot)? }; + // SAFETY: All fields have been initialized. + Ok(unsafe { self.assume_init() }) + } + + fn write_pin_init<E>(mut self, init: impl PinInit<T, E>) -> Result<Pin<Self::Initialized>, E> { + let slot = self.as_mut_ptr(); + // SAFETY: When init errors/panics, slot will get deallocated but not dropped, + // slot is valid and will not be moved, because we pin it later. + unsafe { init.__pinned_init(slot)? }; + // SAFETY: All fields have been initialized. + Ok(unsafe { self.assume_init() }.into()) + } +} + impl<T> UniqueArc<T> { /// Tries to allocate a new [`UniqueArc`] instance. pub fn new(value: T, flags: Flags) -> Result<Self, AllocError> { @@ -644,7 +777,7 @@ impl<T> UniqueArc<T> { try_init!(ArcInner { // SAFETY: There are no safety requirements for this FFI call. refcount: Opaque::new(unsafe { bindings::REFCOUNT_INIT(1) }), - data <- init::uninit::<T, AllocError>(), + data <- pin_init::uninit::<T, AllocError>(), }? AllocError), flags, )?; @@ -729,6 +862,56 @@ impl<T: ?Sized> DerefMut for UniqueArc<T> { } } +/// # Examples +/// +/// ``` +/// # use core::borrow::Borrow; +/// # use kernel::sync::UniqueArc; +/// struct Foo<B: Borrow<u32>>(B); +/// +/// // Owned instance. +/// let owned = Foo(1); +/// +/// // Owned instance using `UniqueArc`. +/// let arc = UniqueArc::new(1, GFP_KERNEL)?; +/// let shared = Foo(arc); +/// +/// let i = 1; +/// // Borrowed from `i`. +/// let borrowed = Foo(&i); +/// # Ok::<(), Error>(()) +/// ``` +impl<T: ?Sized> Borrow<T> for UniqueArc<T> { + fn borrow(&self) -> &T { + self.deref() + } +} + +/// # Examples +/// +/// ``` +/// # use core::borrow::BorrowMut; +/// # use kernel::sync::UniqueArc; +/// struct Foo<B: BorrowMut<u32>>(B); +/// +/// // Owned instance. +/// let owned = Foo(1); +/// +/// // Owned instance using `UniqueArc`. +/// let arc = UniqueArc::new(1, GFP_KERNEL)?; +/// let shared = Foo(arc); +/// +/// let mut i = 1; +/// // Borrowed from `i`. +/// let borrowed = Foo(&mut i); +/// # Ok::<(), Error>(()) +/// ``` +impl<T: ?Sized> BorrowMut<T> for UniqueArc<T> { + fn borrow_mut(&mut self) -> &mut T { + self.deref_mut() + } +} + impl<T: fmt::Display + ?Sized> fmt::Display for UniqueArc<T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Display::fmt(self.deref(), f) diff --git a/rust/kernel/sync/aref.rs b/rust/kernel/sync/aref.rs new file mode 100644 index 000000000000..dbd77bb68617 --- /dev/null +++ b/rust/kernel/sync/aref.rs @@ -0,0 +1,154 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Internal reference counting support. + +use core::{marker::PhantomData, mem::ManuallyDrop, ops::Deref, ptr::NonNull}; + +/// Types that are _always_ reference counted. +/// +/// It allows such types to define their own custom ref increment and decrement functions. +/// Additionally, it allows users to convert from a shared reference `&T` to an owned reference +/// [`ARef<T>`]. +/// +/// This is usually implemented by wrappers to existing structures on the C side of the code. For +/// Rust code, the recommendation is to use [`Arc`](crate::sync::Arc) to create reference-counted +/// instances of a type. +/// +/// # Safety +/// +/// Implementers must ensure that increments to the reference count keep the object alive in memory +/// at least until matching decrements are performed. +/// +/// Implementers must also ensure that all instances are reference-counted. (Otherwise they +/// won't be able to honour the requirement that [`AlwaysRefCounted::inc_ref`] keep the object +/// alive.) +pub unsafe trait AlwaysRefCounted { + /// Increments the reference count on the object. + fn inc_ref(&self); + + /// Decrements the reference count on the object. + /// + /// Frees the object when the count reaches zero. + /// + /// # Safety + /// + /// Callers must ensure that there was a previous matching increment to the reference count, + /// and that the object is no longer used after its reference count is decremented (as it may + /// result in the object being freed), unless the caller owns another increment on the refcount + /// (e.g., it calls [`AlwaysRefCounted::inc_ref`] twice, then calls + /// [`AlwaysRefCounted::dec_ref`] once). + unsafe fn dec_ref(obj: NonNull<Self>); +} + +/// An owned reference to an always-reference-counted object. +/// +/// The object's reference count is automatically decremented when an instance of [`ARef`] is +/// dropped. It is also automatically incremented when a new instance is created via +/// [`ARef::clone`]. +/// +/// # Invariants +/// +/// The pointer stored in `ptr` is non-null and valid for the lifetime of the [`ARef`] instance. In +/// particular, the [`ARef`] instance owns an increment on the underlying object's reference count. +pub struct ARef<T: AlwaysRefCounted> { + ptr: NonNull<T>, + _p: PhantomData<T>, +} + +// SAFETY: It is safe to send `ARef<T>` to another thread when the underlying `T` is `Sync` because +// it effectively means sharing `&T` (which is safe because `T` is `Sync`); additionally, it needs +// `T` to be `Send` because any thread that has an `ARef<T>` may ultimately access `T` using a +// mutable reference, for example, when the reference count reaches zero and `T` is dropped. +unsafe impl<T: AlwaysRefCounted + Sync + Send> Send for ARef<T> {} + +// SAFETY: It is safe to send `&ARef<T>` to another thread when the underlying `T` is `Sync` +// because it effectively means sharing `&T` (which is safe because `T` is `Sync`); additionally, +// it needs `T` to be `Send` because any thread that has a `&ARef<T>` may clone it and get an +// `ARef<T>` on that thread, so the thread may ultimately access `T` using a mutable reference, for +// example, when the reference count reaches zero and `T` is dropped. +unsafe impl<T: AlwaysRefCounted + Sync + Send> Sync for ARef<T> {} + +impl<T: AlwaysRefCounted> ARef<T> { + /// Creates a new instance of [`ARef`]. + /// + /// It takes over an increment of the reference count on the underlying object. + /// + /// # Safety + /// + /// Callers must ensure that the reference count was incremented at least once, and that they + /// are properly relinquishing one increment. That is, if there is only one increment, callers + /// must not use the underlying object anymore -- it is only safe to do so via the newly + /// created [`ARef`]. + pub unsafe fn from_raw(ptr: NonNull<T>) -> Self { + // INVARIANT: The safety requirements guarantee that the new instance now owns the + // increment on the refcount. + Self { + ptr, + _p: PhantomData, + } + } + + /// Consumes the `ARef`, returning a raw pointer. + /// + /// This function does not change the refcount. After calling this function, the caller is + /// responsible for the refcount previously managed by the `ARef`. + /// + /// # Examples + /// + /// ``` + /// use core::ptr::NonNull; + /// use kernel::types::{ARef, AlwaysRefCounted}; + /// + /// struct Empty {} + /// + /// # // SAFETY: TODO. + /// unsafe impl AlwaysRefCounted for Empty { + /// fn inc_ref(&self) {} + /// unsafe fn dec_ref(_obj: NonNull<Self>) {} + /// } + /// + /// let mut data = Empty {}; + /// let ptr = NonNull::<Empty>::new(&mut data).unwrap(); + /// # // SAFETY: TODO. + /// let data_ref: ARef<Empty> = unsafe { ARef::from_raw(ptr) }; + /// let raw_ptr: NonNull<Empty> = ARef::into_raw(data_ref); + /// + /// assert_eq!(ptr, raw_ptr); + /// ``` + pub fn into_raw(me: Self) -> NonNull<T> { + ManuallyDrop::new(me).ptr + } +} + +impl<T: AlwaysRefCounted> Clone for ARef<T> { + fn clone(&self) -> Self { + self.inc_ref(); + // SAFETY: We just incremented the refcount above. + unsafe { Self::from_raw(self.ptr) } + } +} + +impl<T: AlwaysRefCounted> Deref for ARef<T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + // SAFETY: The type invariants guarantee that the object is valid. + unsafe { self.ptr.as_ref() } + } +} + +impl<T: AlwaysRefCounted> From<&T> for ARef<T> { + fn from(b: &T) -> Self { + b.inc_ref(); + // SAFETY: We just incremented the refcount above. + unsafe { Self::from_raw(NonNull::from(b)) } + } +} + +impl<T: AlwaysRefCounted> Drop for ARef<T> { + fn drop(&mut self) { + // SAFETY: The type invariants guarantee that the `ARef` owns the reference we're about to + // decrement. + unsafe { T::dec_ref(self.ptr) }; + } +} diff --git a/rust/kernel/sync/completion.rs b/rust/kernel/sync/completion.rs new file mode 100644 index 000000000000..c50012a940a3 --- /dev/null +++ b/rust/kernel/sync/completion.rs @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Completion support. +//! +//! Reference: <https://docs.kernel.org/scheduler/completion.html> +//! +//! C header: [`include/linux/completion.h`](srctree/include/linux/completion.h) + +use crate::{bindings, prelude::*, types::Opaque}; + +/// Synchronization primitive to signal when a certain task has been completed. +/// +/// The [`Completion`] synchronization primitive signals when a certain task has been completed by +/// waking up other tasks that have been queued up to wait for the [`Completion`] to be completed. +/// +/// # Examples +/// +/// ``` +/// use kernel::sync::{Arc, Completion}; +/// use kernel::workqueue::{self, impl_has_work, new_work, Work, WorkItem}; +/// +/// #[pin_data] +/// struct MyTask { +/// #[pin] +/// work: Work<MyTask>, +/// #[pin] +/// done: Completion, +/// } +/// +/// impl_has_work! { +/// impl HasWork<Self> for MyTask { self.work } +/// } +/// +/// impl MyTask { +/// fn new() -> Result<Arc<Self>> { +/// let this = Arc::pin_init(pin_init!(MyTask { +/// work <- new_work!("MyTask::work"), +/// done <- Completion::new(), +/// }), GFP_KERNEL)?; +/// +/// let _ = workqueue::system().enqueue(this.clone()); +/// +/// Ok(this) +/// } +/// +/// fn wait_for_completion(&self) { +/// self.done.wait_for_completion(); +/// +/// pr_info!("Completion: task complete\n"); +/// } +/// } +/// +/// impl WorkItem for MyTask { +/// type Pointer = Arc<MyTask>; +/// +/// fn run(this: Arc<MyTask>) { +/// // process this task +/// this.done.complete_all(); +/// } +/// } +/// +/// let task = MyTask::new()?; +/// task.wait_for_completion(); +/// # Ok::<(), Error>(()) +/// ``` +#[pin_data] +pub struct Completion { + #[pin] + inner: Opaque<bindings::completion>, +} + +// SAFETY: `Completion` is safe to be send to any task. +unsafe impl Send for Completion {} + +// SAFETY: `Completion` is safe to be accessed concurrently. +unsafe impl Sync for Completion {} + +impl Completion { + /// Create an initializer for a new [`Completion`]. + pub fn new() -> impl PinInit<Self> { + pin_init!(Self { + inner <- Opaque::ffi_init(|slot: *mut bindings::completion| { + // SAFETY: `slot` is a valid pointer to an uninitialized `struct completion`. + unsafe { bindings::init_completion(slot) }; + }), + }) + } + + fn as_raw(&self) -> *mut bindings::completion { + self.inner.get() + } + + /// Signal all tasks waiting on this completion. + /// + /// This method wakes up all tasks waiting on this completion; after this operation the + /// completion is permanently done, i.e. signals all current and future waiters. + pub fn complete_all(&self) { + // SAFETY: `self.as_raw()` is a pointer to a valid `struct completion`. + unsafe { bindings::complete_all(self.as_raw()) }; + } + + /// Wait for completion of a task. + /// + /// This method waits for the completion of a task; it is not interruptible and there is no + /// timeout. + /// + /// See also [`Completion::complete_all`]. + pub fn wait_for_completion(&self) { + // SAFETY: `self.as_raw()` is a pointer to a valid `struct completion`. + unsafe { bindings::wait_for_completion(self.as_raw()) }; + } +} diff --git a/rust/kernel/sync/condvar.rs b/rust/kernel/sync/condvar.rs index 7df565038d7d..c6ec64295c9f 100644 --- a/rust/kernel/sync/condvar.rs +++ b/rust/kernel/sync/condvar.rs @@ -8,16 +8,15 @@ use super::{lock::Backend, lock::Guard, LockClassKey}; use crate::{ ffi::{c_int, c_long}, - init::PinInit, - pin_init, str::CStr, - task::{MAX_SCHEDULE_TIMEOUT, TASK_INTERRUPTIBLE, TASK_NORMAL, TASK_UNINTERRUPTIBLE}, + task::{ + MAX_SCHEDULE_TIMEOUT, TASK_FREEZABLE, TASK_INTERRUPTIBLE, TASK_NORMAL, TASK_UNINTERRUPTIBLE, + }, time::Jiffies, types::Opaque, }; -use core::marker::PhantomPinned; -use core::ptr; -use macros::pin_data; +use core::{marker::PhantomPinned, pin::Pin, ptr}; +use pin_init::{pin_data, pin_init, PinInit}; /// Creates a [`CondVar`] initialiser with the given name and a newly-created lock class. #[macro_export] @@ -37,7 +36,7 @@ pub use new_condvar; /// spuriously. /// /// Instances of [`CondVar`] need a lock class and to be pinned. The recommended way to create such -/// instances is with the [`pin_init`](crate::pin_init) and [`new_condvar`] macros. +/// instances is with the [`pin_init`](crate::pin_init!) and [`new_condvar`] macros. /// /// # Examples /// @@ -101,7 +100,7 @@ unsafe impl Sync for CondVar {} impl CondVar { /// Constructs a new condvar initialiser. - pub fn new(name: &'static CStr, key: &'static LockClassKey) -> impl PinInit<Self> { + pub fn new(name: &'static CStr, key: Pin<&'static LockClassKey>) -> impl PinInit<Self> { pin_init!(Self { _pin: PhantomPinned, // SAFETY: `slot` is valid while the closure is called and both `name` and `key` have @@ -159,6 +158,25 @@ impl CondVar { crate::current!().signal_pending() } + /// Releases the lock and waits for a notification in interruptible and freezable mode. + /// + /// The process is allowed to be frozen during this sleep. No lock should be held when calling + /// this function, and there is a lockdep assertion for this. Freezing a task that holds a lock + /// can trivially deadlock vs another task that needs that lock to complete before it too can + /// hit freezable. + #[must_use = "wait_interruptible_freezable returns if a signal is pending, so the caller must check the return value"] + pub fn wait_interruptible_freezable<T: ?Sized, B: Backend>( + &self, + guard: &mut Guard<'_, T, B>, + ) -> bool { + self.wait_internal( + TASK_INTERRUPTIBLE | TASK_FREEZABLE, + guard, + MAX_SCHEDULE_TIMEOUT, + ); + crate::current!().signal_pending() + } + /// Releases the lock and waits for a notification in interruptible mode. /// /// Atomically releases the given lock (whose ownership is proven by the guard) and puts the @@ -198,6 +216,7 @@ impl CondVar { /// This method behaves like `notify_one`, except that it hints to the scheduler that the /// current thread is about to go to sleep, so it should schedule the target thread on the same /// CPU. + #[inline] pub fn notify_sync(&self) { // SAFETY: `wait_queue_head` points to valid memory. unsafe { bindings::__wake_up_sync(self.wait_queue_head.get(), TASK_NORMAL) }; @@ -207,6 +226,7 @@ impl CondVar { /// /// This is not 'sticky' in the sense that if no thread is waiting, the notification is lost /// completely (as opposed to automatically waking up the next waiter). + #[inline] pub fn notify_one(&self) { self.notify(1); } @@ -215,6 +235,7 @@ impl CondVar { /// /// This is not 'sticky' in the sense that if no thread is waiting, the notification is lost /// completely (as opposed to automatically waking up the next waiter). + #[inline] pub fn notify_all(&self) { self.notify(0); } diff --git a/rust/kernel/sync/lock.rs b/rust/kernel/sync/lock.rs index 41dcddac69e2..27202beef90c 100644 --- a/rust/kernel/sync/lock.rs +++ b/rust/kernel/sync/lock.rs @@ -7,13 +7,11 @@ use super::LockClassKey; use crate::{ - init::PinInit, - pin_init, str::CStr, types::{NotThreadSafe, Opaque, ScopeGuard}, }; -use core::{cell::UnsafeCell, marker::PhantomPinned}; -use macros::pin_data; +use core::{cell::UnsafeCell, marker::PhantomPinned, pin::Pin}; +use pin_init::{pin_data, pin_init, PinInit}; pub mod mutex; pub mod spinlock; @@ -90,12 +88,20 @@ pub unsafe trait Backend { // SAFETY: The safety requirements ensure that the lock is initialised. *guard_state = unsafe { Self::lock(ptr) }; } + + /// Asserts that the lock is held using lockdep. + /// + /// # Safety + /// + /// Callers must ensure that [`Backend::init`] has been previously called. + unsafe fn assert_is_held(ptr: *mut Self::State); } /// A mutual exclusion primitive. /// /// Exposes one of the kernel locking primitives. Which one is exposed depends on the lock /// [`Backend`] specified as the generic parameter `B`. +#[repr(C)] #[pin_data] pub struct Lock<T: ?Sized, B: Backend> { /// The kernel lock object. @@ -121,7 +127,7 @@ unsafe impl<T: ?Sized + Send, B: Backend> Sync for Lock<T, B> {} impl<T, B: Backend> Lock<T, B> { /// Constructs a new lock initialiser. - pub fn new(t: T, name: &'static CStr, key: &'static LockClassKey) -> impl PinInit<Self> { + pub fn new(t: T, name: &'static CStr, key: Pin<&'static LockClassKey>) -> impl PinInit<Self> { pin_init!(Self { data: UnsafeCell::new(t), _pin: PhantomPinned, @@ -134,6 +140,28 @@ impl<T, B: Backend> Lock<T, B> { } } +impl<B: Backend> Lock<(), B> { + /// Constructs a [`Lock`] from a raw pointer. + /// + /// This can be useful for interacting with a lock which was initialised outside of Rust. + /// + /// # Safety + /// + /// The caller promises that `ptr` points to a valid initialised instance of [`State`] during + /// the whole lifetime of `'a`. + /// + /// [`State`]: Backend::State + pub unsafe fn from_raw<'a>(ptr: *mut B::State) -> &'a Self { + // SAFETY: + // - By the safety contract `ptr` must point to a valid initialised instance of `B::State` + // - Since the lock data type is `()` which is a ZST, `state` is the only non-ZST member of + // the struct + // - Combined with `#[repr(C)]`, this guarantees `Self` has an equivalent data layout to + // `B::State`. + unsafe { &*ptr.cast() } + } +} + impl<T: ?Sized, B: Backend> Lock<T, B> { /// Acquires the lock and gives the caller access to the data protected by it. pub fn lock(&self) -> Guard<'_, T, B> { @@ -147,6 +175,8 @@ impl<T: ?Sized, B: Backend> Lock<T, B> { /// Tries to acquire the lock. /// /// Returns a guard that can be used to access the data protected by the lock if successful. + // `Option<T>` is not `#[must_use]` even if `T` is, thus the attribute is needed here. + #[must_use = "if unused, the lock will be immediately unlocked"] pub fn try_lock(&self) -> Option<Guard<'_, T, B>> { // SAFETY: The constructor of the type calls `init`, so the existence of the object proves // that `init` was called. @@ -169,7 +199,37 @@ pub struct Guard<'a, T: ?Sized, B: Backend> { // SAFETY: `Guard` is sync when the data protected by the lock is also sync. unsafe impl<T: Sync + ?Sized, B: Backend> Sync for Guard<'_, T, B> {} -impl<T: ?Sized, B: Backend> Guard<'_, T, B> { +impl<'a, T: ?Sized, B: Backend> Guard<'a, T, B> { + /// Returns the lock that this guard originates from. + /// + /// # Examples + /// + /// The following example shows how to use [`Guard::lock_ref()`] to assert the corresponding + /// lock is held. + /// + /// ``` + /// # use kernel::{new_spinlock, sync::lock::{Backend, Guard, Lock}}; + /// # use pin_init::stack_pin_init; + /// + /// fn assert_held<T, B: Backend>(guard: &Guard<'_, T, B>, lock: &Lock<T, B>) { + /// // Address-equal means the same lock. + /// assert!(core::ptr::eq(guard.lock_ref(), lock)); + /// } + /// + /// // Creates a new lock on the stack. + /// stack_pin_init!{ + /// let l = new_spinlock!(42) + /// } + /// + /// let g = l.lock(); + /// + /// // `g` originates from `l`. + /// assert_held(&g, &l); + /// ``` + pub fn lock_ref(&self) -> &'a Lock<T, B> { + self.lock + } + pub(crate) fn do_unlocked<U>(&mut self, cb: impl FnOnce() -> U) -> U { // SAFETY: The caller owns the lock, so it is safe to unlock it. unsafe { B::unlock(self.lock.state.get(), &self.state) }; @@ -211,7 +271,10 @@ impl<'a, T: ?Sized, B: Backend> Guard<'a, T, B> { /// # Safety /// /// The caller must ensure that it owns the lock. - pub(crate) unsafe fn new(lock: &'a Lock<T, B>, state: B::GuardState) -> Self { + pub unsafe fn new(lock: &'a Lock<T, B>, state: B::GuardState) -> Self { + // SAFETY: The caller can only hold the lock if `Backend::init` has already been called. + unsafe { B::assert_is_held(lock.state.get()) }; + Self { lock, state, diff --git a/rust/kernel/sync/lock/global.rs b/rust/kernel/sync/lock/global.rs index 480ee724e3cc..d65f94b5caf2 100644 --- a/rust/kernel/sync/lock/global.rs +++ b/rust/kernel/sync/lock/global.rs @@ -13,6 +13,7 @@ use crate::{ use core::{ cell::UnsafeCell, marker::{PhantomData, PhantomPinned}, + pin::Pin, }; /// Trait implemented for marker types for global locks. @@ -26,7 +27,7 @@ pub trait GlobalLockBackend { /// The backend used for this global lock. type Backend: Backend + 'static; /// The class for this global lock. - fn get_lock_class() -> &'static LockClassKey; + fn get_lock_class() -> Pin<&'static LockClassKey>; } /// Type used for global locks. @@ -270,7 +271,7 @@ macro_rules! global_lock { type Item = $valuety; type Backend = $crate::global_lock_inner!(backend $kind); - fn get_lock_class() -> &'static $crate::sync::LockClassKey { + fn get_lock_class() -> Pin<&'static $crate::sync::LockClassKey> { $crate::static_lock_class!() } } diff --git a/rust/kernel/sync/lock/mutex.rs b/rust/kernel/sync/lock/mutex.rs index 0e946ebefce1..581cee7ab842 100644 --- a/rust/kernel/sync/lock/mutex.rs +++ b/rust/kernel/sync/lock/mutex.rs @@ -26,7 +26,7 @@ pub use new_mutex; /// Since it may block, [`Mutex`] needs to be used with care in atomic contexts. /// /// Instances of [`Mutex`] need a lock class and to be pinned. The recommended way to create such -/// instances is with the [`pin_init`](crate::pin_init) and [`new_mutex`] macros. +/// instances is with the [`pin_init`](pin_init::pin_init) and [`new_mutex`] macros. /// /// # Examples /// @@ -86,6 +86,14 @@ pub use new_mutex; /// [`struct mutex`]: srctree/include/linux/mutex.h pub type Mutex<T> = super::Lock<T, MutexBackend>; +/// A [`Guard`] acquired from locking a [`Mutex`]. +/// +/// This is simply a type alias for a [`Guard`] returned from locking a [`Mutex`]. It will unlock +/// the [`Mutex`] upon being dropped. +/// +/// [`Guard`]: super::Guard +pub type MutexGuard<'a, T> = super::Guard<'a, T, MutexBackend>; + /// A kernel `struct mutex` lock backend. pub struct MutexBackend; @@ -126,4 +134,9 @@ unsafe impl super::Backend for MutexBackend { None } } + + unsafe fn assert_is_held(ptr: *mut Self::State) { + // SAFETY: The `ptr` pointer is guaranteed to be valid and initialized before use. + unsafe { bindings::mutex_assert_is_held(ptr) } + } } diff --git a/rust/kernel/sync/lock/spinlock.rs b/rust/kernel/sync/lock/spinlock.rs index 9f4d128bed98..d7be38ccbdc7 100644 --- a/rust/kernel/sync/lock/spinlock.rs +++ b/rust/kernel/sync/lock/spinlock.rs @@ -24,7 +24,7 @@ pub use new_spinlock; /// unlocked, at which point another CPU will be allowed to make progress. /// /// Instances of [`SpinLock`] need a lock class and to be pinned. The recommended way to create such -/// instances is with the [`pin_init`](crate::pin_init) and [`new_spinlock`] macros. +/// instances is with the [`pin_init`](pin_init::pin_init) and [`new_spinlock`] macros. /// /// # Examples /// @@ -87,6 +87,14 @@ pub type SpinLock<T> = super::Lock<T, SpinLockBackend>; /// A kernel `spinlock_t` lock backend. pub struct SpinLockBackend; +/// A [`Guard`] acquired from locking a [`SpinLock`]. +/// +/// This is simply a type alias for a [`Guard`] returned from locking a [`SpinLock`]. It will unlock +/// the [`SpinLock`] upon being dropped. +/// +/// [`Guard`]: super::Guard +pub type SpinLockGuard<'a, T> = super::Guard<'a, T, SpinLockBackend>; + // SAFETY: The underlying kernel `spinlock_t` object ensures mutual exclusion. `relock` uses the // default implementation that always calls the same locking method. unsafe impl super::Backend for SpinLockBackend { @@ -125,4 +133,9 @@ unsafe impl super::Backend for SpinLockBackend { None } } + + unsafe fn assert_is_held(ptr: *mut Self::State) { + // SAFETY: The `ptr` pointer is guaranteed to be valid and initialized before use. + unsafe { bindings::spin_assert_is_held(ptr) } + } } diff --git a/rust/kernel/sync/locked_by.rs b/rust/kernel/sync/locked_by.rs index a7b244675c2b..61f100a45b35 100644 --- a/rust/kernel/sync/locked_by.rs +++ b/rust/kernel/sync/locked_by.rs @@ -55,7 +55,7 @@ use core::{cell::UnsafeCell, mem::size_of, ptr}; /// fn print_bytes_used(dir: &Directory, file: &File) { /// let guard = dir.inner.lock(); /// let inner_file = file.inner.access(&guard); -/// pr_info!("{} {}", guard.bytes_used, inner_file.bytes_used); +/// pr_info!("{} {}\n", guard.bytes_used, inner_file.bytes_used); /// } /// /// /// Increments `bytes_used` for both the directory and file. diff --git a/rust/kernel/sync/poll.rs b/rust/kernel/sync/poll.rs index d5f17153b424..0ec985d560c8 100644 --- a/rust/kernel/sync/poll.rs +++ b/rust/kernel/sync/poll.rs @@ -9,9 +9,8 @@ use crate::{ fs::File, prelude::*, sync::{CondVar, LockClassKey}, - types::Opaque, }; -use core::ops::Deref; +use core::{marker::PhantomData, ops::Deref}; /// Creates a [`PollCondVar`] initialiser with the given name and a newly-created lock class. #[macro_export] @@ -23,58 +22,43 @@ macro_rules! new_poll_condvar { }; } -/// Wraps the kernel's `struct poll_table`. +/// Wraps the kernel's `poll_table`. /// /// # Invariants /// -/// This struct contains a valid `struct poll_table`. -/// -/// For a `struct poll_table` to be valid, its `_qproc` function must follow the safety -/// requirements of `_qproc` functions: -/// -/// * The `_qproc` function is given permission to enqueue a waiter to the provided `poll_table` -/// during the call. Once the waiter is removed and an rcu grace period has passed, it must no -/// longer access the `wait_queue_head`. +/// The pointer must be null or reference a valid `poll_table`. #[repr(transparent)] -pub struct PollTable(Opaque<bindings::poll_table>); +pub struct PollTable<'a> { + table: *mut bindings::poll_table, + _lifetime: PhantomData<&'a bindings::poll_table>, +} -impl PollTable { - /// Creates a reference to a [`PollTable`] from a valid pointer. +impl<'a> PollTable<'a> { + /// Creates a [`PollTable`] from a valid pointer. /// /// # Safety /// - /// The caller must ensure that for the duration of 'a, the pointer will point at a valid poll - /// table (as defined in the type invariants). - /// - /// The caller must also ensure that the `poll_table` is only accessed via the returned - /// reference for the duration of 'a. - pub unsafe fn from_ptr<'a>(ptr: *mut bindings::poll_table) -> &'a mut PollTable { - // SAFETY: The safety requirements guarantee the validity of the dereference, while the - // `PollTable` type being transparent makes the cast ok. - unsafe { &mut *ptr.cast() } - } - - fn get_qproc(&self) -> bindings::poll_queue_proc { - let ptr = self.0.get(); - // SAFETY: The `ptr` is valid because it originates from a reference, and the `_qproc` - // field is not modified concurrently with this call since we have an immutable reference. - unsafe { (*ptr)._qproc } + /// The pointer must be null or reference a valid `poll_table` for the duration of `'a`. + pub unsafe fn from_raw(table: *mut bindings::poll_table) -> Self { + // INVARIANTS: The safety requirements are the same as the struct invariants. + PollTable { + table, + _lifetime: PhantomData, + } } /// Register this [`PollTable`] with the provided [`PollCondVar`], so that it can be notified /// using the condition variable. - pub fn register_wait(&mut self, file: &File, cv: &PollCondVar) { - if let Some(qproc) = self.get_qproc() { - // SAFETY: The pointers to `file` and `self` need to be valid for the duration of this - // call to `qproc`, which they are because they are references. - // - // The `cv.wait_queue_head` pointer must be valid until an rcu grace period after the - // waiter is removed. The `PollCondVar` is pinned, so before `cv.wait_queue_head` can - // be destroyed, the destructor must run. That destructor first removes all waiters, - // and then waits for an rcu grace period. Therefore, `cv.wait_queue_head` is valid for - // long enough. - unsafe { qproc(file.as_ptr() as _, cv.wait_queue_head.get(), self.0.get()) }; - } + pub fn register_wait(&self, file: &File, cv: &PollCondVar) { + // SAFETY: + // * `file.as_ptr()` references a valid file for the duration of this call. + // * `self.table` is null or references a valid poll_table for the duration of this call. + // * Since `PollCondVar` is pinned, its destructor is guaranteed to run before the memory + // containing `cv.wait_queue_head` is invalidated. Since the destructor clears all + // waiters and then waits for an rcu grace period, it's guaranteed that + // `cv.wait_queue_head` remains valid for at least an rcu grace period after the removal + // of the last waiter. + unsafe { bindings::poll_wait(file.as_ptr(), cv.wait_queue_head.get(), self.table) } } } @@ -89,7 +73,7 @@ pub struct PollCondVar { impl PollCondVar { /// Constructs a new condvar initialiser. - pub fn new(name: &'static CStr, key: &'static LockClassKey) -> impl PinInit<Self> { + pub fn new(name: &'static CStr, key: Pin<&'static LockClassKey>) -> impl PinInit<Self> { pin_init!(Self { inner <- CondVar::new(name, key), }) @@ -107,6 +91,7 @@ impl Deref for PollCondVar { #[pinned_drop] impl PinnedDrop for PollCondVar { + #[inline] fn drop(self: Pin<&mut Self>) { // Clear anything registered using `register_wait`. // diff --git a/rust/kernel/sync/rcu.rs b/rust/kernel/sync/rcu.rs new file mode 100644 index 000000000000..a32bef6e490b --- /dev/null +++ b/rust/kernel/sync/rcu.rs @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! RCU support. +//! +//! C header: [`include/linux/rcupdate.h`](srctree/include/linux/rcupdate.h) + +use crate::{bindings, types::NotThreadSafe}; + +/// Evidence that the RCU read side lock is held on the current thread/CPU. +/// +/// The type is explicitly not `Send` because this property is per-thread/CPU. +/// +/// # Invariants +/// +/// The RCU read side lock is actually held while instances of this guard exist. +pub struct Guard(NotThreadSafe); + +impl Guard { + /// Acquires the RCU read side lock and returns a guard. + #[inline] + pub fn new() -> Self { + // SAFETY: An FFI call with no additional requirements. + unsafe { bindings::rcu_read_lock() }; + // INVARIANT: The RCU read side lock was just acquired above. + Self(NotThreadSafe) + } + + /// Explicitly releases the RCU read side lock. + #[inline] + pub fn unlock(self) {} +} + +impl Default for Guard { + #[inline] + fn default() -> Self { + Self::new() + } +} + +impl Drop for Guard { + #[inline] + fn drop(&mut self) { + // SAFETY: By the type invariants, the RCU read side is locked, so it is ok to unlock it. + unsafe { bindings::rcu_read_unlock() }; + } +} + +/// Acquires the RCU read side lock. +#[inline] +pub fn read_lock() -> Guard { + Guard::new() +} diff --git a/rust/kernel/task.rs b/rust/kernel/task.rs index 07bc22a7645c..7d0935bc325c 100644 --- a/rust/kernel/task.rs +++ b/rust/kernel/task.rs @@ -7,6 +7,7 @@ use crate::{ bindings, ffi::{c_int, c_long, c_uint}, + mm::MmWithUser, pid_namespace::PidNamespace, types::{ARef, NotThreadSafe, Opaque}, }; @@ -23,6 +24,8 @@ pub const MAX_SCHEDULE_TIMEOUT: c_long = c_long::MAX; pub const TASK_INTERRUPTIBLE: c_int = bindings::TASK_INTERRUPTIBLE as c_int; /// Bitmask for tasks that are sleeping in an uninterruptible state. pub const TASK_UNINTERRUPTIBLE: c_int = bindings::TASK_UNINTERRUPTIBLE as c_int; +/// Bitmask for tasks that are sleeping in a freezable state. +pub const TASK_FREEZABLE: c_int = bindings::TASK_FREEZABLE as c_int; /// Convenience constant for waking up tasks regardless of whether they are in interruptible or /// uninterruptible sleep. pub const TASK_NORMAL: c_uint = bindings::TASK_NORMAL as c_uint; @@ -31,22 +34,20 @@ pub const TASK_NORMAL: c_uint = bindings::TASK_NORMAL as c_uint; #[macro_export] macro_rules! current { () => { - // SAFETY: Deref + addr-of below create a temporary `TaskRef` that cannot outlive the - // caller. + // SAFETY: This expression creates a temporary value that is dropped at the end of the + // caller's scope. The following mechanisms ensure that the resulting `&CurrentTask` cannot + // leave current task context: + // + // * To return to userspace, the caller must leave the current scope. + // * Operations such as `begin_new_exec()` are necessarily unsafe and the caller of + // `begin_new_exec()` is responsible for safety. + // * Rust abstractions for things such as a `kthread_use_mm()` scope must require the + // closure to be `Send`, so the `NotThreadSafe` field of `CurrentTask` ensures that the + // `&CurrentTask` cannot cross the scope in either direction. unsafe { &*$crate::task::Task::current() } }; } -/// Returns the currently running task's pid namespace. -#[macro_export] -macro_rules! current_pid_ns { - () => { - // SAFETY: Deref + addr-of below create a temporary `PidNamespaceRef` that cannot outlive - // the caller. - unsafe { &*$crate::task::Task::current_pid_ns() } - }; -} - /// Wraps the kernel's `struct task_struct`. /// /// # Invariants @@ -85,7 +86,7 @@ macro_rules! current_pid_ns { /// impl State { /// fn new() -> Self { /// Self { -/// creator: current!().into(), +/// creator: ARef::from(&**current!()), /// index: 0, /// } /// } @@ -105,8 +106,46 @@ unsafe impl Send for Task {} // synchronised by C code (e.g., `signal_pending`). unsafe impl Sync for Task {} +/// Represents the [`Task`] in the `current` global. +/// +/// This type exists to provide more efficient operations that are only valid on the current task. +/// For example, to retrieve the pid-namespace of a task, you must use rcu protection unless it is +/// the current task. +/// +/// # Invariants +/// +/// Each value of this type must only be accessed from the task context it was created within. +/// +/// Of course, every thread is in a different task context, but for the purposes of this invariant, +/// these operations also permanently leave the task context: +/// +/// * Returning to userspace from system call context. +/// * Calling `release_task()`. +/// * Calling `begin_new_exec()` in a binary format loader. +/// +/// Other operations temporarily create a new sub-context: +/// +/// * Calling `kthread_use_mm()` creates a new context, and `kthread_unuse_mm()` returns to the +/// old context. +/// +/// This means that a `CurrentTask` obtained before a `kthread_use_mm()` call may be used again +/// once `kthread_unuse_mm()` is called, but it must not be used between these two calls. +/// Conversely, a `CurrentTask` obtained between a `kthread_use_mm()`/`kthread_unuse_mm()` pair +/// must not be used after `kthread_unuse_mm()`. +#[repr(transparent)] +pub struct CurrentTask(Task, NotThreadSafe); + +// Make all `Task` methods available on `CurrentTask`. +impl Deref for CurrentTask { + type Target = Task; + #[inline] + fn deref(&self) -> &Task { + &self.0 + } +} + /// The type of process identifiers (PIDs). -type Pid = bindings::pid_t; +pub type Pid = bindings::pid_t; /// The type of user identifiers (UIDs). #[derive(Copy, Clone)] @@ -131,119 +170,30 @@ impl Task { /// /// # Safety /// - /// Callers must ensure that the returned object doesn't outlive the current task/thread. - pub unsafe fn current() -> impl Deref<Target = Task> { - struct TaskRef<'a> { - task: &'a Task, - _not_send: NotThreadSafe, + /// Callers must ensure that the returned object is only used to access a [`CurrentTask`] + /// within the task context that was active when this function was called. For more details, + /// see the invariants section for [`CurrentTask`]. + #[inline] + pub unsafe fn current() -> impl Deref<Target = CurrentTask> { + struct TaskRef { + task: *const CurrentTask, } - impl Deref for TaskRef<'_> { - type Target = Task; + impl Deref for TaskRef { + type Target = CurrentTask; fn deref(&self) -> &Self::Target { - self.task + // SAFETY: The returned reference borrows from this `TaskRef`, so it cannot outlive + // the `TaskRef`, which the caller of `Task::current()` has promised will not + // outlive the task/thread for which `self.task` is the `current` pointer. Thus, it + // is okay to return a `CurrentTask` reference here. + unsafe { &*self.task } } } - let current = Task::current_raw(); TaskRef { - // SAFETY: If the current thread is still running, the current task is valid. Given - // that `TaskRef` is not `Send`, we know it cannot be transferred to another thread - // (where it could potentially outlive the caller). - task: unsafe { &*current.cast() }, - _not_send: NotThreadSafe, - } - } - - /// Returns a PidNamespace reference for the currently executing task's/thread's pid namespace. - /// - /// This function can be used to create an unbounded lifetime by e.g., storing the returned - /// PidNamespace in a global variable which would be a bug. So the recommended way to get the - /// current task's/thread's pid namespace is to use the [`current_pid_ns`] macro because it is - /// safe. - /// - /// # Safety - /// - /// Callers must ensure that the returned object doesn't outlive the current task/thread. - pub unsafe fn current_pid_ns() -> impl Deref<Target = PidNamespace> { - struct PidNamespaceRef<'a> { - task: &'a PidNamespace, - _not_send: NotThreadSafe, - } - - impl Deref for PidNamespaceRef<'_> { - type Target = PidNamespace; - - fn deref(&self) -> &Self::Target { - self.task - } - } - - // The lifetime of `PidNamespace` is bound to `Task` and `struct pid`. - // - // The `PidNamespace` of a `Task` doesn't ever change once the `Task` is alive. A - // `unshare(CLONE_NEWPID)` or `setns(fd_pidns/pidfd, CLONE_NEWPID)` will not have an effect - // on the calling `Task`'s pid namespace. It will only effect the pid namespace of children - // created by the calling `Task`. This invariant guarantees that after having acquired a - // reference to a `Task`'s pid namespace it will remain unchanged. - // - // When a task has exited and been reaped `release_task()` will be called. This will set - // the `PidNamespace` of the task to `NULL`. So retrieving the `PidNamespace` of a task - // that is dead will return `NULL`. Note, that neither holding the RCU lock nor holding a - // referencing count to - // the `Task` will prevent `release_task()` being called. - // - // In order to retrieve the `PidNamespace` of a `Task` the `task_active_pid_ns()` function - // can be used. There are two cases to consider: - // - // (1) retrieving the `PidNamespace` of the `current` task - // (2) retrieving the `PidNamespace` of a non-`current` task - // - // From system call context retrieving the `PidNamespace` for case (1) is always safe and - // requires neither RCU locking nor a reference count to be held. Retrieving the - // `PidNamespace` after `release_task()` for current will return `NULL` but no codepath - // like that is exposed to Rust. - // - // Retrieving the `PidNamespace` from system call context for (2) requires RCU protection. - // Accessing `PidNamespace` outside of RCU protection requires a reference count that - // must've been acquired while holding the RCU lock. Note that accessing a non-`current` - // task means `NULL` can be returned as the non-`current` task could have already passed - // through `release_task()`. - // - // To retrieve (1) the `current_pid_ns!()` macro should be used which ensure that the - // returned `PidNamespace` cannot outlive the calling scope. The associated - // `current_pid_ns()` function should not be called directly as it could be abused to - // created an unbounded lifetime for `PidNamespace`. The `current_pid_ns!()` macro allows - // Rust to handle the common case of accessing `current`'s `PidNamespace` without RCU - // protection and without having to acquire a reference count. - // - // For (2) the `task_get_pid_ns()` method must be used. This will always acquire a - // reference on `PidNamespace` and will return an `Option` to force the caller to - // explicitly handle the case where `PidNamespace` is `None`, something that tends to be - // forgotten when doing the equivalent operation in `C`. Missing RCU primitives make it - // difficult to perform operations that are otherwise safe without holding a reference - // count as long as RCU protection is guaranteed. But it is not important currently. But we - // do want it in the future. - // - // Note for (2) the required RCU protection around calling `task_active_pid_ns()` - // synchronizes against putting the last reference of the associated `struct pid` of - // `task->thread_pid`. The `struct pid` stored in that field is used to retrieve the - // `PidNamespace` of the caller. When `release_task()` is called `task->thread_pid` will be - // `NULL`ed and `put_pid()` on said `struct pid` will be delayed in `free_pid()` via - // `call_rcu()` allowing everyone with an RCU protected access to the `struct pid` acquired - // from `task->thread_pid` to finish. - // - // SAFETY: The current task's pid namespace is valid as long as the current task is running. - let pidns = unsafe { bindings::task_active_pid_ns(Task::current_raw()) }; - PidNamespaceRef { - // SAFETY: If the current thread is still running, the current task and its associated - // pid namespace are valid. `PidNamespaceRef` is not `Send`, so we know it cannot be - // transferred to another thread (where it could potentially outlive the current - // `Task`). The caller needs to ensure that the PidNamespaceRef doesn't outlive the - // current task/thread. - task: unsafe { PidNamespace::from_ptr(pidns) }, - _not_send: NotThreadSafe, + // CAST: The layout of `struct task_struct` and `CurrentTask` is identical. + task: Task::current_raw().cast(), } } @@ -273,24 +223,28 @@ impl Task { } /// Returns the UID of the given task. + #[inline] pub fn uid(&self) -> Kuid { // SAFETY: It's always safe to call `task_uid` on a valid task. Kuid::from_raw(unsafe { bindings::task_uid(self.as_ptr()) }) } /// Returns the effective UID of the given task. + #[inline] pub fn euid(&self) -> Kuid { // SAFETY: It's always safe to call `task_euid` on a valid task. Kuid::from_raw(unsafe { bindings::task_euid(self.as_ptr()) }) } /// Determines whether the given task has pending signals. + #[inline] pub fn signal_pending(&self) -> bool { // SAFETY: It's always safe to call `signal_pending` on a valid task. unsafe { bindings::signal_pending(self.as_ptr()) != 0 } } /// Returns task's pid namespace with elevated reference count + #[inline] pub fn get_pid_ns(&self) -> Option<ARef<PidNamespace>> { // SAFETY: By the type invariant, we know that `self.0` is valid. let ptr = unsafe { bindings::task_get_pid_ns(self.as_ptr()) }; @@ -306,6 +260,7 @@ impl Task { /// Returns the given task's pid in the provided pid namespace. #[doc(alias = "task_tgid_nr_ns")] + #[inline] pub fn tgid_nr_ns(&self, pidns: Option<&PidNamespace>) -> Pid { let pidns = match pidns { Some(pidns) => pidns.as_ptr(), @@ -319,20 +274,87 @@ impl Task { } /// Wakes up the task. + #[inline] pub fn wake_up(&self) { - // SAFETY: It's always safe to call `signal_pending` on a valid task, even if the task + // SAFETY: It's always safe to call `wake_up_process` on a valid task, even if the task // running. unsafe { bindings::wake_up_process(self.as_ptr()) }; } } +impl CurrentTask { + /// Access the address space of the current task. + /// + /// This function does not touch the refcount of the mm. + #[inline] + pub fn mm(&self) -> Option<&MmWithUser> { + // SAFETY: The `mm` field of `current` is not modified from other threads, so reading it is + // not a data race. + let mm = unsafe { (*self.as_ptr()).mm }; + + if mm.is_null() { + return None; + } + + // SAFETY: If `current->mm` is non-null, then it references a valid mm with a non-zero + // value of `mm_users`. Furthermore, the returned `&MmWithUser` borrows from this + // `CurrentTask`, so it cannot escape the scope in which the current pointer was obtained. + // + // This is safe even if `kthread_use_mm()`/`kthread_unuse_mm()` are used. There are two + // relevant cases: + // * If the `&CurrentTask` was created before `kthread_use_mm()`, then it cannot be + // accessed during the `kthread_use_mm()`/`kthread_unuse_mm()` scope due to the + // `NotThreadSafe` field of `CurrentTask`. + // * If the `&CurrentTask` was created within a `kthread_use_mm()`/`kthread_unuse_mm()` + // scope, then the `&CurrentTask` cannot escape that scope, so the returned `&MmWithUser` + // also cannot escape that scope. + // In either case, it's not possible to read `current->mm` and keep using it after the + // scope is ended with `kthread_unuse_mm()`. + Some(unsafe { MmWithUser::from_raw(mm) }) + } + + /// Access the pid namespace of the current task. + /// + /// This function does not touch the refcount of the namespace or use RCU protection. + /// + /// To access the pid namespace of another task, see [`Task::get_pid_ns`]. + #[doc(alias = "task_active_pid_ns")] + #[inline] + pub fn active_pid_ns(&self) -> Option<&PidNamespace> { + // SAFETY: It is safe to call `task_active_pid_ns` without RCU protection when calling it + // on the current task. + let active_ns = unsafe { bindings::task_active_pid_ns(self.as_ptr()) }; + + if active_ns.is_null() { + return None; + } + + // The lifetime of `PidNamespace` is bound to `Task` and `struct pid`. + // + // The `PidNamespace` of a `Task` doesn't ever change once the `Task` is alive. + // + // From system call context retrieving the `PidNamespace` for the current task is always + // safe and requires neither RCU locking nor a reference count to be held. Retrieving the + // `PidNamespace` after `release_task()` for current will return `NULL` but no codepath + // like that is exposed to Rust. + // + // SAFETY: If `current`'s pid ns is non-null, then it references a valid pid ns. + // Furthermore, the returned `&PidNamespace` borrows from this `CurrentTask`, so it cannot + // escape the scope in which the current pointer was obtained, e.g. it cannot live past a + // `release_task()` call. + Some(unsafe { PidNamespace::from_ptr(active_ns) }) + } +} + // SAFETY: The type invariants guarantee that `Task` is always refcounted. unsafe impl crate::types::AlwaysRefCounted for Task { + #[inline] fn inc_ref(&self) { // SAFETY: The existence of a shared reference means that the refcount is nonzero. unsafe { bindings::get_task_struct(self.as_ptr()) }; } + #[inline] unsafe fn dec_ref(obj: ptr::NonNull<Self>) { // SAFETY: The safety requirements guarantee that the refcount is nonzero. unsafe { bindings::put_task_struct(obj.cast().as_ptr()) } @@ -378,3 +400,27 @@ impl PartialEq for Kuid { } impl Eq for Kuid {} + +/// Annotation for functions that can sleep. +/// +/// Equivalent to the C side [`might_sleep()`], this function serves as +/// a debugging aid and a potential scheduling point. +/// +/// This function can only be used in a nonatomic context. +/// +/// [`might_sleep()`]: https://docs.kernel.org/driver-api/basics.html#c.might_sleep +#[track_caller] +#[inline] +pub fn might_sleep() { + #[cfg(CONFIG_DEBUG_ATOMIC_SLEEP)] + { + let loc = core::panic::Location::caller(); + let file = kernel::file_from_location(loc); + + // SAFETY: `file.as_ptr()` is valid for reading and guaranteed to be nul-terminated. + unsafe { crate::bindings::__might_sleep(file.as_ptr().cast(), loc.line() as i32) } + } + + // SAFETY: Always safe to call. + unsafe { crate::bindings::might_resched() } +} diff --git a/rust/kernel/time.rs b/rust/kernel/time.rs index 379c0f5772e5..64c8dcf548d6 100644 --- a/rust/kernel/time.rs +++ b/rust/kernel/time.rs @@ -5,12 +5,39 @@ //! This module contains the kernel APIs related to time and timers that //! have been ported or wrapped for usage by Rust code in the kernel. //! +//! There are two types in this module: +//! +//! - The [`Instant`] type represents a specific point in time. +//! - The [`Delta`] type represents a span of time. +//! +//! Note that the C side uses `ktime_t` type to represent both. However, timestamp +//! and timedelta are different. To avoid confusion, we use two different types. +//! +//! A [`Instant`] object can be created by calling the [`Instant::now()`] function. +//! It represents a point in time at which the object was created. +//! By calling the [`Instant::elapsed()`] method, a [`Delta`] object representing +//! the elapsed time can be created. The [`Delta`] object can also be created +//! by subtracting two [`Instant`] objects. +//! +//! A [`Delta`] type supports methods to retrieve the duration in various units. +//! //! C header: [`include/linux/jiffies.h`](srctree/include/linux/jiffies.h). //! C header: [`include/linux/ktime.h`](srctree/include/linux/ktime.h). +use core::marker::PhantomData; + +pub mod delay; +pub mod hrtimer; + +/// The number of nanoseconds per microsecond. +pub const NSEC_PER_USEC: i64 = bindings::NSEC_PER_USEC as i64; + /// The number of nanoseconds per millisecond. pub const NSEC_PER_MSEC: i64 = bindings::NSEC_PER_MSEC as i64; +/// The number of nanoseconds per second. +pub const NSEC_PER_SEC: i64 = bindings::NSEC_PER_SEC as i64; + /// The time unit of Linux kernel. One jiffy equals (1/HZ) second. pub type Jiffies = crate::ffi::c_ulong; @@ -25,59 +52,264 @@ pub fn msecs_to_jiffies(msecs: Msecs) -> Jiffies { unsafe { bindings::__msecs_to_jiffies(msecs) } } -/// A Rust wrapper around a `ktime_t`. +/// Trait for clock sources. +/// +/// Selection of the clock source depends on the use case. In some cases the usage of a +/// particular clock is mandatory, e.g. in network protocols, filesystems. In other +/// cases the user of the clock has to decide which clock is best suited for the +/// purpose. In most scenarios clock [`Monotonic`] is the best choice as it +/// provides a accurate monotonic notion of time (leap second smearing ignored). +pub trait ClockSource { + /// The kernel clock ID associated with this clock source. + /// + /// This constant corresponds to the C side `clockid_t` value. + const ID: bindings::clockid_t; + + /// Get the current time from the clock source. + /// + /// The function must return a value in the range from 0 to `KTIME_MAX`. + fn ktime_get() -> bindings::ktime_t; +} + +/// A monotonically increasing clock. +/// +/// A nonsettable system-wide clock that represents monotonic time since as +/// described by POSIX, "some unspecified point in the past". On Linux, that +/// point corresponds to the number of seconds that the system has been +/// running since it was booted. +/// +/// The CLOCK_MONOTONIC clock is not affected by discontinuous jumps in the +/// CLOCK_REAL (e.g., if the system administrator manually changes the +/// clock), but is affected by frequency adjustments. This clock does not +/// count time that the system is suspended. +pub struct Monotonic; + +impl ClockSource for Monotonic { + const ID: bindings::clockid_t = bindings::CLOCK_MONOTONIC as bindings::clockid_t; + + fn ktime_get() -> bindings::ktime_t { + // SAFETY: It is always safe to call `ktime_get()` outside of NMI context. + unsafe { bindings::ktime_get() } + } +} + +/// A settable system-wide clock that measures real (i.e., wall-clock) time. +/// +/// Setting this clock requires appropriate privileges. This clock is +/// affected by discontinuous jumps in the system time (e.g., if the system +/// administrator manually changes the clock), and by frequency adjustments +/// performed by NTP and similar applications via adjtime(3), adjtimex(2), +/// clock_adjtime(2), and ntp_adjtime(3). This clock normally counts the +/// number of seconds since 1970-01-01 00:00:00 Coordinated Universal Time +/// (UTC) except that it ignores leap seconds; near a leap second it may be +/// adjusted by leap second smearing to stay roughly in sync with UTC. Leap +/// second smearing applies frequency adjustments to the clock to speed up +/// or slow down the clock to account for the leap second without +/// discontinuities in the clock. If leap second smearing is not applied, +/// the clock will experience discontinuity around leap second adjustment. +pub struct RealTime; + +impl ClockSource for RealTime { + const ID: bindings::clockid_t = bindings::CLOCK_REALTIME as bindings::clockid_t; + + fn ktime_get() -> bindings::ktime_t { + // SAFETY: It is always safe to call `ktime_get_real()` outside of NMI context. + unsafe { bindings::ktime_get_real() } + } +} + +/// A monotonic that ticks while system is suspended. +/// +/// A nonsettable system-wide clock that is identical to CLOCK_MONOTONIC, +/// except that it also includes any time that the system is suspended. This +/// allows applications to get a suspend-aware monotonic clock without +/// having to deal with the complications of CLOCK_REALTIME, which may have +/// discontinuities if the time is changed using settimeofday(2) or similar. +pub struct BootTime; + +impl ClockSource for BootTime { + const ID: bindings::clockid_t = bindings::CLOCK_BOOTTIME as bindings::clockid_t; + + fn ktime_get() -> bindings::ktime_t { + // SAFETY: It is always safe to call `ktime_get_boottime()` outside of NMI context. + unsafe { bindings::ktime_get_boottime() } + } +} + +/// International Atomic Time. +/// +/// A system-wide clock derived from wall-clock time but counting leap seconds. +/// +/// This clock is coupled to CLOCK_REALTIME and will be set when CLOCK_REALTIME is +/// set, or when the offset to CLOCK_REALTIME is changed via adjtimex(2). This +/// usually happens during boot and **should** not happen during normal operations. +/// However, if NTP or another application adjusts CLOCK_REALTIME by leap second +/// smearing, this clock will not be precise during leap second smearing. +/// +/// The acronym TAI refers to International Atomic Time. +pub struct Tai; + +impl ClockSource for Tai { + const ID: bindings::clockid_t = bindings::CLOCK_TAI as bindings::clockid_t; + + fn ktime_get() -> bindings::ktime_t { + // SAFETY: It is always safe to call `ktime_get_tai()` outside of NMI context. + unsafe { bindings::ktime_get_clocktai() } + } +} + +/// A specific point in time. +/// +/// # Invariants +/// +/// The `inner` value is in the range from 0 to `KTIME_MAX`. #[repr(transparent)] -#[derive(Copy, Clone)] -pub struct Ktime { +#[derive(PartialEq, PartialOrd, Eq, Ord)] +pub struct Instant<C: ClockSource> { inner: bindings::ktime_t, + _c: PhantomData<C>, } -impl Ktime { - /// Create a `Ktime` from a raw `ktime_t`. - #[inline] - pub fn from_raw(inner: bindings::ktime_t) -> Self { - Self { inner } +impl<C: ClockSource> Clone for Instant<C> { + fn clone(&self) -> Self { + *self } +} - /// Get the current time using `CLOCK_MONOTONIC`. +impl<C: ClockSource> Copy for Instant<C> {} + +impl<C: ClockSource> Instant<C> { + /// Get the current time from the clock source. #[inline] - pub fn ktime_get() -> Self { - // SAFETY: It is always safe to call `ktime_get` outside of NMI context. - Self::from_raw(unsafe { bindings::ktime_get() }) + pub fn now() -> Self { + // INVARIANT: The `ClockSource::ktime_get()` function returns a value in the range + // from 0 to `KTIME_MAX`. + Self { + inner: C::ktime_get(), + _c: PhantomData, + } } - /// Divide the number of nanoseconds by a compile-time constant. + /// Return the amount of time elapsed since the [`Instant`]. #[inline] - fn divns_constant<const DIV: i64>(self) -> i64 { - self.to_ns() / DIV + pub fn elapsed(&self) -> Delta { + Self::now() - *self } - /// Returns the number of nanoseconds. #[inline] - pub fn to_ns(self) -> i64 { + pub(crate) fn as_nanos(&self) -> i64 { self.inner } +} + +impl<C: ClockSource> core::ops::Sub for Instant<C> { + type Output = Delta; - /// Returns the number of milliseconds. + // By the type invariant, it never overflows. #[inline] - pub fn to_ms(self) -> i64 { - self.divns_constant::<NSEC_PER_MSEC>() + fn sub(self, other: Instant<C>) -> Delta { + Delta { + nanos: self.inner - other.inner, + } } } -/// Returns the number of milliseconds between two ktimes. -#[inline] -pub fn ktime_ms_delta(later: Ktime, earlier: Ktime) -> i64 { - (later - earlier).to_ms() +/// A span of time. +/// +/// This struct represents a span of time, with its value stored as nanoseconds. +/// The value can represent any valid i64 value, including negative, zero, and +/// positive numbers. +#[derive(Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Debug)] +pub struct Delta { + nanos: i64, } -impl core::ops::Sub for Ktime { - type Output = Ktime; +impl Delta { + /// A span of time equal to zero. + pub const ZERO: Self = Self { nanos: 0 }; + + /// Create a new [`Delta`] from a number of microseconds. + /// + /// The `micros` can range from -9_223_372_036_854_775 to 9_223_372_036_854_775. + /// If `micros` is outside this range, `i64::MIN` is used for negative values, + /// and `i64::MAX` is used for positive values due to saturation. + #[inline] + pub const fn from_micros(micros: i64) -> Self { + Self { + nanos: micros.saturating_mul(NSEC_PER_USEC), + } + } + + /// Create a new [`Delta`] from a number of milliseconds. + /// + /// The `millis` can range from -9_223_372_036_854 to 9_223_372_036_854. + /// If `millis` is outside this range, `i64::MIN` is used for negative values, + /// and `i64::MAX` is used for positive values due to saturation. + #[inline] + pub const fn from_millis(millis: i64) -> Self { + Self { + nanos: millis.saturating_mul(NSEC_PER_MSEC), + } + } + /// Create a new [`Delta`] from a number of seconds. + /// + /// The `secs` can range from -9_223_372_036 to 9_223_372_036. + /// If `secs` is outside this range, `i64::MIN` is used for negative values, + /// and `i64::MAX` is used for positive values due to saturation. #[inline] - fn sub(self, other: Ktime) -> Ktime { + pub const fn from_secs(secs: i64) -> Self { Self { - inner: self.inner - other.inner, + nanos: secs.saturating_mul(NSEC_PER_SEC), + } + } + + /// Return `true` if the [`Delta`] spans no time. + #[inline] + pub fn is_zero(self) -> bool { + self.as_nanos() == 0 + } + + /// Return `true` if the [`Delta`] spans a negative amount of time. + #[inline] + pub fn is_negative(self) -> bool { + self.as_nanos() < 0 + } + + /// Return the number of nanoseconds in the [`Delta`]. + #[inline] + pub const fn as_nanos(self) -> i64 { + self.nanos + } + + /// Return the smallest number of microseconds greater than or equal + /// to the value in the [`Delta`]. + #[inline] + pub fn as_micros_ceil(self) -> i64 { + #[cfg(CONFIG_64BIT)] + { + self.as_nanos().saturating_add(NSEC_PER_USEC - 1) / NSEC_PER_USEC + } + + #[cfg(not(CONFIG_64BIT))] + // SAFETY: It is always safe to call `ktime_to_us()` with any value. + unsafe { + bindings::ktime_to_us(self.as_nanos().saturating_add(NSEC_PER_USEC - 1)) + } + } + + /// Return the number of milliseconds in the [`Delta`]. + #[inline] + pub fn as_millis(self) -> i64 { + #[cfg(CONFIG_64BIT)] + { + self.as_nanos() / NSEC_PER_MSEC + } + + #[cfg(not(CONFIG_64BIT))] + // SAFETY: It is always safe to call `ktime_to_ms()` with any value. + unsafe { + bindings::ktime_to_ms(self.as_nanos()) } } } diff --git a/rust/kernel/time/delay.rs b/rust/kernel/time/delay.rs new file mode 100644 index 000000000000..eb8838da62bc --- /dev/null +++ b/rust/kernel/time/delay.rs @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Delay and sleep primitives. +//! +//! This module contains the kernel APIs related to delay and sleep that +//! have been ported or wrapped for usage by Rust code in the kernel. +//! +//! C header: [`include/linux/delay.h`](srctree/include/linux/delay.h). + +use super::Delta; +use crate::prelude::*; + +/// Sleeps for a given duration at least. +/// +/// Equivalent to the C side [`fsleep()`], flexible sleep function, +/// which automatically chooses the best sleep method based on a duration. +/// +/// `delta` must be within `[0, i32::MAX]` microseconds; +/// otherwise, it is erroneous behavior. That is, it is considered a bug +/// to call this function with an out-of-range value, in which case the function +/// will sleep for at least the maximum value in the range and may warn +/// in the future. +/// +/// The behavior above differs from the C side [`fsleep()`] for which out-of-range +/// values mean "infinite timeout" instead. +/// +/// This function can only be used in a nonatomic context. +/// +/// [`fsleep()`]: https://docs.kernel.org/timers/delay_sleep_functions.html#c.fsleep +pub fn fsleep(delta: Delta) { + // The maximum value is set to `i32::MAX` microseconds to prevent integer + // overflow inside fsleep, which could lead to unintentional infinite sleep. + const MAX_DELTA: Delta = Delta::from_micros(i32::MAX as i64); + + let delta = if (Delta::ZERO..=MAX_DELTA).contains(&delta) { + delta + } else { + // TODO: Add WARN_ONCE() when it's supported. + MAX_DELTA + }; + + // SAFETY: It is always safe to call `fsleep()` with any duration. + unsafe { + // Convert the duration to microseconds and round up to preserve + // the guarantee; `fsleep()` sleeps for at least the provided duration, + // but that it may sleep for longer under some circumstances. + bindings::fsleep(delta.as_micros_ceil() as c_ulong) + } +} diff --git a/rust/kernel/time/hrtimer.rs b/rust/kernel/time/hrtimer.rs new file mode 100644 index 000000000000..144e3b57cc78 --- /dev/null +++ b/rust/kernel/time/hrtimer.rs @@ -0,0 +1,638 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Intrusive high resolution timers. +//! +//! Allows running timer callbacks without doing allocations at the time of +//! starting the timer. For now, only one timer per type is allowed. +//! +//! # Vocabulary +//! +//! States: +//! +//! - Stopped: initialized but not started, or cancelled, or not restarted. +//! - Started: initialized and started or restarted. +//! - Running: executing the callback. +//! +//! Operations: +//! +//! * Start +//! * Cancel +//! * Restart +//! +//! Events: +//! +//! * Expire +//! +//! ## State Diagram +//! +//! ```text +//! Return NoRestart +//! +---------------------------------------------------------------------+ +//! | | +//! | | +//! | | +//! | Return Restart | +//! | +------------------------+ | +//! | | | | +//! | | | | +//! v v | | +//! +-----------------+ Start +------------------+ +--------+-----+--+ +//! | +---------------->| | | | +//! Init | | | | Expire | | +//! --------->| Stopped | | Started +---------->| Running | +//! | | Cancel | | | | +//! | |<----------------+ | | | +//! +-----------------+ +---------------+--+ +-----------------+ +//! ^ | +//! | | +//! +---------+ +//! Restart +//! ``` +//! +//! +//! A timer is initialized in the **stopped** state. A stopped timer can be +//! **started** by the `start` operation, with an **expiry** time. After the +//! `start` operation, the timer is in the **started** state. When the timer +//! **expires**, the timer enters the **running** state and the handler is +//! executed. After the handler has returned, the timer may enter the +//! **started* or **stopped** state, depending on the return value of the +//! handler. A timer in the **started** or **running** state may be **canceled** +//! by the `cancel` operation. A timer that is cancelled enters the **stopped** +//! state. +//! +//! A `cancel` or `restart` operation on a timer in the **running** state takes +//! effect after the handler has returned and the timer has transitioned +//! out of the **running** state. +//! +//! A `restart` operation on a timer in the **stopped** state is equivalent to a +//! `start` operation. + +use super::{ClockSource, Delta, Instant}; +use crate::{prelude::*, types::Opaque}; +use core::marker::PhantomData; +use pin_init::PinInit; + +/// A timer backed by a C `struct hrtimer`. +/// +/// # Invariants +/// +/// * `self.timer` is initialized by `bindings::hrtimer_setup`. +#[pin_data] +#[repr(C)] +pub struct HrTimer<T> { + #[pin] + timer: Opaque<bindings::hrtimer>, + _t: PhantomData<T>, +} + +// SAFETY: Ownership of an `HrTimer` can be moved to other threads and +// used/dropped from there. +unsafe impl<T> Send for HrTimer<T> {} + +// SAFETY: Timer operations are locked on the C side, so it is safe to operate +// on a timer from multiple threads. +unsafe impl<T> Sync for HrTimer<T> {} + +impl<T> HrTimer<T> { + /// Return an initializer for a new timer instance. + pub fn new() -> impl PinInit<Self> + where + T: HrTimerCallback, + T: HasHrTimer<T>, + { + pin_init!(Self { + // INVARIANT: We initialize `timer` with `hrtimer_setup` below. + timer <- Opaque::ffi_init(move |place: *mut bindings::hrtimer| { + // SAFETY: By design of `pin_init!`, `place` is a pointer to a + // live allocation. hrtimer_setup will initialize `place` and + // does not require `place` to be initialized prior to the call. + unsafe { + bindings::hrtimer_setup( + place, + Some(T::Pointer::run), + <<T as HasHrTimer<T>>::TimerMode as HrTimerMode>::Clock::ID, + <T as HasHrTimer<T>>::TimerMode::C_MODE, + ); + } + }), + _t: PhantomData, + }) + } + + /// Get a pointer to the contained `bindings::hrtimer`. + /// + /// This function is useful to get access to the value without creating + /// intermediate references. + /// + /// # Safety + /// + /// `this` must point to a live allocation of at least the size of `Self`. + unsafe fn raw_get(this: *const Self) -> *mut bindings::hrtimer { + // SAFETY: The field projection to `timer` does not go out of bounds, + // because the caller of this function promises that `this` points to an + // allocation of at least the size of `Self`. + unsafe { Opaque::cast_into(core::ptr::addr_of!((*this).timer)) } + } + + /// Cancel an initialized and potentially running timer. + /// + /// If the timer handler is running, this function will block until the + /// handler returns. + /// + /// Note that the timer might be started by a concurrent start operation. If + /// so, the timer might not be in the **stopped** state when this function + /// returns. + /// + /// Users of the `HrTimer` API would not usually call this method directly. + /// Instead they would use the safe [`HrTimerHandle::cancel`] on the handle + /// returned when the timer was started. + /// + /// This function is useful to get access to the value without creating + /// intermediate references. + /// + /// # Safety + /// + /// `this` must point to a valid `Self`. + pub(crate) unsafe fn raw_cancel(this: *const Self) -> bool { + // SAFETY: `this` points to an allocation of at least `HrTimer` size. + let c_timer_ptr = unsafe { HrTimer::raw_get(this) }; + + // If the handler is running, this will wait for the handler to return + // before returning. + // SAFETY: `c_timer_ptr` is initialized and valid. Synchronization is + // handled on the C side. + unsafe { bindings::hrtimer_cancel(c_timer_ptr) != 0 } + } +} + +/// Implemented by pointer types that point to structs that contain a [`HrTimer`]. +/// +/// `Self` must be [`Sync`] because it is passed to timer callbacks in another +/// thread of execution (hard or soft interrupt context). +/// +/// Starting a timer returns a [`HrTimerHandle`] that can be used to manipulate +/// the timer. Note that it is OK to call the start function repeatedly, and +/// that more than one [`HrTimerHandle`] associated with a [`HrTimerPointer`] may +/// exist. A timer can be manipulated through any of the handles, and a handle +/// may represent a cancelled timer. +pub trait HrTimerPointer: Sync + Sized { + /// The operational mode associated with this timer. + /// + /// This defines how the expiration value is interpreted. + type TimerMode: HrTimerMode; + + /// A handle representing a started or restarted timer. + /// + /// If the timer is running or if the timer callback is executing when the + /// handle is dropped, the drop method of [`HrTimerHandle`] should not return + /// until the timer is stopped and the callback has completed. + /// + /// Note: When implementing this trait, consider that it is not unsafe to + /// leak the handle. + type TimerHandle: HrTimerHandle; + + /// Start the timer with expiry after `expires` time units. If the timer was + /// already running, it is restarted with the new expiry time. + fn start(self, expires: <Self::TimerMode as HrTimerMode>::Expires) -> Self::TimerHandle; +} + +/// Unsafe version of [`HrTimerPointer`] for situations where leaking the +/// [`HrTimerHandle`] returned by `start` would be unsound. This is the case for +/// stack allocated timers. +/// +/// Typical implementers are pinned references such as [`Pin<&T>`]. +/// +/// # Safety +/// +/// Implementers of this trait must ensure that instances of types implementing +/// [`UnsafeHrTimerPointer`] outlives any associated [`HrTimerPointer::TimerHandle`] +/// instances. +pub unsafe trait UnsafeHrTimerPointer: Sync + Sized { + /// The operational mode associated with this timer. + /// + /// This defines how the expiration value is interpreted. + type TimerMode: HrTimerMode; + + /// A handle representing a running timer. + /// + /// # Safety + /// + /// If the timer is running, or if the timer callback is executing when the + /// handle is dropped, the drop method of [`Self::TimerHandle`] must not return + /// until the timer is stopped and the callback has completed. + type TimerHandle: HrTimerHandle; + + /// Start the timer after `expires` time units. If the timer was already + /// running, it is restarted at the new expiry time. + /// + /// # Safety + /// + /// Caller promises keep the timer structure alive until the timer is dead. + /// Caller can ensure this by not leaking the returned [`Self::TimerHandle`]. + unsafe fn start(self, expires: <Self::TimerMode as HrTimerMode>::Expires) -> Self::TimerHandle; +} + +/// A trait for stack allocated timers. +/// +/// # Safety +/// +/// Implementers must ensure that `start_scoped` does not return until the +/// timer is dead and the timer handler is not running. +pub unsafe trait ScopedHrTimerPointer { + /// The operational mode associated with this timer. + /// + /// This defines how the expiration value is interpreted. + type TimerMode: HrTimerMode; + + /// Start the timer to run after `expires` time units and immediately + /// after call `f`. When `f` returns, the timer is cancelled. + fn start_scoped<T, F>(self, expires: <Self::TimerMode as HrTimerMode>::Expires, f: F) -> T + where + F: FnOnce() -> T; +} + +// SAFETY: By the safety requirement of [`UnsafeHrTimerPointer`], dropping the +// handle returned by [`UnsafeHrTimerPointer::start`] ensures that the timer is +// killed. +unsafe impl<T> ScopedHrTimerPointer for T +where + T: UnsafeHrTimerPointer, +{ + type TimerMode = T::TimerMode; + + fn start_scoped<U, F>( + self, + expires: <<T as UnsafeHrTimerPointer>::TimerMode as HrTimerMode>::Expires, + f: F, + ) -> U + where + F: FnOnce() -> U, + { + // SAFETY: We drop the timer handle below before returning. + let handle = unsafe { UnsafeHrTimerPointer::start(self, expires) }; + let t = f(); + drop(handle); + t + } +} + +/// Implemented by [`HrTimerPointer`] implementers to give the C timer callback a +/// function to call. +// This is split from `HrTimerPointer` to make it easier to specify trait bounds. +pub trait RawHrTimerCallback { + /// Type of the parameter passed to [`HrTimerCallback::run`]. It may be + /// [`Self`], or a pointer type derived from [`Self`]. + type CallbackTarget<'a>; + + /// Callback to be called from C when timer fires. + /// + /// # Safety + /// + /// Only to be called by C code in the `hrtimer` subsystem. `this` must point + /// to the `bindings::hrtimer` structure that was used to start the timer. + unsafe extern "C" fn run(this: *mut bindings::hrtimer) -> bindings::hrtimer_restart; +} + +/// Implemented by structs that can be the target of a timer callback. +pub trait HrTimerCallback { + /// The type whose [`RawHrTimerCallback::run`] method will be invoked when + /// the timer expires. + type Pointer<'a>: RawHrTimerCallback; + + /// Called by the timer logic when the timer fires. + fn run(this: <Self::Pointer<'_> as RawHrTimerCallback>::CallbackTarget<'_>) -> HrTimerRestart + where + Self: Sized; +} + +/// A handle representing a potentially running timer. +/// +/// More than one handle representing the same timer might exist. +/// +/// # Safety +/// +/// When dropped, the timer represented by this handle must be cancelled, if it +/// is running. If the timer handler is running when the handle is dropped, the +/// drop method must wait for the handler to return before returning. +/// +/// Note: One way to satisfy the safety requirement is to call `Self::cancel` in +/// the drop implementation for `Self.` +pub unsafe trait HrTimerHandle { + /// Cancel the timer. If the timer is in the running state, block till the + /// handler has returned. + /// + /// Note that the timer might be started by a concurrent start operation. If + /// so, the timer might not be in the **stopped** state when this function + /// returns. + fn cancel(&mut self) -> bool; +} + +/// Implemented by structs that contain timer nodes. +/// +/// Clients of the timer API would usually safely implement this trait by using +/// the [`crate::impl_has_hr_timer`] macro. +/// +/// # Safety +/// +/// Implementers of this trait must ensure that the implementer has a +/// [`HrTimer`] field and that all trait methods are implemented according to +/// their documentation. All the methods of this trait must operate on the same +/// field. +pub unsafe trait HasHrTimer<T> { + /// The operational mode associated with this timer. + /// + /// This defines how the expiration value is interpreted. + type TimerMode: HrTimerMode; + + /// Return a pointer to the [`HrTimer`] within `Self`. + /// + /// This function is useful to get access to the value without creating + /// intermediate references. + /// + /// # Safety + /// + /// `this` must be a valid pointer. + unsafe fn raw_get_timer(this: *const Self) -> *const HrTimer<T>; + + /// Return a pointer to the struct that is containing the [`HrTimer`] pointed + /// to by `ptr`. + /// + /// This function is useful to get access to the value without creating + /// intermediate references. + /// + /// # Safety + /// + /// `ptr` must point to a [`HrTimer<T>`] field in a struct of type `Self`. + unsafe fn timer_container_of(ptr: *mut HrTimer<T>) -> *mut Self + where + Self: Sized; + + /// Get pointer to the contained `bindings::hrtimer` struct. + /// + /// This function is useful to get access to the value without creating + /// intermediate references. + /// + /// # Safety + /// + /// `this` must be a valid pointer. + unsafe fn c_timer_ptr(this: *const Self) -> *const bindings::hrtimer { + // SAFETY: `this` is a valid pointer to a `Self`. + let timer_ptr = unsafe { Self::raw_get_timer(this) }; + + // SAFETY: timer_ptr points to an allocation of at least `HrTimer` size. + unsafe { HrTimer::raw_get(timer_ptr) } + } + + /// Start the timer contained in the `Self` pointed to by `self_ptr`. If + /// it is already running it is removed and inserted. + /// + /// # Safety + /// + /// - `this` must point to a valid `Self`. + /// - Caller must ensure that the pointee of `this` lives until the timer + /// fires or is canceled. + unsafe fn start(this: *const Self, expires: <Self::TimerMode as HrTimerMode>::Expires) { + // SAFETY: By function safety requirement, `this` is a valid `Self`. + unsafe { + bindings::hrtimer_start_range_ns( + Self::c_timer_ptr(this).cast_mut(), + expires.as_nanos(), + 0, + <Self::TimerMode as HrTimerMode>::C_MODE, + ); + } + } +} + +/// Restart policy for timers. +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +#[repr(u32)] +pub enum HrTimerRestart { + /// Timer should not be restarted. + NoRestart = bindings::hrtimer_restart_HRTIMER_NORESTART, + /// Timer should be restarted. + Restart = bindings::hrtimer_restart_HRTIMER_RESTART, +} + +impl HrTimerRestart { + fn into_c(self) -> bindings::hrtimer_restart { + self as bindings::hrtimer_restart + } +} + +/// Time representations that can be used as expiration values in [`HrTimer`]. +pub trait HrTimerExpires { + /// Converts the expiration time into a nanosecond representation. + /// + /// This value corresponds to a raw ktime_t value, suitable for passing to kernel + /// timer functions. The interpretation (absolute vs relative) depends on the + /// associated [HrTimerMode] in use. + fn as_nanos(&self) -> i64; +} + +impl<C: ClockSource> HrTimerExpires for Instant<C> { + #[inline] + fn as_nanos(&self) -> i64 { + Instant::<C>::as_nanos(self) + } +} + +impl HrTimerExpires for Delta { + #[inline] + fn as_nanos(&self) -> i64 { + Delta::as_nanos(*self) + } +} + +mod private { + use crate::time::ClockSource; + + pub trait Sealed {} + + impl<C: ClockSource> Sealed for super::AbsoluteMode<C> {} + impl<C: ClockSource> Sealed for super::RelativeMode<C> {} + impl<C: ClockSource> Sealed for super::AbsolutePinnedMode<C> {} + impl<C: ClockSource> Sealed for super::RelativePinnedMode<C> {} + impl<C: ClockSource> Sealed for super::AbsoluteSoftMode<C> {} + impl<C: ClockSource> Sealed for super::RelativeSoftMode<C> {} + impl<C: ClockSource> Sealed for super::AbsolutePinnedSoftMode<C> {} + impl<C: ClockSource> Sealed for super::RelativePinnedSoftMode<C> {} + impl<C: ClockSource> Sealed for super::AbsoluteHardMode<C> {} + impl<C: ClockSource> Sealed for super::RelativeHardMode<C> {} + impl<C: ClockSource> Sealed for super::AbsolutePinnedHardMode<C> {} + impl<C: ClockSource> Sealed for super::RelativePinnedHardMode<C> {} +} + +/// Operational mode of [`HrTimer`]. +pub trait HrTimerMode: private::Sealed { + /// The C representation of hrtimer mode. + const C_MODE: bindings::hrtimer_mode; + + /// Type representing the clock source. + type Clock: ClockSource; + + /// Type representing the expiration specification (absolute or relative time). + type Expires: HrTimerExpires; +} + +/// Timer that expires at a fixed point in time. +pub struct AbsoluteMode<C: ClockSource>(PhantomData<C>); + +impl<C: ClockSource> HrTimerMode for AbsoluteMode<C> { + const C_MODE: bindings::hrtimer_mode = bindings::hrtimer_mode_HRTIMER_MODE_ABS; + + type Clock = C; + type Expires = Instant<C>; +} + +/// Timer that expires after a delay from now. +pub struct RelativeMode<C: ClockSource>(PhantomData<C>); + +impl<C: ClockSource> HrTimerMode for RelativeMode<C> { + const C_MODE: bindings::hrtimer_mode = bindings::hrtimer_mode_HRTIMER_MODE_REL; + + type Clock = C; + type Expires = Delta; +} + +/// Timer with absolute expiration time, pinned to its current CPU. +pub struct AbsolutePinnedMode<C: ClockSource>(PhantomData<C>); +impl<C: ClockSource> HrTimerMode for AbsolutePinnedMode<C> { + const C_MODE: bindings::hrtimer_mode = bindings::hrtimer_mode_HRTIMER_MODE_ABS_PINNED; + + type Clock = C; + type Expires = Instant<C>; +} + +/// Timer with relative expiration time, pinned to its current CPU. +pub struct RelativePinnedMode<C: ClockSource>(PhantomData<C>); +impl<C: ClockSource> HrTimerMode for RelativePinnedMode<C> { + const C_MODE: bindings::hrtimer_mode = bindings::hrtimer_mode_HRTIMER_MODE_REL_PINNED; + + type Clock = C; + type Expires = Delta; +} + +/// Timer with absolute expiration, handled in soft irq context. +pub struct AbsoluteSoftMode<C: ClockSource>(PhantomData<C>); +impl<C: ClockSource> HrTimerMode for AbsoluteSoftMode<C> { + const C_MODE: bindings::hrtimer_mode = bindings::hrtimer_mode_HRTIMER_MODE_ABS_SOFT; + + type Clock = C; + type Expires = Instant<C>; +} + +/// Timer with relative expiration, handled in soft irq context. +pub struct RelativeSoftMode<C: ClockSource>(PhantomData<C>); +impl<C: ClockSource> HrTimerMode for RelativeSoftMode<C> { + const C_MODE: bindings::hrtimer_mode = bindings::hrtimer_mode_HRTIMER_MODE_REL_SOFT; + + type Clock = C; + type Expires = Delta; +} + +/// Timer with absolute expiration, pinned to CPU and handled in soft irq context. +pub struct AbsolutePinnedSoftMode<C: ClockSource>(PhantomData<C>); +impl<C: ClockSource> HrTimerMode for AbsolutePinnedSoftMode<C> { + const C_MODE: bindings::hrtimer_mode = bindings::hrtimer_mode_HRTIMER_MODE_ABS_PINNED_SOFT; + + type Clock = C; + type Expires = Instant<C>; +} + +/// Timer with absolute expiration, pinned to CPU and handled in soft irq context. +pub struct RelativePinnedSoftMode<C: ClockSource>(PhantomData<C>); +impl<C: ClockSource> HrTimerMode for RelativePinnedSoftMode<C> { + const C_MODE: bindings::hrtimer_mode = bindings::hrtimer_mode_HRTIMER_MODE_REL_PINNED_SOFT; + + type Clock = C; + type Expires = Delta; +} + +/// Timer with absolute expiration, handled in hard irq context. +pub struct AbsoluteHardMode<C: ClockSource>(PhantomData<C>); +impl<C: ClockSource> HrTimerMode for AbsoluteHardMode<C> { + const C_MODE: bindings::hrtimer_mode = bindings::hrtimer_mode_HRTIMER_MODE_ABS_HARD; + + type Clock = C; + type Expires = Instant<C>; +} + +/// Timer with relative expiration, handled in hard irq context. +pub struct RelativeHardMode<C: ClockSource>(PhantomData<C>); +impl<C: ClockSource> HrTimerMode for RelativeHardMode<C> { + const C_MODE: bindings::hrtimer_mode = bindings::hrtimer_mode_HRTIMER_MODE_REL_HARD; + + type Clock = C; + type Expires = Delta; +} + +/// Timer with absolute expiration, pinned to CPU and handled in hard irq context. +pub struct AbsolutePinnedHardMode<C: ClockSource>(PhantomData<C>); +impl<C: ClockSource> HrTimerMode for AbsolutePinnedHardMode<C> { + const C_MODE: bindings::hrtimer_mode = bindings::hrtimer_mode_HRTIMER_MODE_ABS_PINNED_HARD; + + type Clock = C; + type Expires = Instant<C>; +} + +/// Timer with relative expiration, pinned to CPU and handled in hard irq context. +pub struct RelativePinnedHardMode<C: ClockSource>(PhantomData<C>); +impl<C: ClockSource> HrTimerMode for RelativePinnedHardMode<C> { + const C_MODE: bindings::hrtimer_mode = bindings::hrtimer_mode_HRTIMER_MODE_REL_PINNED_HARD; + + type Clock = C; + type Expires = Delta; +} + +/// Use to implement the [`HasHrTimer<T>`] trait. +/// +/// See [`module`] documentation for an example. +/// +/// [`module`]: crate::time::hrtimer +#[macro_export] +macro_rules! impl_has_hr_timer { + ( + impl$({$($generics:tt)*})? + HasHrTimer<$timer_type:ty> + for $self:ty + { + mode : $mode:ty, + field : self.$field:ident $(,)? + } + $($rest:tt)* + ) => { + // SAFETY: This implementation of `raw_get_timer` only compiles if the + // field has the right type. + unsafe impl$(<$($generics)*>)? $crate::time::hrtimer::HasHrTimer<$timer_type> for $self { + type TimerMode = $mode; + + #[inline] + unsafe fn raw_get_timer( + this: *const Self, + ) -> *const $crate::time::hrtimer::HrTimer<$timer_type> { + // SAFETY: The caller promises that the pointer is not dangling. + unsafe { ::core::ptr::addr_of!((*this).$field) } + } + + #[inline] + unsafe fn timer_container_of( + ptr: *mut $crate::time::hrtimer::HrTimer<$timer_type>, + ) -> *mut Self { + // SAFETY: As per the safety requirement of this function, `ptr` + // is pointing inside a `$timer_type`. + unsafe { ::kernel::container_of!(ptr, $timer_type, $field) } + } + } + } +} + +mod arc; +pub use arc::ArcHrTimerHandle; +mod pin; +pub use pin::PinHrTimerHandle; +mod pin_mut; +pub use pin_mut::PinMutHrTimerHandle; +// `box` is a reserved keyword, so prefix with `t` for timer +mod tbox; +pub use tbox::BoxHrTimerHandle; diff --git a/rust/kernel/time/hrtimer/arc.rs b/rust/kernel/time/hrtimer/arc.rs new file mode 100644 index 000000000000..ed490a7a8950 --- /dev/null +++ b/rust/kernel/time/hrtimer/arc.rs @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: GPL-2.0 + +use super::HasHrTimer; +use super::HrTimer; +use super::HrTimerCallback; +use super::HrTimerHandle; +use super::HrTimerMode; +use super::HrTimerPointer; +use super::RawHrTimerCallback; +use crate::sync::Arc; +use crate::sync::ArcBorrow; + +/// A handle for an `Arc<HasHrTimer<T>>` returned by a call to +/// [`HrTimerPointer::start`]. +pub struct ArcHrTimerHandle<T> +where + T: HasHrTimer<T>, +{ + pub(crate) inner: Arc<T>, +} + +// SAFETY: We implement drop below, and we cancel the timer in the drop +// implementation. +unsafe impl<T> HrTimerHandle for ArcHrTimerHandle<T> +where + T: HasHrTimer<T>, +{ + fn cancel(&mut self) -> bool { + let self_ptr = Arc::as_ptr(&self.inner); + + // SAFETY: As we obtained `self_ptr` from a valid reference above, it + // must point to a valid `T`. + let timer_ptr = unsafe { <T as HasHrTimer<T>>::raw_get_timer(self_ptr) }; + + // SAFETY: As `timer_ptr` points into `T` and `T` is valid, `timer_ptr` + // must point to a valid `HrTimer` instance. + unsafe { HrTimer::<T>::raw_cancel(timer_ptr) } + } +} + +impl<T> Drop for ArcHrTimerHandle<T> +where + T: HasHrTimer<T>, +{ + fn drop(&mut self) { + self.cancel(); + } +} + +impl<T> HrTimerPointer for Arc<T> +where + T: 'static, + T: Send + Sync, + T: HasHrTimer<T>, + T: for<'a> HrTimerCallback<Pointer<'a> = Self>, +{ + type TimerMode = <T as HasHrTimer<T>>::TimerMode; + type TimerHandle = ArcHrTimerHandle<T>; + + fn start( + self, + expires: <<T as HasHrTimer<T>>::TimerMode as HrTimerMode>::Expires, + ) -> ArcHrTimerHandle<T> { + // SAFETY: + // - We keep `self` alive by wrapping it in a handle below. + // - Since we generate the pointer passed to `start` from a valid + // reference, it is a valid pointer. + unsafe { T::start(Arc::as_ptr(&self), expires) }; + ArcHrTimerHandle { inner: self } + } +} + +impl<T> RawHrTimerCallback for Arc<T> +where + T: 'static, + T: HasHrTimer<T>, + T: for<'a> HrTimerCallback<Pointer<'a> = Self>, +{ + type CallbackTarget<'a> = ArcBorrow<'a, T>; + + unsafe extern "C" fn run(ptr: *mut bindings::hrtimer) -> bindings::hrtimer_restart { + // `HrTimer` is `repr(C)` + let timer_ptr = ptr.cast::<super::HrTimer<T>>(); + + // SAFETY: By C API contract `ptr` is the pointer we passed when + // queuing the timer, so it is a `HrTimer<T>` embedded in a `T`. + let data_ptr = unsafe { T::timer_container_of(timer_ptr) }; + + // SAFETY: + // - `data_ptr` is derived form the pointer to the `T` that was used to + // queue the timer. + // - As per the safety requirements of the trait `HrTimerHandle`, the + // `ArcHrTimerHandle` associated with this timer is guaranteed to + // be alive until this method returns. That handle borrows the `T` + // behind `data_ptr` thus guaranteeing the validity of + // the `ArcBorrow` created below. + // - We own one refcount in the `ArcTimerHandle` associated with this + // timer, so it is not possible to get a `UniqueArc` to this + // allocation from other `Arc` clones. + let receiver = unsafe { ArcBorrow::from_raw(data_ptr) }; + + T::run(receiver).into_c() + } +} diff --git a/rust/kernel/time/hrtimer/pin.rs b/rust/kernel/time/hrtimer/pin.rs new file mode 100644 index 000000000000..aef16d9ee2f0 --- /dev/null +++ b/rust/kernel/time/hrtimer/pin.rs @@ -0,0 +1,108 @@ +// SPDX-License-Identifier: GPL-2.0 + +use super::HasHrTimer; +use super::HrTimer; +use super::HrTimerCallback; +use super::HrTimerHandle; +use super::HrTimerMode; +use super::RawHrTimerCallback; +use super::UnsafeHrTimerPointer; +use core::pin::Pin; + +/// A handle for a `Pin<&HasHrTimer>`. When the handle exists, the timer might be +/// running. +pub struct PinHrTimerHandle<'a, T> +where + T: HasHrTimer<T>, +{ + pub(crate) inner: Pin<&'a T>, +} + +// SAFETY: We cancel the timer when the handle is dropped. The implementation of +// the `cancel` method will block if the timer handler is running. +unsafe impl<'a, T> HrTimerHandle for PinHrTimerHandle<'a, T> +where + T: HasHrTimer<T>, +{ + fn cancel(&mut self) -> bool { + let self_ptr: *const T = self.inner.get_ref(); + + // SAFETY: As we got `self_ptr` from a reference above, it must point to + // a valid `T`. + let timer_ptr = unsafe { <T as HasHrTimer<T>>::raw_get_timer(self_ptr) }; + + // SAFETY: As `timer_ptr` is derived from a reference, it must point to + // a valid and initialized `HrTimer`. + unsafe { HrTimer::<T>::raw_cancel(timer_ptr) } + } +} + +impl<'a, T> Drop for PinHrTimerHandle<'a, T> +where + T: HasHrTimer<T>, +{ + fn drop(&mut self) { + self.cancel(); + } +} + +// SAFETY: We capture the lifetime of `Self` when we create a `PinHrTimerHandle`, +// so `Self` will outlive the handle. +unsafe impl<'a, T> UnsafeHrTimerPointer for Pin<&'a T> +where + T: Send + Sync, + T: HasHrTimer<T>, + T: HrTimerCallback<Pointer<'a> = Self>, +{ + type TimerMode = <T as HasHrTimer<T>>::TimerMode; + type TimerHandle = PinHrTimerHandle<'a, T>; + + unsafe fn start( + self, + expires: <<T as HasHrTimer<T>>::TimerMode as HrTimerMode>::Expires, + ) -> Self::TimerHandle { + // Cast to pointer + let self_ptr: *const T = self.get_ref(); + + // SAFETY: + // - As we derive `self_ptr` from a reference above, it must point to a + // valid `T`. + // - We keep `self` alive by wrapping it in a handle below. + unsafe { T::start(self_ptr, expires) }; + + PinHrTimerHandle { inner: self } + } +} + +impl<'a, T> RawHrTimerCallback for Pin<&'a T> +where + T: HasHrTimer<T>, + T: HrTimerCallback<Pointer<'a> = Self>, +{ + type CallbackTarget<'b> = Self; + + unsafe extern "C" fn run(ptr: *mut bindings::hrtimer) -> bindings::hrtimer_restart { + // `HrTimer` is `repr(C)` + let timer_ptr = ptr.cast::<HrTimer<T>>(); + + // SAFETY: By the safety requirement of this function, `timer_ptr` + // points to a `HrTimer<T>` contained in an `T`. + let receiver_ptr = unsafe { T::timer_container_of(timer_ptr) }; + + // SAFETY: + // - By the safety requirement of this function, `timer_ptr` + // points to a `HrTimer<T>` contained in an `T`. + // - As per the safety requirements of the trait `HrTimerHandle`, the + // `PinHrTimerHandle` associated with this timer is guaranteed to + // be alive until this method returns. That handle borrows the `T` + // behind `receiver_ptr`, thus guaranteeing the validity of + // the reference created below. + let receiver_ref = unsafe { &*receiver_ptr }; + + // SAFETY: `receiver_ref` only exists as pinned, so it is safe to pin it + // here. + let receiver_pin = unsafe { Pin::new_unchecked(receiver_ref) }; + + T::run(receiver_pin).into_c() + } +} diff --git a/rust/kernel/time/hrtimer/pin_mut.rs b/rust/kernel/time/hrtimer/pin_mut.rs new file mode 100644 index 000000000000..767d0a4e8a2c --- /dev/null +++ b/rust/kernel/time/hrtimer/pin_mut.rs @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: GPL-2.0 + +use super::{ + HasHrTimer, HrTimer, HrTimerCallback, HrTimerHandle, HrTimerMode, RawHrTimerCallback, + UnsafeHrTimerPointer, +}; +use core::{marker::PhantomData, pin::Pin, ptr::NonNull}; + +/// A handle for a `Pin<&mut HasHrTimer>`. When the handle exists, the timer might +/// be running. +pub struct PinMutHrTimerHandle<'a, T> +where + T: HasHrTimer<T>, +{ + pub(crate) inner: NonNull<T>, + _p: PhantomData<&'a mut T>, +} + +// SAFETY: We cancel the timer when the handle is dropped. The implementation of +// the `cancel` method will block if the timer handler is running. +unsafe impl<'a, T> HrTimerHandle for PinMutHrTimerHandle<'a, T> +where + T: HasHrTimer<T>, +{ + fn cancel(&mut self) -> bool { + let self_ptr = self.inner.as_ptr(); + + // SAFETY: As we got `self_ptr` from a reference above, it must point to + // a valid `T`. + let timer_ptr = unsafe { <T as HasHrTimer<T>>::raw_get_timer(self_ptr) }; + + // SAFETY: As `timer_ptr` is derived from a reference, it must point to + // a valid and initialized `HrTimer`. + unsafe { HrTimer::<T>::raw_cancel(timer_ptr) } + } +} + +impl<'a, T> Drop for PinMutHrTimerHandle<'a, T> +where + T: HasHrTimer<T>, +{ + fn drop(&mut self) { + self.cancel(); + } +} + +// SAFETY: We capture the lifetime of `Self` when we create a +// `PinMutHrTimerHandle`, so `Self` will outlive the handle. +unsafe impl<'a, T> UnsafeHrTimerPointer for Pin<&'a mut T> +where + T: Send + Sync, + T: HasHrTimer<T>, + T: HrTimerCallback<Pointer<'a> = Self>, +{ + type TimerMode = <T as HasHrTimer<T>>::TimerMode; + type TimerHandle = PinMutHrTimerHandle<'a, T>; + + unsafe fn start( + mut self, + expires: <<T as HasHrTimer<T>>::TimerMode as HrTimerMode>::Expires, + ) -> Self::TimerHandle { + // SAFETY: + // - We promise not to move out of `self`. We only pass `self` + // back to the caller as a `Pin<&mut self>`. + // - The return value of `get_unchecked_mut` is guaranteed not to be null. + let self_ptr = unsafe { NonNull::new_unchecked(self.as_mut().get_unchecked_mut()) }; + + // SAFETY: + // - As we derive `self_ptr` from a reference above, it must point to a + // valid `T`. + // - We keep `self` alive by wrapping it in a handle below. + unsafe { T::start(self_ptr.as_ptr(), expires) }; + + PinMutHrTimerHandle { + inner: self_ptr, + _p: PhantomData, + } + } +} + +impl<'a, T> RawHrTimerCallback for Pin<&'a mut T> +where + T: HasHrTimer<T>, + T: HrTimerCallback<Pointer<'a> = Self>, +{ + type CallbackTarget<'b> = Self; + + unsafe extern "C" fn run(ptr: *mut bindings::hrtimer) -> bindings::hrtimer_restart { + // `HrTimer` is `repr(C)` + let timer_ptr = ptr.cast::<HrTimer<T>>(); + + // SAFETY: By the safety requirement of this function, `timer_ptr` + // points to a `HrTimer<T>` contained in an `T`. + let receiver_ptr = unsafe { T::timer_container_of(timer_ptr) }; + + // SAFETY: + // - By the safety requirement of this function, `timer_ptr` + // points to a `HrTimer<T>` contained in an `T`. + // - As per the safety requirements of the trait `HrTimerHandle`, the + // `PinMutHrTimerHandle` associated with this timer is guaranteed to + // be alive until this method returns. That handle borrows the `T` + // behind `receiver_ptr` mutably thus guaranteeing the validity of + // the reference created below. + let receiver_ref = unsafe { &mut *receiver_ptr }; + + // SAFETY: `receiver_ref` only exists as pinned, so it is safe to pin it + // here. + let receiver_pin = unsafe { Pin::new_unchecked(receiver_ref) }; + + T::run(receiver_pin).into_c() + } +} diff --git a/rust/kernel/time/hrtimer/tbox.rs b/rust/kernel/time/hrtimer/tbox.rs new file mode 100644 index 000000000000..ec08303315f2 --- /dev/null +++ b/rust/kernel/time/hrtimer/tbox.rs @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: GPL-2.0 + +use super::HasHrTimer; +use super::HrTimer; +use super::HrTimerCallback; +use super::HrTimerHandle; +use super::HrTimerMode; +use super::HrTimerPointer; +use super::RawHrTimerCallback; +use crate::prelude::*; +use core::ptr::NonNull; + +/// A handle for a [`Box<HasHrTimer<T>>`] returned by a call to +/// [`HrTimerPointer::start`]. +/// +/// # Invariants +/// +/// - `self.inner` comes from a `Box::into_raw` call. +pub struct BoxHrTimerHandle<T, A> +where + T: HasHrTimer<T>, + A: crate::alloc::Allocator, +{ + pub(crate) inner: NonNull<T>, + _p: core::marker::PhantomData<A>, +} + +// SAFETY: We implement drop below, and we cancel the timer in the drop +// implementation. +unsafe impl<T, A> HrTimerHandle for BoxHrTimerHandle<T, A> +where + T: HasHrTimer<T>, + A: crate::alloc::Allocator, +{ + fn cancel(&mut self) -> bool { + // SAFETY: As we obtained `self.inner` from a valid reference when we + // created `self`, it must point to a valid `T`. + let timer_ptr = unsafe { <T as HasHrTimer<T>>::raw_get_timer(self.inner.as_ptr()) }; + + // SAFETY: As `timer_ptr` points into `T` and `T` is valid, `timer_ptr` + // must point to a valid `HrTimer` instance. + unsafe { HrTimer::<T>::raw_cancel(timer_ptr) } + } +} + +impl<T, A> Drop for BoxHrTimerHandle<T, A> +where + T: HasHrTimer<T>, + A: crate::alloc::Allocator, +{ + fn drop(&mut self) { + self.cancel(); + // SAFETY: By type invariant, `self.inner` came from a `Box::into_raw` + // call. + drop(unsafe { Box::<T, A>::from_raw(self.inner.as_ptr()) }) + } +} + +impl<T, A> HrTimerPointer for Pin<Box<T, A>> +where + T: 'static, + T: Send + Sync, + T: HasHrTimer<T>, + T: for<'a> HrTimerCallback<Pointer<'a> = Pin<Box<T, A>>>, + A: crate::alloc::Allocator, +{ + type TimerMode = <T as HasHrTimer<T>>::TimerMode; + type TimerHandle = BoxHrTimerHandle<T, A>; + + fn start( + self, + expires: <<T as HasHrTimer<T>>::TimerMode as HrTimerMode>::Expires, + ) -> Self::TimerHandle { + // SAFETY: + // - We will not move out of this box during timer callback (we pass an + // immutable reference to the callback). + // - `Box::into_raw` is guaranteed to return a valid pointer. + let inner = + unsafe { NonNull::new_unchecked(Box::into_raw(Pin::into_inner_unchecked(self))) }; + + // SAFETY: + // - We keep `self` alive by wrapping it in a handle below. + // - Since we generate the pointer passed to `start` from a valid + // reference, it is a valid pointer. + unsafe { T::start(inner.as_ptr(), expires) }; + + // INVARIANT: `inner` came from `Box::into_raw` above. + BoxHrTimerHandle { + inner, + _p: core::marker::PhantomData, + } + } +} + +impl<T, A> RawHrTimerCallback for Pin<Box<T, A>> +where + T: 'static, + T: HasHrTimer<T>, + T: for<'a> HrTimerCallback<Pointer<'a> = Pin<Box<T, A>>>, + A: crate::alloc::Allocator, +{ + type CallbackTarget<'a> = Pin<&'a mut T>; + + unsafe extern "C" fn run(ptr: *mut bindings::hrtimer) -> bindings::hrtimer_restart { + // `HrTimer` is `repr(C)` + let timer_ptr = ptr.cast::<super::HrTimer<T>>(); + + // SAFETY: By C API contract `ptr` is the pointer we passed when + // queuing the timer, so it is a `HrTimer<T>` embedded in a `T`. + let data_ptr = unsafe { T::timer_container_of(timer_ptr) }; + + // SAFETY: + // - As per the safety requirements of the trait `HrTimerHandle`, the + // `BoxHrTimerHandle` associated with this timer is guaranteed to + // be alive until this method returns. That handle owns the `T` + // behind `data_ptr` thus guaranteeing the validity of + // the reference created below. + // - As `data_ptr` comes from a `Pin<Box<T>>`, only pinned references to + // `data_ptr` exist. + let data_mut_ref = unsafe { Pin::new_unchecked(&mut *data_ptr) }; + + T::run(data_mut_ref).into_c() + } +} diff --git a/rust/kernel/types.rs b/rust/kernel/types.rs index ec6457bb3084..dc0a02f5c3cf 100644 --- a/rust/kernel/types.rs +++ b/rust/kernel/types.rs @@ -2,14 +2,16 @@ //! Kernel types. -use crate::init::{self, PinInit}; +use crate::ffi::c_void; use core::{ cell::UnsafeCell, marker::{PhantomData, PhantomPinned}, - mem::{ManuallyDrop, MaybeUninit}, + mem::MaybeUninit, ops::{Deref, DerefMut}, - ptr::NonNull, }; +use pin_init::{PinInit, Wrapper, Zeroable}; + +pub use crate::sync::aref::{ARef, AlwaysRefCounted}; /// Used to transfer ownership to and from foreign (non-Rust) languages. /// @@ -18,36 +20,47 @@ use core::{ /// /// This trait is meant to be used in cases when Rust objects are stored in C objects and /// eventually "freed" back to Rust. -pub trait ForeignOwnable: Sized { - /// Type of values borrowed between calls to [`ForeignOwnable::into_foreign`] and - /// [`ForeignOwnable::from_foreign`]. +/// +/// # Safety +/// +/// - Implementations must satisfy the guarantees of [`Self::into_foreign`]. +pub unsafe trait ForeignOwnable: Sized { + /// The alignment of pointers returned by `into_foreign`. + const FOREIGN_ALIGN: usize; + + /// Type used to immutably borrow a value that is currently foreign-owned. type Borrowed<'a>; + /// Type used to mutably borrow a value that is currently foreign-owned. + type BorrowedMut<'a>; + /// Converts a Rust-owned object to a foreign-owned one. /// - /// The foreign representation is a pointer to void. There are no guarantees for this pointer. - /// For example, it might be invalid, dangling or pointing to uninitialized memory. Using it in - /// any way except for [`ForeignOwnable::from_foreign`], [`ForeignOwnable::borrow`], - /// [`ForeignOwnable::try_from_foreign`] can result in undefined behavior. - fn into_foreign(self) -> *const crate::ffi::c_void; - - /// Borrows a foreign-owned object. + /// The foreign representation is a pointer to void. Aside from the guarantees listed below, + /// there are no other guarantees for this pointer. For example, it might be invalid, dangling + /// or pointing to uninitialized memory. Using it in any way except for [`from_foreign`], + /// [`try_from_foreign`], [`borrow`], or [`borrow_mut`] can result in undefined behavior. /// - /// # Safety + /// # Guarantees + /// + /// - Minimum alignment of returned pointer is [`Self::FOREIGN_ALIGN`]. + /// - The returned pointer is not null. /// - /// `ptr` must have been returned by a previous call to [`ForeignOwnable::into_foreign`] for - /// which a previous matching [`ForeignOwnable::from_foreign`] hasn't been called yet. - unsafe fn borrow<'a>(ptr: *const crate::ffi::c_void) -> Self::Borrowed<'a>; + /// [`from_foreign`]: Self::from_foreign + /// [`try_from_foreign`]: Self::try_from_foreign + /// [`borrow`]: Self::borrow + /// [`borrow_mut`]: Self::borrow_mut + fn into_foreign(self) -> *mut c_void; /// Converts a foreign-owned object back to a Rust-owned one. /// /// # Safety /// - /// `ptr` must have been returned by a previous call to [`ForeignOwnable::into_foreign`] for - /// which a previous matching [`ForeignOwnable::from_foreign`] hasn't been called yet. - /// Additionally, all instances (if any) of values returned by [`ForeignOwnable::borrow`] for - /// this object must have been dropped. - unsafe fn from_foreign(ptr: *const crate::ffi::c_void) -> Self; + /// The provided pointer must have been returned by a previous call to [`into_foreign`], and it + /// must not be passed to `from_foreign` more than once. + /// + /// [`into_foreign`]: Self::into_foreign + unsafe fn from_foreign(ptr: *mut c_void) -> Self; /// Tries to convert a foreign-owned object back to a Rust-owned one. /// @@ -56,9 +69,10 @@ pub trait ForeignOwnable: Sized { /// /// # Safety /// - /// `ptr` must either be null or satisfy the safety requirements for - /// [`ForeignOwnable::from_foreign`]. - unsafe fn try_from_foreign(ptr: *const crate::ffi::c_void) -> Option<Self> { + /// `ptr` must either be null or satisfy the safety requirements for [`from_foreign`]. + /// + /// [`from_foreign`]: Self::from_foreign + unsafe fn try_from_foreign(ptr: *mut c_void) -> Option<Self> { if ptr.is_null() { None } else { @@ -67,18 +81,66 @@ pub trait ForeignOwnable: Sized { unsafe { Some(Self::from_foreign(ptr)) } } } + + /// Borrows a foreign-owned object immutably. + /// + /// This method provides a way to access a foreign-owned value from Rust immutably. It provides + /// you with exactly the same abilities as an `&Self` when the value is Rust-owned. + /// + /// # Safety + /// + /// The provided pointer must have been returned by a previous call to [`into_foreign`], and if + /// the pointer is ever passed to [`from_foreign`], then that call must happen after the end of + /// the lifetime `'a`. + /// + /// [`into_foreign`]: Self::into_foreign + /// [`from_foreign`]: Self::from_foreign + unsafe fn borrow<'a>(ptr: *mut c_void) -> Self::Borrowed<'a>; + + /// Borrows a foreign-owned object mutably. + /// + /// This method provides a way to access a foreign-owned value from Rust mutably. It provides + /// you with exactly the same abilities as an `&mut Self` when the value is Rust-owned, except + /// that the address of the object must not be changed. + /// + /// Note that for types like [`Arc`], an `&mut Arc<T>` only gives you immutable access to the + /// inner value, so this method also only provides immutable access in that case. + /// + /// In the case of `Box<T>`, this method gives you the ability to modify the inner `T`, but it + /// does not let you change the box itself. That is, you cannot change which allocation the box + /// points at. + /// + /// # Safety + /// + /// The provided pointer must have been returned by a previous call to [`into_foreign`], and if + /// the pointer is ever passed to [`from_foreign`], then that call must happen after the end of + /// the lifetime `'a`. + /// + /// The lifetime `'a` must not overlap with the lifetime of any other call to [`borrow`] or + /// `borrow_mut` on the same object. + /// + /// [`into_foreign`]: Self::into_foreign + /// [`from_foreign`]: Self::from_foreign + /// [`borrow`]: Self::borrow + /// [`Arc`]: crate::sync::Arc + unsafe fn borrow_mut<'a>(ptr: *mut c_void) -> Self::BorrowedMut<'a>; } -impl ForeignOwnable for () { +// SAFETY: The pointer returned by `into_foreign` comes from a well aligned +// pointer to `()`. +unsafe impl ForeignOwnable for () { + const FOREIGN_ALIGN: usize = core::mem::align_of::<()>(); type Borrowed<'a> = (); + type BorrowedMut<'a> = (); - fn into_foreign(self) -> *const crate::ffi::c_void { + fn into_foreign(self) -> *mut c_void { core::ptr::NonNull::dangling().as_ptr() } - unsafe fn borrow<'a>(_: *const crate::ffi::c_void) -> Self::Borrowed<'a> {} + unsafe fn from_foreign(_: *mut c_void) -> Self {} - unsafe fn from_foreign(_: *const crate::ffi::c_void) -> Self {} + unsafe fn borrow<'a>(_: *mut c_void) -> Self::Borrowed<'a> {} + unsafe fn borrow_mut<'a>(_: *mut c_void) -> Self::BorrowedMut<'a> {} } /// Runs a cleanup function/closure when dropped. @@ -206,7 +268,7 @@ impl<T, F: FnOnce(T)> Drop for ScopeGuard<T, F> { /// Stores an opaque value. /// -/// `Opaque<T>` is meant to be used with FFI objects that are never interpreted by Rust code. +/// [`Opaque<T>`] is meant to be used with FFI objects that are never interpreted by Rust code. /// /// It is used to wrap structs from the C side, like for example `Opaque<bindings::mutex>`. /// It gets rid of all the usual assumptions that Rust has for a value: @@ -221,7 +283,7 @@ impl<T, F: FnOnce(T)> Drop for ScopeGuard<T, F> { /// This has to be used for all values that the C side has access to, because it can't be ensured /// that the C side is adhering to the usual constraints that Rust needs. /// -/// Using `Opaque<T>` allows to continue to use references on the Rust side even for values shared +/// Using [`Opaque<T>`] allows to continue to use references on the Rust side even for values shared /// with C. /// /// # Examples @@ -264,6 +326,9 @@ pub struct Opaque<T> { _pin: PhantomPinned, } +// SAFETY: `Opaque<T>` allows the inner value to be any bit pattern, including all zeros. +unsafe impl<T> Zeroable for Opaque<T> {} + impl<T> Opaque<T> { /// Creates a new opaque value. pub const fn new(value: T) -> Self { @@ -281,6 +346,14 @@ impl<T> Opaque<T> { } } + /// Creates a new zeroed opaque value. + pub const fn zeroed() -> Self { + Self { + value: UnsafeCell::new(MaybeUninit::zeroed()), + _pin: PhantomPinned, + } + } + /// Creates a pin-initializer from the given initializer closure. /// /// The returned initializer calls the given closure with the pointer to the inner `T` of this @@ -293,8 +366,8 @@ impl<T> Opaque<T> { // SAFETY: We contain a `MaybeUninit`, so it is OK for the `init_func` to not fully // initialize the `T`. unsafe { - init::pin_init_from_closure::<_, ::core::convert::Infallible>(move |slot| { - init_func(Self::raw_get(slot)); + pin_init::pin_init_from_closure::<_, ::core::convert::Infallible>(move |slot| { + init_func(Self::cast_into(slot)); Ok(()) }) } @@ -313,7 +386,9 @@ impl<T> Opaque<T> { ) -> impl PinInit<Self, E> { // SAFETY: We contain a `MaybeUninit`, so it is OK for the `init_func` to not fully // initialize the `T`. - unsafe { init::pin_init_from_closure::<_, E>(move |slot| init_func(Self::raw_get(slot))) } + unsafe { + pin_init::pin_init_from_closure::<_, E>(move |slot| init_func(Self::cast_into(slot))) + } } /// Returns a raw pointer to the opaque data. @@ -325,178 +400,29 @@ impl<T> Opaque<T> { /// /// This function is useful to get access to the value without creating intermediate /// references. - pub const fn raw_get(this: *const Self) -> *mut T { + pub const fn cast_into(this: *const Self) -> *mut T { UnsafeCell::raw_get(this.cast::<UnsafeCell<MaybeUninit<T>>>()).cast::<T>() } -} - -/// Types that are _always_ reference counted. -/// -/// It allows such types to define their own custom ref increment and decrement functions. -/// Additionally, it allows users to convert from a shared reference `&T` to an owned reference -/// [`ARef<T>`]. -/// -/// This is usually implemented by wrappers to existing structures on the C side of the code. For -/// Rust code, the recommendation is to use [`Arc`](crate::sync::Arc) to create reference-counted -/// instances of a type. -/// -/// # Safety -/// -/// Implementers must ensure that increments to the reference count keep the object alive in memory -/// at least until matching decrements are performed. -/// -/// Implementers must also ensure that all instances are reference-counted. (Otherwise they -/// won't be able to honour the requirement that [`AlwaysRefCounted::inc_ref`] keep the object -/// alive.) -pub unsafe trait AlwaysRefCounted { - /// Increments the reference count on the object. - fn inc_ref(&self); - - /// Decrements the reference count on the object. - /// - /// Frees the object when the count reaches zero. - /// - /// # Safety - /// - /// Callers must ensure that there was a previous matching increment to the reference count, - /// and that the object is no longer used after its reference count is decremented (as it may - /// result in the object being freed), unless the caller owns another increment on the refcount - /// (e.g., it calls [`AlwaysRefCounted::inc_ref`] twice, then calls - /// [`AlwaysRefCounted::dec_ref`] once). - unsafe fn dec_ref(obj: NonNull<Self>); -} - -/// An owned reference to an always-reference-counted object. -/// -/// The object's reference count is automatically decremented when an instance of [`ARef`] is -/// dropped. It is also automatically incremented when a new instance is created via -/// [`ARef::clone`]. -/// -/// # Invariants -/// -/// The pointer stored in `ptr` is non-null and valid for the lifetime of the [`ARef`] instance. In -/// particular, the [`ARef`] instance owns an increment on the underlying object's reference count. -pub struct ARef<T: AlwaysRefCounted> { - ptr: NonNull<T>, - _p: PhantomData<T>, -} - -// SAFETY: It is safe to send `ARef<T>` to another thread when the underlying `T` is `Sync` because -// it effectively means sharing `&T` (which is safe because `T` is `Sync`); additionally, it needs -// `T` to be `Send` because any thread that has an `ARef<T>` may ultimately access `T` using a -// mutable reference, for example, when the reference count reaches zero and `T` is dropped. -unsafe impl<T: AlwaysRefCounted + Sync + Send> Send for ARef<T> {} - -// SAFETY: It is safe to send `&ARef<T>` to another thread when the underlying `T` is `Sync` -// because it effectively means sharing `&T` (which is safe because `T` is `Sync`); additionally, -// it needs `T` to be `Send` because any thread that has a `&ARef<T>` may clone it and get an -// `ARef<T>` on that thread, so the thread may ultimately access `T` using a mutable reference, for -// example, when the reference count reaches zero and `T` is dropped. -unsafe impl<T: AlwaysRefCounted + Sync + Send> Sync for ARef<T> {} - -impl<T: AlwaysRefCounted> ARef<T> { - /// Creates a new instance of [`ARef`]. - /// - /// It takes over an increment of the reference count on the underlying object. - /// - /// # Safety - /// - /// Callers must ensure that the reference count was incremented at least once, and that they - /// are properly relinquishing one increment. That is, if there is only one increment, callers - /// must not use the underlying object anymore -- it is only safe to do so via the newly - /// created [`ARef`]. - pub unsafe fn from_raw(ptr: NonNull<T>) -> Self { - // INVARIANT: The safety requirements guarantee that the new instance now owns the - // increment on the refcount. - Self { - ptr, - _p: PhantomData, - } - } - /// Consumes the `ARef`, returning a raw pointer. - /// - /// This function does not change the refcount. After calling this function, the caller is - /// responsible for the refcount previously managed by the `ARef`. - /// - /// # Examples - /// - /// ``` - /// use core::ptr::NonNull; - /// use kernel::types::{ARef, AlwaysRefCounted}; - /// - /// struct Empty {} - /// - /// # // SAFETY: TODO. - /// unsafe impl AlwaysRefCounted for Empty { - /// fn inc_ref(&self) {} - /// unsafe fn dec_ref(_obj: NonNull<Self>) {} - /// } - /// - /// let mut data = Empty {}; - /// let ptr = NonNull::<Empty>::new(&mut data as *mut _).unwrap(); - /// # // SAFETY: TODO. - /// let data_ref: ARef<Empty> = unsafe { ARef::from_raw(ptr) }; - /// let raw_ptr: NonNull<Empty> = ARef::into_raw(data_ref); - /// - /// assert_eq!(ptr, raw_ptr); - /// ``` - pub fn into_raw(me: Self) -> NonNull<T> { - ManuallyDrop::new(me).ptr - } -} - -impl<T: AlwaysRefCounted> Clone for ARef<T> { - fn clone(&self) -> Self { - self.inc_ref(); - // SAFETY: We just incremented the refcount above. - unsafe { Self::from_raw(self.ptr) } - } -} - -impl<T: AlwaysRefCounted> Deref for ARef<T> { - type Target = T; - - fn deref(&self) -> &Self::Target { - // SAFETY: The type invariants guarantee that the object is valid. - unsafe { self.ptr.as_ref() } - } -} - -impl<T: AlwaysRefCounted> From<&T> for ARef<T> { - fn from(b: &T) -> Self { - b.inc_ref(); - // SAFETY: We just incremented the refcount above. - unsafe { Self::from_raw(NonNull::from(b)) } + /// The opposite operation of [`Opaque::cast_into`]. + pub const fn cast_from(this: *const T) -> *const Self { + this.cast() } } -impl<T: AlwaysRefCounted> Drop for ARef<T> { - fn drop(&mut self) { - // SAFETY: The type invariants guarantee that the `ARef` owns the reference we're about to - // decrement. - unsafe { T::dec_ref(self.ptr) }; +impl<T> Wrapper<T> for Opaque<T> { + /// Create an opaque pin-initializer from the given pin-initializer. + fn pin_init<E>(slot: impl PinInit<T, E>) -> impl PinInit<Self, E> { + Self::try_ffi_init(|ptr: *mut T| { + // SAFETY: + // - `ptr` is a valid pointer to uninitialized memory, + // - `slot` is not accessed on error, + // - `slot` is pinned in memory. + unsafe { PinInit::<T, E>::__pinned_init(slot, ptr) } + }) } } -/// A sum type that always holds either a value of type `L` or `R`. -/// -/// # Examples -/// -/// ``` -/// use kernel::types::Either; -/// -/// let left_value: Either<i32, &str> = Either::Left(7); -/// let right_value: Either<i32, &str> = Either::Right("right value"); -/// ``` -pub enum Either<L, R> { - /// Constructs an instance of [`Either`] containing a value of type `L`. - Left(L), - - /// Constructs an instance of [`Either`] containing a value of type `R`. - Right(R), -} - /// Zero-sized type to mark types not [`Send`]. /// /// Add this type as a field to your struct if your type should not be sent to a different task. diff --git a/rust/kernel/uaccess.rs b/rust/kernel/uaccess.rs index cc044924867b..a8fb4764185a 100644 --- a/rust/kernel/uaccess.rs +++ b/rust/kernel/uaccess.rs @@ -5,17 +5,60 @@ //! C header: [`include/linux/uaccess.h`](srctree/include/linux/uaccess.h) use crate::{ - alloc::Flags, + alloc::{Allocator, Flags}, bindings, error::Result, - ffi::c_void, + ffi::{c_char, c_void}, prelude::*, transmute::{AsBytes, FromBytes}, }; use core::mem::{size_of, MaybeUninit}; -/// The type used for userspace addresses. -pub type UserPtr = usize; +/// A pointer into userspace. +/// +/// This is the Rust equivalent to C pointers tagged with `__user`. +#[repr(transparent)] +#[derive(Copy, Clone)] +pub struct UserPtr(*mut c_void); + +impl UserPtr { + /// Create a `UserPtr` from an integer representing the userspace address. + #[inline] + pub fn from_addr(addr: usize) -> Self { + Self(addr as *mut c_void) + } + + /// Create a `UserPtr` from a pointer representing the userspace address. + #[inline] + pub fn from_ptr(addr: *mut c_void) -> Self { + Self(addr) + } + + /// Cast this userspace pointer to a raw const void pointer. + /// + /// It is up to the caller to use the returned pointer correctly. + #[inline] + pub fn as_const_ptr(self) -> *const c_void { + self.0 + } + + /// Cast this userspace pointer to a raw mutable void pointer. + /// + /// It is up to the caller to use the returned pointer correctly. + #[inline] + pub fn as_mut_ptr(self) -> *mut c_void { + self.0 + } + + /// Increment this user pointer by `add` bytes. + /// + /// This addition is wrapping, so wrapping around the address space does not result in a panic + /// even if `CONFIG_RUST_OVERFLOW_CHECKS` is enabled. + #[inline] + pub fn wrapping_byte_add(self, add: usize) -> UserPtr { + UserPtr(self.0.wrapping_byte_add(add)) + } +} /// A pointer to an area in userspace memory, which can be either read-only or read-write. /// @@ -46,10 +89,9 @@ pub type UserPtr = usize; /// /// ```no_run /// use kernel::ffi::c_void; -/// use kernel::error::Result; /// use kernel::uaccess::{UserPtr, UserSlice}; /// -/// fn bytes_add_one(uptr: UserPtr, len: usize) -> Result<()> { +/// fn bytes_add_one(uptr: UserPtr, len: usize) -> Result { /// let (read, mut write) = UserSlice::new(uptr, len).reader_writer(); /// /// let mut buf = KVec::new(); @@ -68,7 +110,6 @@ pub type UserPtr = usize; /// /// ```no_run /// use kernel::ffi::c_void; -/// use kernel::error::{code::EINVAL, Result}; /// use kernel::uaccess::{UserPtr, UserSlice}; /// /// /// Returns whether the data in this region is valid. @@ -127,7 +168,7 @@ impl UserSlice { /// Reads the entirety of the user slice, appending it to the end of the provided buffer. /// /// Fails with [`EFAULT`] if the read happens on a bad address. - pub fn read_all(self, buf: &mut KVec<u8>, flags: Flags) -> Result { + pub fn read_all<A: Allocator>(self, buf: &mut Vec<u8, A>, flags: Flags) -> Result { self.reader().read_all(buf, flags) } @@ -179,7 +220,7 @@ impl UserSliceReader { pub fn skip(&mut self, num_skip: usize) -> Result { // Update `self.length` first since that's the fallible part of this operation. self.length = self.length.checked_sub(num_skip).ok_or(EFAULT)?; - self.ptr = self.ptr.wrapping_add(num_skip); + self.ptr = self.ptr.wrapping_byte_add(num_skip); Ok(()) } @@ -226,11 +267,11 @@ impl UserSliceReader { } // SAFETY: `out_ptr` points into a mutable slice of length `len`, so we may write // that many bytes to it. - let res = unsafe { bindings::copy_from_user(out_ptr, self.ptr as *const c_void, len) }; + let res = unsafe { bindings::copy_from_user(out_ptr, self.ptr.as_const_ptr(), len) }; if res != 0 { return Err(EFAULT); } - self.ptr = self.ptr.wrapping_add(len); + self.ptr = self.ptr.wrapping_byte_add(len); self.length -= len; Ok(()) } @@ -242,7 +283,7 @@ impl UserSliceReader { pub fn read_slice(&mut self, out: &mut [u8]) -> Result { // SAFETY: The types are compatible and `read_raw` doesn't write uninitialized bytes to // `out`. - let out = unsafe { &mut *(out as *mut [u8] as *mut [MaybeUninit<u8>]) }; + let out = unsafe { &mut *(core::ptr::from_mut(out) as *mut [MaybeUninit<u8>]) }; self.read_raw(out) } @@ -264,14 +305,14 @@ impl UserSliceReader { let res = unsafe { bindings::_copy_from_user( out.as_mut_ptr().cast::<c_void>(), - self.ptr as *const c_void, + self.ptr.as_const_ptr(), len, ) }; if res != 0 { return Err(EFAULT); } - self.ptr = self.ptr.wrapping_add(len); + self.ptr = self.ptr.wrapping_byte_add(len); self.length -= len; // SAFETY: The read above has initialized all bytes in `out`, and since `T` implements // `FromBytes`, any bit-pattern is a valid value for this type. @@ -281,19 +322,77 @@ impl UserSliceReader { /// Reads the entirety of the user slice, appending it to the end of the provided buffer. /// /// Fails with [`EFAULT`] if the read happens on a bad address. - pub fn read_all(mut self, buf: &mut KVec<u8>, flags: Flags) -> Result { + pub fn read_all<A: Allocator>(mut self, buf: &mut Vec<u8, A>, flags: Flags) -> Result { let len = self.length; buf.reserve(len, flags)?; - // The call to `try_reserve` was successful, so the spare capacity is at least `len` bytes - // long. + // The call to `reserve` was successful, so the spare capacity is at least `len` bytes long. self.read_raw(&mut buf.spare_capacity_mut()[..len])?; // SAFETY: Since the call to `read_raw` was successful, so the next `len` bytes of the // vector have been initialized. - unsafe { buf.set_len(buf.len() + len) }; + unsafe { buf.inc_len(len) }; Ok(()) } + + /// Read a NUL-terminated string from userspace and return it. + /// + /// The string is read into `buf` and a NUL-terminator is added if the end of `buf` is reached. + /// Since there must be space to add a NUL-terminator, the buffer must not be empty. The + /// returned `&CStr` points into `buf`. + /// + /// Fails with [`EFAULT`] if the read happens on a bad address (some data may have been + /// copied). + #[doc(alias = "strncpy_from_user")] + pub fn strcpy_into_buf<'buf>(self, buf: &'buf mut [u8]) -> Result<&'buf CStr> { + if buf.is_empty() { + return Err(EINVAL); + } + + // SAFETY: The types are compatible and `strncpy_from_user` doesn't write uninitialized + // bytes to `buf`. + let mut dst = unsafe { &mut *(core::ptr::from_mut(buf) as *mut [MaybeUninit<u8>]) }; + + // We never read more than `self.length` bytes. + if dst.len() > self.length { + dst = &mut dst[..self.length]; + } + + let mut len = raw_strncpy_from_user(dst, self.ptr)?; + if len < dst.len() { + // Add one to include the NUL-terminator. + len += 1; + } else if len < buf.len() { + // This implies that `len == dst.len() < buf.len()`. + // + // This means that we could not fill the entire buffer, but we had to stop reading + // because we hit the `self.length` limit of this `UserSliceReader`. Since we did not + // fill the buffer, we treat this case as if we tried to read past the `self.length` + // limit and received a page fault, which is consistent with other `UserSliceReader` + // methods that also return page faults when you exceed `self.length`. + return Err(EFAULT); + } else { + // This implies that `len == buf.len()`. + // + // This means that we filled the buffer exactly. In this case, we add a NUL-terminator + // and return it. Unlike the `len < dst.len()` branch, don't modify `len` because it + // already represents the length including the NUL-terminator. + // + // SAFETY: Due to the check at the beginning, the buffer is not empty. + unsafe { *buf.last_mut().unwrap_unchecked() = 0 }; + } + + // This method consumes `self`, so it can only be called once, thus we do not need to + // update `self.length`. This sidesteps concerns such as whether `self.length` should be + // incremented by `len` or `len-1` in the `len == buf.len()` case. + + // SAFETY: There are two cases: + // * If we hit the `len < dst.len()` case, then `raw_strncpy_from_user` guarantees that + // this slice contains exactly one NUL byte at the end of the string. + // * Otherwise, `raw_strncpy_from_user` guarantees that the string contained no NUL bytes, + // and we have since added a NUL byte at the end. + Ok(unsafe { CStr::from_bytes_with_nul_unchecked(&buf[..len]) }) + } } /// A writer for [`UserSlice`]. @@ -330,11 +429,11 @@ impl UserSliceWriter { } // SAFETY: `data_ptr` points into an immutable slice of length `len`, so we may read // that many bytes from it. - let res = unsafe { bindings::copy_to_user(self.ptr as *mut c_void, data_ptr, len) }; + let res = unsafe { bindings::copy_to_user(self.ptr.as_mut_ptr(), data_ptr, len) }; if res != 0 { return Err(EFAULT); } - self.ptr = self.ptr.wrapping_add(len); + self.ptr = self.ptr.wrapping_byte_add(len); self.length -= len; Ok(()) } @@ -357,16 +456,53 @@ impl UserSliceWriter { // is a compile-time constant. let res = unsafe { bindings::_copy_to_user( - self.ptr as *mut c_void, - (value as *const T).cast::<c_void>(), + self.ptr.as_mut_ptr(), + core::ptr::from_ref(value).cast::<c_void>(), len, ) }; if res != 0 { return Err(EFAULT); } - self.ptr = self.ptr.wrapping_add(len); + self.ptr = self.ptr.wrapping_byte_add(len); self.length -= len; Ok(()) } } + +/// Reads a nul-terminated string into `dst` and returns the length. +/// +/// This reads from userspace until a NUL byte is encountered, or until `dst.len()` bytes have been +/// read. Fails with [`EFAULT`] if a read happens on a bad address (some data may have been +/// copied). When the end of the buffer is encountered, no NUL byte is added, so the string is +/// *not* guaranteed to be NUL-terminated when `Ok(dst.len())` is returned. +/// +/// # Guarantees +/// +/// When this function returns `Ok(len)`, it is guaranteed that the first `len` bytes of `dst` are +/// initialized and non-zero. Furthermore, if `len < dst.len()`, then `dst[len]` is a NUL byte. +#[inline] +fn raw_strncpy_from_user(dst: &mut [MaybeUninit<u8>], src: UserPtr) -> Result<usize> { + // CAST: Slice lengths are guaranteed to be `<= isize::MAX`. + let len = dst.len() as isize; + + // SAFETY: `dst` is valid for writing `dst.len()` bytes. + let res = unsafe { + bindings::strncpy_from_user( + dst.as_mut_ptr().cast::<c_char>(), + src.as_const_ptr().cast::<c_char>(), + len, + ) + }; + + if res < 0 { + return Err(Error::from_errno(res as i32)); + } + + #[cfg(CONFIG_RUST_OVERFLOW_CHECKS)] + assert!(res <= len); + + // GUARANTEES: `strncpy_from_user` was successful, so `dst` has contents in accordance with the + // guarantees of this function. + Ok(res as usize) +} diff --git a/rust/kernel/workqueue.rs b/rust/kernel/workqueue.rs index fd3e97192ed8..b9343d5bc00f 100644 --- a/rust/kernel/workqueue.rs +++ b/rust/kernel/workqueue.rs @@ -26,7 +26,7 @@ //! * The [`WorkItemPointer`] trait is implemented for the pointer type that points at a something //! that implements [`WorkItem`]. //! -//! ## Example +//! ## Examples //! //! This example defines a struct that holds an integer and can be scheduled on the workqueue. When //! the struct is executed, it will print the integer. Since there is only one `work_struct` field, @@ -60,7 +60,7 @@ //! type Pointer = Arc<MyStruct>; //! //! fn run(this: Arc<MyStruct>) { -//! pr_info!("The value is: {}", this.value); +//! pr_info!("The value is: {}\n", this.value); //! } //! } //! @@ -69,6 +69,7 @@ //! fn print_later(val: Arc<MyStruct>) { //! let _ = workqueue::system().enqueue(val); //! } +//! # print_later(MyStruct::new(42).unwrap()); //! ``` //! //! The following example shows how multiple `work_struct` fields can be used: @@ -107,7 +108,7 @@ //! type Pointer = Arc<MyStruct>; //! //! fn run(this: Arc<MyStruct>) { -//! pr_info!("The value is: {}", this.value_1); +//! pr_info!("The value is: {}\n", this.value_1); //! } //! } //! @@ -115,7 +116,7 @@ //! type Pointer = Arc<MyStruct>; //! //! fn run(this: Arc<MyStruct>) { -//! pr_info!("The second value is: {}", this.value_2); +//! pr_info!("The second value is: {}\n", this.value_2); //! } //! } //! @@ -126,12 +127,73 @@ //! fn print_2_later(val: Arc<MyStruct>) { //! let _ = workqueue::system().enqueue::<Arc<MyStruct>, 2>(val); //! } +//! # print_1_later(MyStruct::new(24, 25).unwrap()); +//! # print_2_later(MyStruct::new(41, 42).unwrap()); +//! ``` +//! +//! This example shows how you can schedule delayed work items: +//! +//! ``` +//! use kernel::sync::Arc; +//! use kernel::workqueue::{self, impl_has_delayed_work, new_delayed_work, DelayedWork, WorkItem}; +//! +//! #[pin_data] +//! struct MyStruct { +//! value: i32, +//! #[pin] +//! work: DelayedWork<MyStruct>, +//! } +//! +//! impl_has_delayed_work! { +//! impl HasDelayedWork<Self> for MyStruct { self.work } +//! } +//! +//! impl MyStruct { +//! fn new(value: i32) -> Result<Arc<Self>> { +//! Arc::pin_init( +//! pin_init!(MyStruct { +//! value, +//! work <- new_delayed_work!("MyStruct::work"), +//! }), +//! GFP_KERNEL, +//! ) +//! } +//! } +//! +//! impl WorkItem for MyStruct { +//! type Pointer = Arc<MyStruct>; +//! +//! fn run(this: Arc<MyStruct>) { +//! pr_info!("The value is: {}\n", this.value); +//! } +//! } +//! +//! /// This method will enqueue the struct for execution on the system workqueue, where its value +//! /// will be printed 12 jiffies later. +//! fn print_later(val: Arc<MyStruct>) { +//! let _ = workqueue::system().enqueue_delayed(val, 12); +//! } +//! +//! /// It is also possible to use the ordinary `enqueue` method together with `DelayedWork`. This +//! /// is equivalent to calling `enqueue_delayed` with a delay of zero. +//! fn print_now(val: Arc<MyStruct>) { +//! let _ = workqueue::system().enqueue(val); +//! } +//! # print_later(MyStruct::new(42).unwrap()); +//! # print_now(MyStruct::new(42).unwrap()); //! ``` //! //! C header: [`include/linux/workqueue.h`](srctree/include/linux/workqueue.h) -use crate::alloc::{AllocError, Flags}; -use crate::{prelude::*, sync::Arc, sync::LockClassKey, types::Opaque}; +use crate::{ + alloc::{AllocError, Flags}, + container_of, + prelude::*, + sync::Arc, + sync::LockClassKey, + time::Jiffies, + types::Opaque, +}; use core::marker::PhantomData; /// Creates a [`Work`] initialiser with the given name and a newly-created lock class. @@ -143,6 +205,33 @@ macro_rules! new_work { } pub use new_work; +/// Creates a [`DelayedWork`] initialiser with the given name and a newly-created lock class. +#[macro_export] +macro_rules! new_delayed_work { + () => { + $crate::workqueue::DelayedWork::new( + $crate::optional_name!(), + $crate::static_lock_class!(), + $crate::c_str!(::core::concat!( + ::core::file!(), + ":", + ::core::line!(), + "_timer" + )), + $crate::static_lock_class!(), + ) + }; + ($name:literal) => { + $crate::workqueue::DelayedWork::new( + $crate::c_str!($name), + $crate::static_lock_class!(), + $crate::c_str!(::core::concat!($name, "_timer")), + $crate::static_lock_class!(), + ) + }; +} +pub use new_delayed_work; + /// A kernel work queue. /// /// Wraps the kernel's C `struct workqueue_struct`. @@ -167,7 +256,7 @@ impl Queue { pub unsafe fn from_raw<'a>(ptr: *const bindings::workqueue_struct) -> &'a Queue { // SAFETY: The `Queue` type is `#[repr(transparent)]`, so the pointer cast is valid. The // caller promises that the pointer is not dangling. - unsafe { &*(ptr as *const Queue) } + unsafe { &*ptr.cast::<Queue>() } } /// Enqueues a work item. @@ -195,7 +284,7 @@ impl Queue { unsafe { w.__enqueue(move |work_ptr| { bindings::queue_work_on( - bindings::wq_misc_consts_WORK_CPU_UNBOUND as _, + bindings::wq_misc_consts_WORK_CPU_UNBOUND as ffi::c_int, queue_ptr, work_ptr, ) @@ -203,6 +292,42 @@ impl Queue { } } + /// Enqueues a delayed work item. + /// + /// This may fail if the work item is already enqueued in a workqueue. + /// + /// The work item will be submitted using `WORK_CPU_UNBOUND`. + pub fn enqueue_delayed<W, const ID: u64>(&self, w: W, delay: Jiffies) -> W::EnqueueOutput + where + W: RawDelayedWorkItem<ID> + Send + 'static, + { + let queue_ptr = self.0.get(); + + // SAFETY: We only return `false` if the `work_struct` is already in a workqueue. The other + // `__enqueue` requirements are not relevant since `W` is `Send` and static. + // + // The call to `bindings::queue_delayed_work_on` will dereference the provided raw pointer, + // which is ok because `__enqueue` guarantees that the pointer is valid for the duration of + // this closure, and the safety requirements of `RawDelayedWorkItem` expands this + // requirement to apply to the entire `delayed_work`. + // + // Furthermore, if the C workqueue code accesses the pointer after this call to + // `__enqueue`, then the work item was successfully enqueued, and + // `bindings::queue_delayed_work_on` will have returned true. In this case, `__enqueue` + // promises that the raw pointer will stay valid until we call the function pointer in the + // `work_struct`, so the access is ok. + unsafe { + w.__enqueue(move |work_ptr| { + bindings::queue_delayed_work_on( + bindings::wq_misc_consts_WORK_CPU_UNBOUND as ffi::c_int, + queue_ptr, + container_of!(work_ptr, bindings::delayed_work, work), + delay, + ) + }) + } + } + /// Tries to spawn the given function or closure as a work item. /// /// This method can fail because it allocates memory to store the work item. @@ -295,6 +420,16 @@ pub unsafe trait RawWorkItem<const ID: u64> { F: FnOnce(*mut bindings::work_struct) -> bool; } +/// A raw delayed work item. +/// +/// # Safety +/// +/// If the `__enqueue` method in the `RawWorkItem` implementation calls the closure, then the +/// provided pointer must point at the `work` field of a valid `delayed_work`, and the guarantees +/// that `__enqueue` provides about accessing the `work_struct` must also apply to the rest of the +/// `delayed_work` struct. +pub unsafe trait RawDelayedWorkItem<const ID: u64>: RawWorkItem<ID> {} + /// Defines the method that should be called directly when a work item is executed. /// /// This trait is implemented by `Pin<KBox<T>>` and [`Arc<T>`], and is mainly intended to be @@ -366,7 +501,7 @@ unsafe impl<T: ?Sized, const ID: u64> Sync for Work<T, ID> {} impl<T: ?Sized, const ID: u64> Work<T, ID> { /// Creates a new instance of [`Work`]. #[inline] - pub fn new(name: &'static CStr, key: &'static LockClassKey) -> impl PinInit<Self> + pub fn new(name: &'static CStr, key: Pin<&'static LockClassKey>) -> impl PinInit<Self> where T: WorkItem<ID>, { @@ -400,11 +535,11 @@ impl<T: ?Sized, const ID: u64> Work<T, ID> { // // A pointer cast would also be ok due to `#[repr(transparent)]`. We use `addr_of!` so that // the compiler does not complain that the `work` field is unused. - unsafe { Opaque::raw_get(core::ptr::addr_of!((*ptr).work)) } + unsafe { Opaque::cast_into(core::ptr::addr_of!((*ptr).work)) } } } -/// Declares that a type has a [`Work<T, ID>`] field. +/// Declares that a type contains a [`Work<T, ID>`]. /// /// The intended way of using this trait is via the [`impl_has_work!`] macro. You can use the macro /// like this: @@ -426,51 +561,28 @@ impl<T: ?Sized, const ID: u64> Work<T, ID> { /// /// # Safety /// -/// The [`OFFSET`] constant must be the offset of a field in `Self` of type [`Work<T, ID>`]. The -/// methods on this trait must have exactly the behavior that the definitions given below have. +/// The methods [`raw_get_work`] and [`work_container_of`] must return valid pointers and must be +/// true inverses of each other; that is, they must satisfy the following invariants: +/// - `work_container_of(raw_get_work(ptr)) == ptr` for any `ptr: *mut Self`. +/// - `raw_get_work(work_container_of(ptr)) == ptr` for any `ptr: *mut Work<T, ID>`. /// /// [`impl_has_work!`]: crate::impl_has_work -/// [`OFFSET`]: HasWork::OFFSET +/// [`raw_get_work`]: HasWork::raw_get_work +/// [`work_container_of`]: HasWork::work_container_of pub unsafe trait HasWork<T, const ID: u64 = 0> { - /// The offset of the [`Work<T, ID>`] field. - const OFFSET: usize; - - /// Returns the offset of the [`Work<T, ID>`] field. - /// - /// This method exists because the [`OFFSET`] constant cannot be accessed if the type is not - /// [`Sized`]. - /// - /// [`OFFSET`]: HasWork::OFFSET - #[inline] - fn get_work_offset(&self) -> usize { - Self::OFFSET - } - /// Returns a pointer to the [`Work<T, ID>`] field. /// /// # Safety /// /// The provided pointer must point at a valid struct of type `Self`. - #[inline] - unsafe fn raw_get_work(ptr: *mut Self) -> *mut Work<T, ID> { - // SAFETY: The caller promises that the pointer is valid. - unsafe { (ptr as *mut u8).add(Self::OFFSET) as *mut Work<T, ID> } - } + unsafe fn raw_get_work(ptr: *mut Self) -> *mut Work<T, ID>; /// Returns a pointer to the struct containing the [`Work<T, ID>`] field. /// /// # Safety /// /// The pointer must point at a [`Work<T, ID>`] field in a struct of type `Self`. - #[inline] - unsafe fn work_container_of(ptr: *mut Work<T, ID>) -> *mut Self - where - Self: Sized, - { - // SAFETY: The caller promises that the pointer points at a field of the right type in the - // right kind of struct. - unsafe { (ptr as *mut u8).sub(Self::OFFSET) as *mut Self } - } + unsafe fn work_container_of(ptr: *mut Work<T, ID>) -> *mut Self; } /// Used to safely implement the [`HasWork<T, ID>`] trait. @@ -501,8 +613,6 @@ macro_rules! impl_has_work { // SAFETY: The implementation of `raw_get_work` only compiles if the field has the right // type. unsafe impl$(<$($generics)+>)? $crate::workqueue::HasWork<$work_type $(, $id)?> for $self { - const OFFSET: usize = ::core::mem::offset_of!(Self, $field) as usize; - #[inline] unsafe fn raw_get_work(ptr: *mut Self) -> *mut $crate::workqueue::Work<$work_type $(, $id)?> { // SAFETY: The caller promises that the pointer is not dangling. @@ -510,6 +620,15 @@ macro_rules! impl_has_work { ::core::ptr::addr_of_mut!((*ptr).$field) } } + + #[inline] + unsafe fn work_container_of( + ptr: *mut $crate::workqueue::Work<$work_type $(, $id)?>, + ) -> *mut Self { + // SAFETY: The caller promises that the pointer points at a field of the right type + // in the right kind of struct. + unsafe { $crate::container_of!(ptr, Self, $field) } + } } )*}; } @@ -519,6 +638,178 @@ impl_has_work! { impl{T} HasWork<Self> for ClosureWork<T> { self.work } } +/// Links for a delayed work item. +/// +/// This struct contains a function pointer to the [`run`] function from the [`WorkItemPointer`] +/// trait, and defines the linked list pointers necessary to enqueue a work item in a workqueue in +/// a delayed manner. +/// +/// Wraps the kernel's C `struct delayed_work`. +/// +/// This is a helper type used to associate a `delayed_work` with the [`WorkItem`] that uses it. +/// +/// [`run`]: WorkItemPointer::run +#[pin_data] +#[repr(transparent)] +pub struct DelayedWork<T: ?Sized, const ID: u64 = 0> { + #[pin] + dwork: Opaque<bindings::delayed_work>, + _inner: PhantomData<T>, +} + +// SAFETY: Kernel work items are usable from any thread. +// +// We do not need to constrain `T` since the work item does not actually contain a `T`. +unsafe impl<T: ?Sized, const ID: u64> Send for DelayedWork<T, ID> {} +// SAFETY: Kernel work items are usable from any thread. +// +// We do not need to constrain `T` since the work item does not actually contain a `T`. +unsafe impl<T: ?Sized, const ID: u64> Sync for DelayedWork<T, ID> {} + +impl<T: ?Sized, const ID: u64> DelayedWork<T, ID> { + /// Creates a new instance of [`DelayedWork`]. + #[inline] + pub fn new( + work_name: &'static CStr, + work_key: Pin<&'static LockClassKey>, + timer_name: &'static CStr, + timer_key: Pin<&'static LockClassKey>, + ) -> impl PinInit<Self> + where + T: WorkItem<ID>, + { + pin_init!(Self { + dwork <- Opaque::ffi_init(|slot: *mut bindings::delayed_work| { + // SAFETY: The `WorkItemPointer` implementation promises that `run` can be used as + // the work item function. + unsafe { + bindings::init_work_with_key( + core::ptr::addr_of_mut!((*slot).work), + Some(T::Pointer::run), + false, + work_name.as_char_ptr(), + work_key.as_ptr(), + ) + } + + // SAFETY: The `delayed_work_timer_fn` function pointer can be used here because + // the timer is embedded in a `struct delayed_work`, and only ever scheduled via + // the core workqueue code, and configured to run in irqsafe context. + unsafe { + bindings::timer_init_key( + core::ptr::addr_of_mut!((*slot).timer), + Some(bindings::delayed_work_timer_fn), + bindings::TIMER_IRQSAFE, + timer_name.as_char_ptr(), + timer_key.as_ptr(), + ) + } + }), + _inner: PhantomData, + }) + } + + /// Get a pointer to the inner `delayed_work`. + /// + /// # Safety + /// + /// The provided pointer must not be dangling and must be properly aligned. (But the memory + /// need not be initialized.) + #[inline] + pub unsafe fn raw_as_work(ptr: *const Self) -> *mut Work<T, ID> { + // SAFETY: The caller promises that the pointer is aligned and not dangling. + let dw: *mut bindings::delayed_work = + unsafe { Opaque::cast_into(core::ptr::addr_of!((*ptr).dwork)) }; + // SAFETY: The caller promises that the pointer is aligned and not dangling. + let wrk: *mut bindings::work_struct = unsafe { core::ptr::addr_of_mut!((*dw).work) }; + // CAST: Work and work_struct have compatible layouts. + wrk.cast() + } +} + +/// Declares that a type contains a [`DelayedWork<T, ID>`]. +/// +/// # Safety +/// +/// The `HasWork<T, ID>` implementation must return a `work_struct` that is stored in the `work` +/// field of a `delayed_work` with the same access rules as the `work_struct`. +pub unsafe trait HasDelayedWork<T, const ID: u64 = 0>: HasWork<T, ID> {} + +/// Used to safely implement the [`HasDelayedWork<T, ID>`] trait. +/// +/// This macro also implements the [`HasWork`] trait, so you do not need to use [`impl_has_work!`] +/// when using this macro. +/// +/// # Examples +/// +/// ``` +/// use kernel::sync::Arc; +/// use kernel::workqueue::{self, impl_has_delayed_work, DelayedWork}; +/// +/// struct MyStruct<'a, T, const N: usize> { +/// work_field: DelayedWork<MyStruct<'a, T, N>, 17>, +/// f: fn(&'a [T; N]), +/// } +/// +/// impl_has_delayed_work! { +/// impl{'a, T, const N: usize} HasDelayedWork<MyStruct<'a, T, N>, 17> +/// for MyStruct<'a, T, N> { self.work_field } +/// } +/// ``` +#[macro_export] +macro_rules! impl_has_delayed_work { + ($(impl$({$($generics:tt)*})? + HasDelayedWork<$work_type:ty $(, $id:tt)?> + for $self:ty + { self.$field:ident } + )*) => {$( + // SAFETY: The implementation of `raw_get_work` only compiles if the field has the right + // type. + unsafe impl$(<$($generics)+>)? + $crate::workqueue::HasDelayedWork<$work_type $(, $id)?> for $self {} + + // SAFETY: The implementation of `raw_get_work` only compiles if the field has the right + // type. + unsafe impl$(<$($generics)+>)? $crate::workqueue::HasWork<$work_type $(, $id)?> for $self { + #[inline] + unsafe fn raw_get_work( + ptr: *mut Self + ) -> *mut $crate::workqueue::Work<$work_type $(, $id)?> { + // SAFETY: The caller promises that the pointer is not dangling. + let ptr: *mut $crate::workqueue::DelayedWork<$work_type $(, $id)?> = unsafe { + ::core::ptr::addr_of_mut!((*ptr).$field) + }; + + // SAFETY: The caller promises that the pointer is not dangling. + unsafe { $crate::workqueue::DelayedWork::raw_as_work(ptr) } + } + + #[inline] + unsafe fn work_container_of( + ptr: *mut $crate::workqueue::Work<$work_type $(, $id)?>, + ) -> *mut Self { + // SAFETY: The caller promises that the pointer points at a field of the right type + // in the right kind of struct. + let ptr = unsafe { $crate::workqueue::Work::raw_get(ptr) }; + + // SAFETY: The caller promises that the pointer points at a field of the right type + // in the right kind of struct. + let delayed_work = unsafe { + $crate::container_of!(ptr, $crate::bindings::delayed_work, work) + }; + + let delayed_work: *mut $crate::workqueue::DelayedWork<$work_type $(, $id)?> = + delayed_work.cast(); + + // SAFETY: The caller promises that the pointer points at a field of the right type + // in the right kind of struct. + unsafe { $crate::container_of!(delayed_work, Self, $field) } + } + } + )*}; +} +pub use impl_has_delayed_work; + // SAFETY: The `__enqueue` implementation in RawWorkItem uses a `work_struct` initialized with the // `run` method of this trait as the function pointer because: // - `__enqueue` gets the `work_struct` from the `Work` field, using `T::raw_get_work`. @@ -535,7 +826,7 @@ where { unsafe extern "C" fn run(ptr: *mut bindings::work_struct) { // The `__enqueue` method always uses a `work_struct` stored in a `Work<T, ID>`. - let ptr = ptr as *mut Work<T, ID>; + let ptr = ptr.cast::<Work<T, ID>>(); // SAFETY: This computes the pointer that `__enqueue` got from `Arc::into_raw`. let ptr = unsafe { T::work_container_of(ptr) }; // SAFETY: This pointer comes from `Arc::into_raw` and we've been given back ownership. @@ -580,6 +871,16 @@ where } } +// SAFETY: By the safety requirements of `HasDelayedWork`, the `work_struct` returned by methods in +// `HasWork` provides a `work_struct` that is the `work` field of a `delayed_work`, and the rest of +// the `delayed_work` has the same access rules as its `work` field. +unsafe impl<T, const ID: u64> RawDelayedWorkItem<ID> for Arc<T> +where + T: WorkItem<ID, Pointer = Self>, + T: HasDelayedWork<T, ID>, +{ +} + // SAFETY: TODO. unsafe impl<T, const ID: u64> WorkItemPointer<ID> for Pin<KBox<T>> where @@ -588,7 +889,7 @@ where { unsafe extern "C" fn run(ptr: *mut bindings::work_struct) { // The `__enqueue` method always uses a `work_struct` stored in a `Work<T, ID>`. - let ptr = ptr as *mut Work<T, ID>; + let ptr = ptr.cast::<Work<T, ID>>(); // SAFETY: This computes the pointer that `__enqueue` got from `Arc::into_raw`. let ptr = unsafe { T::work_container_of(ptr) }; // SAFETY: This pointer comes from `Arc::into_raw` and we've been given back ownership. @@ -630,6 +931,16 @@ where } } +// SAFETY: By the safety requirements of `HasDelayedWork`, the `work_struct` returned by methods in +// `HasWork` provides a `work_struct` that is the `work` field of a `delayed_work`, and the rest of +// the `delayed_work` has the same access rules as its `work` field. +unsafe impl<T, const ID: u64> RawDelayedWorkItem<ID> for Pin<KBox<T>> +where + T: WorkItem<ID, Pointer = Self>, + T: HasDelayedWork<T, ID>, +{ +} + /// Returns the system work queue (`system_wq`). /// /// It is the one used by `schedule[_delayed]_work[_on]()`. Multi-CPU multi-threaded. There are @@ -700,3 +1011,21 @@ pub fn system_freezable_power_efficient() -> &'static Queue { // SAFETY: `system_freezable_power_efficient_wq` is a C global, always available. unsafe { Queue::from_raw(bindings::system_freezable_power_efficient_wq) } } + +/// Returns the system bottom halves work queue (`system_bh_wq`). +/// +/// It is similar to the one returned by [`system`] but for work items which +/// need to run from a softirq context. +pub fn system_bh() -> &'static Queue { + // SAFETY: `system_bh_wq` is a C global, always available. + unsafe { Queue::from_raw(bindings::system_bh_wq) } +} + +/// Returns the system bottom halves high-priority work queue (`system_bh_highpri_wq`). +/// +/// It is similar to the one returned by [`system_bh`] but for work items which +/// require higher scheduling priority. +pub fn system_bh_highpri() -> &'static Queue { + // SAFETY: `system_bh_highpri_wq` is a C global, always available. + unsafe { Queue::from_raw(bindings::system_bh_highpri_wq) } +} diff --git a/rust/kernel/xarray.rs b/rust/kernel/xarray.rs new file mode 100644 index 000000000000..a49d6db28845 --- /dev/null +++ b/rust/kernel/xarray.rs @@ -0,0 +1,276 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! XArray abstraction. +//! +//! C header: [`include/linux/xarray.h`](srctree/include/linux/xarray.h) + +use crate::{ + alloc, bindings, build_assert, + error::{Error, Result}, + ffi::c_void, + types::{ForeignOwnable, NotThreadSafe, Opaque}, +}; +use core::{iter, marker::PhantomData, pin::Pin, ptr::NonNull}; +use pin_init::{pin_data, pin_init, pinned_drop, PinInit}; + +/// An array which efficiently maps sparse integer indices to owned objects. +/// +/// This is similar to a [`crate::alloc::kvec::Vec<Option<T>>`], but more efficient when there are +/// holes in the index space, and can be efficiently grown. +/// +/// # Invariants +/// +/// `self.xa` is always an initialized and valid [`bindings::xarray`] whose entries are either +/// `XA_ZERO_ENTRY` or came from `T::into_foreign`. +/// +/// # Examples +/// +/// ```rust +/// use kernel::alloc::KBox; +/// use kernel::xarray::{AllocKind, XArray}; +/// +/// let xa = KBox::pin_init(XArray::new(AllocKind::Alloc1), GFP_KERNEL)?; +/// +/// let dead = KBox::new(0xdead, GFP_KERNEL)?; +/// let beef = KBox::new(0xbeef, GFP_KERNEL)?; +/// +/// let mut guard = xa.lock(); +/// +/// assert_eq!(guard.get(0), None); +/// +/// assert_eq!(guard.store(0, dead, GFP_KERNEL)?.as_deref(), None); +/// assert_eq!(guard.get(0).copied(), Some(0xdead)); +/// +/// *guard.get_mut(0).unwrap() = 0xffff; +/// assert_eq!(guard.get(0).copied(), Some(0xffff)); +/// +/// assert_eq!(guard.store(0, beef, GFP_KERNEL)?.as_deref().copied(), Some(0xffff)); +/// assert_eq!(guard.get(0).copied(), Some(0xbeef)); +/// +/// guard.remove(0); +/// assert_eq!(guard.get(0), None); +/// +/// # Ok::<(), Error>(()) +/// ``` +#[pin_data(PinnedDrop)] +pub struct XArray<T: ForeignOwnable> { + #[pin] + xa: Opaque<bindings::xarray>, + _p: PhantomData<T>, +} + +#[pinned_drop] +impl<T: ForeignOwnable> PinnedDrop for XArray<T> { + fn drop(self: Pin<&mut Self>) { + self.iter().for_each(|ptr| { + let ptr = ptr.as_ptr(); + // SAFETY: `ptr` came from `T::into_foreign`. + // + // INVARIANT: we own the only reference to the array which is being dropped so the + // broken invariant is not observable on function exit. + drop(unsafe { T::from_foreign(ptr) }) + }); + + // SAFETY: `self.xa` is always valid by the type invariant. + unsafe { bindings::xa_destroy(self.xa.get()) }; + } +} + +/// Flags passed to [`XArray::new`] to configure the array's allocation tracking behavior. +pub enum AllocKind { + /// Consider the first element to be at index 0. + Alloc, + /// Consider the first element to be at index 1. + Alloc1, +} + +impl<T: ForeignOwnable> XArray<T> { + /// Creates a new initializer for this type. + pub fn new(kind: AllocKind) -> impl PinInit<Self> { + let flags = match kind { + AllocKind::Alloc => bindings::XA_FLAGS_ALLOC, + AllocKind::Alloc1 => bindings::XA_FLAGS_ALLOC1, + }; + pin_init!(Self { + // SAFETY: `xa` is valid while the closure is called. + // + // INVARIANT: `xa` is initialized here to an empty, valid [`bindings::xarray`]. + xa <- Opaque::ffi_init(|xa| unsafe { + bindings::xa_init_flags(xa, flags) + }), + _p: PhantomData, + }) + } + + fn iter(&self) -> impl Iterator<Item = NonNull<c_void>> + '_ { + let mut index = 0; + + // SAFETY: `self.xa` is always valid by the type invariant. + iter::once(unsafe { + bindings::xa_find(self.xa.get(), &mut index, usize::MAX, bindings::XA_PRESENT) + }) + .chain(iter::from_fn(move || { + // SAFETY: `self.xa` is always valid by the type invariant. + Some(unsafe { + bindings::xa_find_after(self.xa.get(), &mut index, usize::MAX, bindings::XA_PRESENT) + }) + })) + .map_while(|ptr| NonNull::new(ptr.cast())) + } + + /// Attempts to lock the [`XArray`] for exclusive access. + pub fn try_lock(&self) -> Option<Guard<'_, T>> { + // SAFETY: `self.xa` is always valid by the type invariant. + if (unsafe { bindings::xa_trylock(self.xa.get()) } != 0) { + Some(Guard { + xa: self, + _not_send: NotThreadSafe, + }) + } else { + None + } + } + + /// Locks the [`XArray`] for exclusive access. + pub fn lock(&self) -> Guard<'_, T> { + // SAFETY: `self.xa` is always valid by the type invariant. + unsafe { bindings::xa_lock(self.xa.get()) }; + + Guard { + xa: self, + _not_send: NotThreadSafe, + } + } +} + +/// A lock guard. +/// +/// The lock is unlocked when the guard goes out of scope. +#[must_use = "the lock unlocks immediately when the guard is unused"] +pub struct Guard<'a, T: ForeignOwnable> { + xa: &'a XArray<T>, + _not_send: NotThreadSafe, +} + +impl<T: ForeignOwnable> Drop for Guard<'_, T> { + fn drop(&mut self) { + // SAFETY: + // - `self.xa.xa` is always valid by the type invariant. + // - The caller holds the lock, so it is safe to unlock it. + unsafe { bindings::xa_unlock(self.xa.xa.get()) }; + } +} + +/// The error returned by [`store`](Guard::store). +/// +/// Contains the underlying error and the value that was not stored. +pub struct StoreError<T> { + /// The error that occurred. + pub error: Error, + /// The value that was not stored. + pub value: T, +} + +impl<T> From<StoreError<T>> for Error { + fn from(value: StoreError<T>) -> Self { + value.error + } +} + +impl<'a, T: ForeignOwnable> Guard<'a, T> { + fn load<F, U>(&self, index: usize, f: F) -> Option<U> + where + F: FnOnce(NonNull<c_void>) -> U, + { + // SAFETY: `self.xa.xa` is always valid by the type invariant. + let ptr = unsafe { bindings::xa_load(self.xa.xa.get(), index) }; + let ptr = NonNull::new(ptr.cast())?; + Some(f(ptr)) + } + + /// Provides a reference to the element at the given index. + pub fn get(&self, index: usize) -> Option<T::Borrowed<'_>> { + self.load(index, |ptr| { + // SAFETY: `ptr` came from `T::into_foreign`. + unsafe { T::borrow(ptr.as_ptr()) } + }) + } + + /// Provides a mutable reference to the element at the given index. + pub fn get_mut(&mut self, index: usize) -> Option<T::BorrowedMut<'_>> { + self.load(index, |ptr| { + // SAFETY: `ptr` came from `T::into_foreign`. + unsafe { T::borrow_mut(ptr.as_ptr()) } + }) + } + + /// Removes and returns the element at the given index. + pub fn remove(&mut self, index: usize) -> Option<T> { + // SAFETY: + // - `self.xa.xa` is always valid by the type invariant. + // - The caller holds the lock. + let ptr = unsafe { bindings::__xa_erase(self.xa.xa.get(), index) }.cast(); + // SAFETY: + // - `ptr` is either NULL or came from `T::into_foreign`. + // - `&mut self` guarantees that the lifetimes of [`T::Borrowed`] and [`T::BorrowedMut`] + // borrowed from `self` have ended. + unsafe { T::try_from_foreign(ptr) } + } + + /// Stores an element at the given index. + /// + /// May drop the lock if needed to allocate memory, and then reacquire it afterwards. + /// + /// On success, returns the element which was previously at the given index. + /// + /// On failure, returns the element which was attempted to be stored. + pub fn store( + &mut self, + index: usize, + value: T, + gfp: alloc::Flags, + ) -> Result<Option<T>, StoreError<T>> { + build_assert!( + T::FOREIGN_ALIGN >= 4, + "pointers stored in XArray must be 4-byte aligned" + ); + let new = value.into_foreign(); + + let old = { + let new = new.cast(); + // SAFETY: + // - `self.xa.xa` is always valid by the type invariant. + // - The caller holds the lock. + // + // INVARIANT: `new` came from `T::into_foreign`. + unsafe { bindings::__xa_store(self.xa.xa.get(), index, new, gfp.as_raw()) } + }; + + // SAFETY: `__xa_store` returns the old entry at this index on success or `xa_err` if an + // error happened. + let errno = unsafe { bindings::xa_err(old) }; + if errno != 0 { + // SAFETY: `new` came from `T::into_foreign` and `__xa_store` does not take + // ownership of the value on error. + let value = unsafe { T::from_foreign(new) }; + Err(StoreError { + value, + error: Error::from_errno(errno), + }) + } else { + let old = old.cast(); + // SAFETY: `ptr` is either NULL or came from `T::into_foreign`. + // + // NB: `XA_ZERO_ENTRY` is never returned by functions belonging to the Normal XArray + // API; such entries present as `NULL`. + Ok(unsafe { T::try_from_foreign(old) }) + } + } +} + +// SAFETY: `XArray<T>` has no shared mutable state so it is `Send` iff `T` is `Send`. +unsafe impl<T: ForeignOwnable + Send> Send for XArray<T> {} + +// SAFETY: `XArray<T>` serialises the interior mutability it provides so it is `Sync` iff `T` is +// `Send`. +unsafe impl<T: ForeignOwnable + Send> Sync for XArray<T> {} diff --git a/rust/macros/export.rs b/rust/macros/export.rs new file mode 100644 index 000000000000..a08f6337d5c8 --- /dev/null +++ b/rust/macros/export.rs @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: GPL-2.0 + +use crate::helpers::function_name; +use proc_macro::TokenStream; + +/// Please see [`crate::export`] for documentation. +pub(crate) fn export(_attr: TokenStream, ts: TokenStream) -> TokenStream { + let Some(name) = function_name(ts.clone()) else { + return "::core::compile_error!(\"The #[export] attribute must be used on a function.\");" + .parse::<TokenStream>() + .unwrap(); + }; + + // This verifies that the function has the same signature as the declaration generated by + // bindgen. It makes use of the fact that all branches of an if/else must have the same type. + let signature_check = quote!( + const _: () = { + if true { + ::kernel::bindings::#name + } else { + #name + }; + }; + ); + + let no_mangle = quote!(#[no_mangle]); + + TokenStream::from_iter([signature_check, no_mangle, ts]) +} diff --git a/rust/macros/helpers.rs b/rust/macros/helpers.rs index 563dcd2b7ace..e2602be402c1 100644 --- a/rust/macros/helpers.rs +++ b/rust/macros/helpers.rs @@ -1,6 +1,6 @@ // SPDX-License-Identifier: GPL-2.0 -use proc_macro::{token_stream, Group, TokenStream, TokenTree}; +use proc_macro::{token_stream, Group, Ident, TokenStream, TokenTree}; pub(crate) fn try_ident(it: &mut token_stream::IntoIter) -> Option<String> { if let Some(TokenTree::Ident(ident)) = it.next() { @@ -70,148 +70,36 @@ pub(crate) fn expect_end(it: &mut token_stream::IntoIter) { } } -/// Parsed generics. -/// -/// See the field documentation for an explanation what each of the fields represents. -/// -/// # Examples -/// -/// ```rust,ignore -/// # let input = todo!(); -/// let (Generics { decl_generics, impl_generics, ty_generics }, rest) = parse_generics(input); -/// quote! { -/// struct Foo<$($decl_generics)*> { -/// // ... -/// } -/// -/// impl<$impl_generics> Foo<$ty_generics> { -/// fn foo() { -/// // ... -/// } -/// } -/// } -/// ``` -pub(crate) struct Generics { - /// The generics with bounds and default values (e.g. `T: Clone, const N: usize = 0`). - /// - /// Use this on type definitions e.g. `struct Foo<$decl_generics> ...` (or `union`/`enum`). - pub(crate) decl_generics: Vec<TokenTree>, - /// The generics with bounds (e.g. `T: Clone, const N: usize`). - /// - /// Use this on `impl` blocks e.g. `impl<$impl_generics> Trait for ...`. - pub(crate) impl_generics: Vec<TokenTree>, - /// The generics without bounds and without default values (e.g. `T, N`). - /// - /// Use this when you use the type that is declared with these generics e.g. - /// `Foo<$ty_generics>`. - pub(crate) ty_generics: Vec<TokenTree>, -} - -/// Parses the given `TokenStream` into `Generics` and the rest. -/// -/// The generics are not present in the rest, but a where clause might remain. -pub(crate) fn parse_generics(input: TokenStream) -> (Generics, Vec<TokenTree>) { - // The generics with bounds and default values. - let mut decl_generics = vec![]; - // `impl_generics`, the declared generics with their bounds. - let mut impl_generics = vec![]; - // Only the names of the generics, without any bounds. - let mut ty_generics = vec![]; - // Tokens not related to the generics e.g. the `where` token and definition. - let mut rest = vec![]; - // The current level of `<`. - let mut nesting = 0; - let mut toks = input.into_iter(); - // If we are at the beginning of a generic parameter. - let mut at_start = true; - let mut skip_until_comma = false; - while let Some(tt) = toks.next() { - if nesting == 1 && matches!(&tt, TokenTree::Punct(p) if p.as_char() == '>') { - // Found the end of the generics. - break; - } else if nesting >= 1 { - decl_generics.push(tt.clone()); - } - match tt.clone() { - TokenTree::Punct(p) if p.as_char() == '<' => { - if nesting >= 1 && !skip_until_comma { - // This is inside of the generics and part of some bound. - impl_generics.push(tt); - } - nesting += 1; - } - TokenTree::Punct(p) if p.as_char() == '>' => { - // This is a parsing error, so we just end it here. - if nesting == 0 { - break; - } else { - nesting -= 1; - if nesting >= 1 && !skip_until_comma { - // We are still inside of the generics and part of some bound. - impl_generics.push(tt); - } - } - } - TokenTree::Punct(p) if skip_until_comma && p.as_char() == ',' => { - if nesting == 1 { - impl_generics.push(tt.clone()); - impl_generics.push(tt); - skip_until_comma = false; +/// Given a function declaration, finds the name of the function. +pub(crate) fn function_name(input: TokenStream) -> Option<Ident> { + let mut input = input.into_iter(); + while let Some(token) = input.next() { + match token { + TokenTree::Ident(i) if i.to_string() == "fn" => { + if let Some(TokenTree::Ident(i)) = input.next() { + return Some(i); } + return None; } - _ if !skip_until_comma => { - match nesting { - // If we haven't entered the generics yet, we still want to keep these tokens. - 0 => rest.push(tt), - 1 => { - // Here depending on the token, it might be a generic variable name. - match tt.clone() { - TokenTree::Ident(i) if at_start && i.to_string() == "const" => { - let Some(name) = toks.next() else { - // Parsing error. - break; - }; - impl_generics.push(tt); - impl_generics.push(name.clone()); - ty_generics.push(name.clone()); - decl_generics.push(name); - at_start = false; - } - TokenTree::Ident(_) if at_start => { - impl_generics.push(tt.clone()); - ty_generics.push(tt); - at_start = false; - } - TokenTree::Punct(p) if p.as_char() == ',' => { - impl_generics.push(tt.clone()); - ty_generics.push(tt); - at_start = true; - } - // Lifetimes begin with `'`. - TokenTree::Punct(p) if p.as_char() == '\'' && at_start => { - impl_generics.push(tt.clone()); - ty_generics.push(tt); - } - // Generics can have default values, we skip these. - TokenTree::Punct(p) if p.as_char() == '=' => { - skip_until_comma = true; - } - _ => impl_generics.push(tt), - } - } - _ => impl_generics.push(tt), - } - } - _ => {} + _ => continue, } } - rest.extend(toks); - ( - Generics { - impl_generics, - decl_generics, - ty_generics, - }, - rest, - ) + None +} + +pub(crate) fn file() -> String { + #[cfg(not(CONFIG_RUSTC_HAS_SPAN_FILE))] + { + proc_macro::Span::call_site() + .source_file() + .path() + .to_string_lossy() + .into_owned() + } + + #[cfg(CONFIG_RUSTC_HAS_SPAN_FILE)] + #[allow(clippy::incompatible_msrv)] + { + proc_macro::Span::call_site().file() + } } diff --git a/rust/macros/kunit.rs b/rust/macros/kunit.rs new file mode 100644 index 000000000000..81d18149a0cc --- /dev/null +++ b/rust/macros/kunit.rs @@ -0,0 +1,182 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Procedural macro to run KUnit tests using a user-space like syntax. +//! +//! Copyright (c) 2023 José Expósito <jose.exposito89@gmail.com> + +use proc_macro::{Delimiter, Group, TokenStream, TokenTree}; +use std::fmt::Write; + +pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream { + let attr = attr.to_string(); + + if attr.is_empty() { + panic!("Missing test name in `#[kunit_tests(test_name)]` macro") + } + + if attr.len() > 255 { + panic!("The test suite name `{attr}` exceeds the maximum length of 255 bytes") + } + + let mut tokens: Vec<_> = ts.into_iter().collect(); + + // Scan for the `mod` keyword. + tokens + .iter() + .find_map(|token| match token { + TokenTree::Ident(ident) => match ident.to_string().as_str() { + "mod" => Some(true), + _ => None, + }, + _ => None, + }) + .expect("`#[kunit_tests(test_name)]` attribute should only be applied to modules"); + + // Retrieve the main body. The main body should be the last token tree. + let body = match tokens.pop() { + Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group, + _ => panic!("Cannot locate main body of module"), + }; + + // Get the functions set as tests. Search for `[test]` -> `fn`. + let mut body_it = body.stream().into_iter(); + let mut tests = Vec::new(); + while let Some(token) = body_it.next() { + match token { + TokenTree::Group(ident) if ident.to_string() == "[test]" => match body_it.next() { + Some(TokenTree::Ident(ident)) if ident.to_string() == "fn" => { + let test_name = match body_it.next() { + Some(TokenTree::Ident(ident)) => ident.to_string(), + _ => continue, + }; + tests.push(test_name); + } + _ => continue, + }, + _ => (), + } + } + + // Add `#[cfg(CONFIG_KUNIT="y")]` before the module declaration. + let config_kunit = "#[cfg(CONFIG_KUNIT=\"y\")]".to_owned().parse().unwrap(); + tokens.insert( + 0, + TokenTree::Group(Group::new(Delimiter::None, config_kunit)), + ); + + // Generate the test KUnit test suite and a test case for each `#[test]`. + // The code generated for the following test module: + // + // ``` + // #[kunit_tests(kunit_test_suit_name)] + // mod tests { + // #[test] + // fn foo() { + // assert_eq!(1, 1); + // } + // + // #[test] + // fn bar() { + // assert_eq!(2, 2); + // } + // } + // ``` + // + // Looks like: + // + // ``` + // unsafe extern "C" fn kunit_rust_wrapper_foo(_test: *mut ::kernel::bindings::kunit) { foo(); } + // unsafe extern "C" fn kunit_rust_wrapper_bar(_test: *mut ::kernel::bindings::kunit) { bar(); } + // + // static mut TEST_CASES: [::kernel::bindings::kunit_case; 3] = [ + // ::kernel::kunit::kunit_case(::kernel::c_str!("foo"), kunit_rust_wrapper_foo), + // ::kernel::kunit::kunit_case(::kernel::c_str!("bar"), kunit_rust_wrapper_bar), + // ::kernel::kunit::kunit_case_null(), + // ]; + // + // ::kernel::kunit_unsafe_test_suite!(kunit_test_suit_name, TEST_CASES); + // ``` + let mut kunit_macros = "".to_owned(); + let mut test_cases = "".to_owned(); + let mut assert_macros = "".to_owned(); + let path = crate::helpers::file(); + for test in &tests { + let kunit_wrapper_fn_name = format!("kunit_rust_wrapper_{test}"); + // An extra `use` is used here to reduce the length of the message. + let kunit_wrapper = format!( + "unsafe extern \"C\" fn {kunit_wrapper_fn_name}(_test: *mut ::kernel::bindings::kunit) {{ use ::kernel::kunit::is_test_result_ok; assert!(is_test_result_ok({test}())); }}", + ); + writeln!(kunit_macros, "{kunit_wrapper}").unwrap(); + writeln!( + test_cases, + " ::kernel::kunit::kunit_case(::kernel::c_str!(\"{test}\"), {kunit_wrapper_fn_name})," + ) + .unwrap(); + writeln!( + assert_macros, + r#" +/// Overrides the usual [`assert!`] macro with one that calls KUnit instead. +#[allow(unused)] +macro_rules! assert {{ + ($cond:expr $(,)?) => {{{{ + kernel::kunit_assert!("{test}", "{path}", 0, $cond); + }}}} +}} + +/// Overrides the usual [`assert_eq!`] macro with one that calls KUnit instead. +#[allow(unused)] +macro_rules! assert_eq {{ + ($left:expr, $right:expr $(,)?) => {{{{ + kernel::kunit_assert_eq!("{test}", "{path}", 0, $left, $right); + }}}} +}} + "# + ) + .unwrap(); + } + + writeln!(kunit_macros).unwrap(); + writeln!( + kunit_macros, + "static mut TEST_CASES: [::kernel::bindings::kunit_case; {}] = [\n{test_cases} ::kernel::kunit::kunit_case_null(),\n];", + tests.len() + 1 + ) + .unwrap(); + + writeln!( + kunit_macros, + "::kernel::kunit_unsafe_test_suite!({attr}, TEST_CASES);" + ) + .unwrap(); + + // Remove the `#[test]` macros. + // We do this at a token level, in order to preserve span information. + let mut new_body = vec![]; + let mut body_it = body.stream().into_iter(); + + while let Some(token) = body_it.next() { + match token { + TokenTree::Punct(ref c) if c.as_char() == '#' => match body_it.next() { + Some(TokenTree::Group(group)) if group.to_string() == "[test]" => (), + Some(next) => { + new_body.extend([token, next]); + } + _ => { + new_body.push(token); + } + }, + _ => { + new_body.push(token); + } + } + } + + let mut final_body = TokenStream::new(); + final_body.extend::<TokenStream>(assert_macros.parse().unwrap()); + final_body.extend(new_body); + final_body.extend::<TokenStream>(kunit_macros.parse().unwrap()); + + tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, final_body))); + + tokens.into_iter().collect() +} diff --git a/rust/macros/lib.rs b/rust/macros/lib.rs index 4ab94e44adfe..fa847cf3a9b5 100644 --- a/rust/macros/lib.rs +++ b/rust/macros/lib.rs @@ -6,16 +6,20 @@ // and thus add a dependency on `include/config/RUSTC_VERSION_TEXT`, which is // touched by Kconfig when the version string from the compiler changes. +// Stable since Rust 1.88.0 under a different name, `proc_macro_span_file`, +// which was added in Rust 1.88.0. This is why `cfg_attr` is used here, i.e. +// to avoid depending on the full `proc_macro_span` on Rust >= 1.88.0. +#![cfg_attr(not(CONFIG_RUSTC_HAS_SPAN_FILE), feature(proc_macro_span))] + #[macro_use] mod quote; mod concat_idents; +mod export; mod helpers; +mod kunit; mod module; mod paste; -mod pin_data; -mod pinned_drop; mod vtable; -mod zeroable; use proc_macro::TokenStream; @@ -36,7 +40,7 @@ use proc_macro::TokenStream; /// module!{ /// type: MyModule, /// name: "my_kernel_module", -/// author: "Rust for Linux Contributors", +/// authors: ["Rust for Linux Contributors"], /// description: "My very own kernel module!", /// license: "GPL", /// alias: ["alternate_module_name"], @@ -69,7 +73,7 @@ use proc_macro::TokenStream; /// module!{ /// type: MyDeviceDriverModule, /// name: "my_device_driver_module", -/// author: "Rust for Linux Contributors", +/// authors: ["Rust for Linux Contributors"], /// description: "My device driver requires firmware", /// license: "GPL", /// firmware: ["my_device_firmware1.bin", "my_device_firmware2.bin"], @@ -88,7 +92,7 @@ use proc_macro::TokenStream; /// # Supported argument types /// - `type`: type which implements the [`Module`] trait (required). /// - `name`: ASCII string literal of the name of the kernel module (required). -/// - `author`: string literal of the author of the kernel module. +/// - `authors`: array of ASCII string literals of the authors of the kernel module. /// - `description`: string literal of the description of the kernel module. /// - `license`: ASCII string literal of the license of the kernel module (required). /// - `alias`: array of ASCII string literals of the alias names of the kernel module. @@ -123,12 +127,12 @@ pub fn module(ts: TokenStream) -> TokenStream { /// used on the Rust side, it should not be possible to call the default /// implementation. This is done to ensure that we call the vtable methods /// through the C vtable, and not through the Rust vtable. Therefore, the -/// default implementation should call `kernel::build_error`, which prevents +/// default implementation should call `build_error!`, which prevents /// calls to this function at compile time: /// /// ```compile_fail /// # // Intentionally missing `use`s to simplify `rusttest`. -/// kernel::build_error(VTABLE_DEFAULT_ERROR) +/// build_error!(VTABLE_DEFAULT_ERROR) /// ``` /// /// Note that you might need to import [`kernel::error::VTABLE_DEFAULT_ERROR`]. @@ -145,11 +149,11 @@ pub fn module(ts: TokenStream) -> TokenStream { /// #[vtable] /// pub trait Operations: Send + Sync + Sized { /// fn foo(&self) -> Result<()> { -/// kernel::build_error(VTABLE_DEFAULT_ERROR) +/// build_error!(VTABLE_DEFAULT_ERROR) /// } /// /// fn bar(&self) -> Result<()> { -/// kernel::build_error(VTABLE_DEFAULT_ERROR) +/// build_error!(VTABLE_DEFAULT_ERROR) /// } /// } /// @@ -174,6 +178,29 @@ pub fn vtable(attr: TokenStream, ts: TokenStream) -> TokenStream { vtable::vtable(attr, ts) } +/// Export a function so that C code can call it via a header file. +/// +/// Functions exported using this macro can be called from C code using the declaration in the +/// appropriate header file. It should only be used in cases where C calls the function through a +/// header file; cases where C calls into Rust via a function pointer in a vtable (such as +/// `file_operations`) should not use this macro. +/// +/// This macro has the following effect: +/// +/// * Disables name mangling for this function. +/// * Verifies at compile-time that the function signature matches the declaration in the header +/// file. +/// +/// You must declare the signature of the Rust function in a header file that is included by +/// `rust/bindings/bindings_helper.h`. +/// +/// This macro is *not* the same as the C macros `EXPORT_SYMBOL_*`. All Rust symbols are currently +/// automatically exported with `EXPORT_SYMBOL_GPL`. +#[proc_macro_attribute] +pub fn export(attr: TokenStream, ts: TokenStream) -> TokenStream { + export::export(attr, ts) +} + /// Concatenate two identifiers. /// /// This is useful in macros that need to declare or reference items with names @@ -232,106 +259,6 @@ pub fn concat_idents(ts: TokenStream) -> TokenStream { concat_idents::concat_idents(ts) } -/// Used to specify the pinning information of the fields of a struct. -/// -/// This is somewhat similar in purpose as -/// [pin-project-lite](https://crates.io/crates/pin-project-lite). -/// Place this macro on a struct definition and then `#[pin]` in front of the attributes of each -/// field you want to structurally pin. -/// -/// This macro enables the use of the [`pin_init!`] macro. When pin-initializing a `struct`, -/// then `#[pin]` directs the type of initializer that is required. -/// -/// If your `struct` implements `Drop`, then you need to add `PinnedDrop` as arguments to this -/// macro, and change your `Drop` implementation to `PinnedDrop` annotated with -/// `#[`[`macro@pinned_drop`]`]`, since dropping pinned values requires extra care. -/// -/// # Examples -/// -/// ``` -/// # #![feature(lint_reasons)] -/// # use kernel::prelude::*; -/// # use std::{sync::Mutex, process::Command}; -/// # use kernel::macros::pin_data; -/// #[pin_data] -/// struct DriverData { -/// #[pin] -/// queue: Mutex<KVec<Command>>, -/// buf: KBox<[u8; 1024 * 1024]>, -/// } -/// ``` -/// -/// ``` -/// # #![feature(lint_reasons)] -/// # use kernel::prelude::*; -/// # use std::{sync::Mutex, process::Command}; -/// # use core::pin::Pin; -/// # pub struct Info; -/// # mod bindings { -/// # pub unsafe fn destroy_info(_ptr: *mut super::Info) {} -/// # } -/// use kernel::macros::{pin_data, pinned_drop}; -/// -/// #[pin_data(PinnedDrop)] -/// struct DriverData { -/// #[pin] -/// queue: Mutex<KVec<Command>>, -/// buf: KBox<[u8; 1024 * 1024]>, -/// raw_info: *mut Info, -/// } -/// -/// #[pinned_drop] -/// impl PinnedDrop for DriverData { -/// fn drop(self: Pin<&mut Self>) { -/// unsafe { bindings::destroy_info(self.raw_info) }; -/// } -/// } -/// # fn main() {} -/// ``` -/// -/// [`pin_init!`]: ../kernel/macro.pin_init.html -// ^ cannot use direct link, since `kernel` is not a dependency of `macros`. -#[proc_macro_attribute] -pub fn pin_data(inner: TokenStream, item: TokenStream) -> TokenStream { - pin_data::pin_data(inner, item) -} - -/// Used to implement `PinnedDrop` safely. -/// -/// Only works on structs that are annotated via `#[`[`macro@pin_data`]`]`. -/// -/// # Examples -/// -/// ``` -/// # #![feature(lint_reasons)] -/// # use kernel::prelude::*; -/// # use macros::{pin_data, pinned_drop}; -/// # use std::{sync::Mutex, process::Command}; -/// # use core::pin::Pin; -/// # mod bindings { -/// # pub struct Info; -/// # pub unsafe fn destroy_info(_ptr: *mut Info) {} -/// # } -/// #[pin_data(PinnedDrop)] -/// struct DriverData { -/// #[pin] -/// queue: Mutex<KVec<Command>>, -/// buf: KBox<[u8; 1024 * 1024]>, -/// raw_info: *mut bindings::Info, -/// } -/// -/// #[pinned_drop] -/// impl PinnedDrop for DriverData { -/// fn drop(self: Pin<&mut Self>) { -/// unsafe { bindings::destroy_info(self.raw_info) }; -/// } -/// } -/// ``` -#[proc_macro_attribute] -pub fn pinned_drop(args: TokenStream, input: TokenStream) -> TokenStream { - pinned_drop::pinned_drop(args, input) -} - /// Paste identifiers together. /// /// Within the `paste!` macro, identifiers inside `[<` and `>]` are concatenated together to form a @@ -341,7 +268,7 @@ pub fn pinned_drop(args: TokenStream, input: TokenStream) -> TokenStream { /// literals (lifetimes and documentation strings are not supported). There is a difference in /// supported modifiers as well. /// -/// # Example +/// # Examples /// /// ``` /// # const binder_driver_return_protocol_BR_OK: u32 = 0; @@ -361,7 +288,7 @@ pub fn pinned_drop(args: TokenStream, input: TokenStream) -> TokenStream { /// # const binder_driver_return_protocol_BR_FAILED_REPLY: u32 = 14; /// macro_rules! pub_no_prefix { /// ($prefix:ident, $($newname:ident),+) => { -/// kernel::macros::paste! { +/// ::kernel::macros::paste! { /// $(pub(crate) const $newname: u32 = [<$prefix $newname>];)+ /// } /// }; @@ -418,7 +345,7 @@ pub fn pinned_drop(args: TokenStream, input: TokenStream) -> TokenStream { /// # const binder_driver_return_protocol_BR_FAILED_REPLY: u32 = 14; /// macro_rules! pub_no_prefix { /// ($prefix:ident, $($newname:ident),+) => { -/// kernel::macros::paste! { +/// ::kernel::macros::paste! { /// $(pub(crate) const fn [<$newname:lower:span>]() -> u32 { [<$prefix $newname:span>] })+ /// } /// }; @@ -453,7 +380,7 @@ pub fn pinned_drop(args: TokenStream, input: TokenStream) -> TokenStream { /// ``` /// macro_rules! create_numbered_fn { /// ($name:literal, $val:literal) => { -/// kernel::macros::paste! { +/// ::kernel::macros::paste! { /// fn [<some_ $name _fn $val>]() -> u32 { $val } /// } /// }; @@ -472,23 +399,29 @@ pub fn paste(input: TokenStream) -> TokenStream { tokens.into_iter().collect() } -/// Derives the [`Zeroable`] trait for the given struct. +/// Registers a KUnit test suite and its test cases using a user-space like syntax. /// -/// This can only be used for structs where every field implements the [`Zeroable`] trait. +/// This macro should be used on modules. If `CONFIG_KUNIT` (in `.config`) is `n`, the target module +/// is ignored. /// /// # Examples /// -/// ``` -/// use kernel::macros::Zeroable; +/// ```ignore +/// # use kernel::prelude::*; +/// #[kunit_tests(kunit_test_suit_name)] +/// mod tests { +/// #[test] +/// fn foo() { +/// assert_eq!(1, 1); +/// } /// -/// #[derive(Zeroable)] -/// pub struct DriverData { -/// id: i64, -/// buf_ptr: *mut u8, -/// len: usize, +/// #[test] +/// fn bar() { +/// assert_eq!(2, 2); +/// } /// } /// ``` -#[proc_macro_derive(Zeroable)] -pub fn derive_zeroable(input: TokenStream) -> TokenStream { - zeroable::derive(input) +#[proc_macro_attribute] +pub fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream { + kunit::kunit_tests(attr, ts) } diff --git a/rust/macros/module.rs b/rust/macros/module.rs index 2587f41b0d39..5ee54a00c0b6 100644 --- a/rust/macros/module.rs +++ b/rust/macros/module.rs @@ -48,7 +48,7 @@ impl<'a> ModInfoBuilder<'a> { ) } else { // Loadable modules' modinfo strings go as-is. - format!("{field}={content}\0", field = field, content = content) + format!("{field}={content}\0") }; write!( @@ -56,8 +56,8 @@ impl<'a> ModInfoBuilder<'a> { " {cfg} #[doc(hidden)] - #[link_section = \".modinfo\"] - #[used] + #[cfg_attr(not(target_os = \"macos\"), link_section = \".modinfo\")] + #[used(compiler)] pub static __{module}_{counter}: [u8; {length}] = *{string}; ", cfg = if builtin { @@ -94,7 +94,7 @@ struct ModuleInfo { type_: String, license: String, name: String, - author: Option<String>, + authors: Option<Vec<String>>, description: Option<String>, alias: Option<Vec<String>>, firmware: Option<Vec<String>>, @@ -107,7 +107,7 @@ impl ModuleInfo { const EXPECTED_KEYS: &[&str] = &[ "type", "name", - "author", + "authors", "description", "license", "alias", @@ -124,10 +124,7 @@ impl ModuleInfo { }; if seen_keys.contains(&key) { - panic!( - "Duplicated key \"{}\". Keys can only be specified once.", - key - ); + panic!("Duplicated key \"{key}\". Keys can only be specified once."); } assert_eq!(expect_punct(it), ':'); @@ -135,15 +132,12 @@ impl ModuleInfo { match key.as_str() { "type" => info.type_ = expect_ident(it), "name" => info.name = expect_string_ascii(it), - "author" => info.author = Some(expect_string(it)), + "authors" => info.authors = Some(expect_string_array(it)), "description" => info.description = Some(expect_string(it)), "license" => info.license = expect_string_ascii(it), "alias" => info.alias = Some(expect_string_array(it)), "firmware" => info.firmware = Some(expect_string_array(it)), - _ => panic!( - "Unknown key \"{}\". Valid keys are: {:?}.", - key, EXPECTED_KEYS - ), + _ => panic!("Unknown key \"{key}\". Valid keys are: {EXPECTED_KEYS:?}."), } assert_eq!(expect_punct(it), ','); @@ -155,7 +149,7 @@ impl ModuleInfo { for key in REQUIRED_KEYS { if !seen_keys.iter().any(|e| e == key) { - panic!("Missing required key \"{}\".", key); + panic!("Missing required key \"{key}\"."); } } @@ -167,10 +161,7 @@ impl ModuleInfo { } if seen_keys != ordered_keys { - panic!( - "Keys are not ordered as expected. Order them like: {:?}.", - ordered_keys - ); + panic!("Keys are not ordered as expected. Order them like: {ordered_keys:?}."); } info @@ -182,9 +173,13 @@ pub(crate) fn module(ts: TokenStream) -> TokenStream { let info = ModuleInfo::parse(&mut it); - let mut modinfo = ModInfoBuilder::new(info.name.as_ref()); - if let Some(author) = info.author { - modinfo.emit("author", &author); + // Rust does not allow hyphens in identifiers, use underscore instead. + let ident = info.name.replace('-', "_"); + let mut modinfo = ModInfoBuilder::new(ident.as_ref()); + if let Some(authors) = info.authors { + for author in authors { + modinfo.emit("author", &author); + } } if let Some(description) = info.description { modinfo.emit("description", &description); @@ -216,23 +211,31 @@ pub(crate) fn module(ts: TokenStream) -> TokenStream { // SAFETY: `__this_module` is constructed by the kernel at load time and will not be // freed until the module is unloaded. #[cfg(MODULE)] - static THIS_MODULE: kernel::ThisModule = unsafe {{ + static THIS_MODULE: ::kernel::ThisModule = unsafe {{ extern \"C\" {{ - static __this_module: kernel::types::Opaque<kernel::bindings::module>; + static __this_module: ::kernel::types::Opaque<::kernel::bindings::module>; }} - kernel::ThisModule::from_ptr(__this_module.get()) + ::kernel::ThisModule::from_ptr(__this_module.get()) }}; #[cfg(not(MODULE))] - static THIS_MODULE: kernel::ThisModule = unsafe {{ - kernel::ThisModule::from_ptr(core::ptr::null_mut()) + static THIS_MODULE: ::kernel::ThisModule = unsafe {{ + ::kernel::ThisModule::from_ptr(::core::ptr::null_mut()) }}; + /// The `LocalModule` type is the type of the module created by `module!`, + /// `module_pci_driver!`, `module_platform_driver!`, etc. + type LocalModule = {type_}; + + impl ::kernel::ModuleMetadata for {type_} {{ + const NAME: &'static ::kernel::str::CStr = ::kernel::c_str!(\"{name}\"); + }} + // Double nested modules, since then nobody can access the public items inside. mod __module_init {{ mod __module_init {{ use super::super::{type_}; - use kernel::init::PinInit; + use pin_init::PinInit; /// The \"Rust loadable module\" mark. // @@ -240,11 +243,11 @@ pub(crate) fn module(ts: TokenStream) -> TokenStream { // key or a new section. For the moment, keep it simple. #[cfg(MODULE)] #[doc(hidden)] - #[used] + #[used(compiler)] static __IS_RUST_MODULE: () = (); - static mut __MOD: core::mem::MaybeUninit<{type_}> = - core::mem::MaybeUninit::uninit(); + static mut __MOD: ::core::mem::MaybeUninit<{type_}> = + ::core::mem::MaybeUninit::uninit(); // Loadable modules need to export the `{{init,cleanup}}_module` identifiers. /// # Safety @@ -255,7 +258,7 @@ pub(crate) fn module(ts: TokenStream) -> TokenStream { #[doc(hidden)] #[no_mangle] #[link_section = \".init.text\"] - pub unsafe extern \"C\" fn init_module() -> kernel::ffi::c_int {{ + pub unsafe extern \"C\" fn init_module() -> ::kernel::ffi::c_int {{ // SAFETY: This function is inaccessible to the outside due to the double // module wrapping it. It is called exactly once by the C side via its // unique name. @@ -264,13 +267,14 @@ pub(crate) fn module(ts: TokenStream) -> TokenStream { #[cfg(MODULE)] #[doc(hidden)] - #[used] + #[used(compiler)] #[link_section = \".init.data\"] static __UNIQUE_ID___addressable_init_module: unsafe extern \"C\" fn() -> i32 = init_module; #[cfg(MODULE)] #[doc(hidden)] #[no_mangle] + #[link_section = \".exit.text\"] pub extern \"C\" fn cleanup_module() {{ // SAFETY: // - This function is inaccessible to the outside due to the double @@ -283,7 +287,7 @@ pub(crate) fn module(ts: TokenStream) -> TokenStream { #[cfg(MODULE)] #[doc(hidden)] - #[used] + #[used(compiler)] #[link_section = \".exit.data\"] static __UNIQUE_ID___addressable_cleanup_module: extern \"C\" fn() = cleanup_module; @@ -293,15 +297,16 @@ pub(crate) fn module(ts: TokenStream) -> TokenStream { #[cfg(not(CONFIG_HAVE_ARCH_PREL32_RELOCATIONS))] #[doc(hidden)] #[link_section = \"{initcall_section}\"] - #[used] - pub static __{name}_initcall: extern \"C\" fn() -> kernel::ffi::c_int = __{name}_init; + #[used(compiler)] + pub static __{ident}_initcall: extern \"C\" fn() -> + ::kernel::ffi::c_int = __{ident}_init; #[cfg(not(MODULE))] #[cfg(CONFIG_HAVE_ARCH_PREL32_RELOCATIONS)] - core::arch::global_asm!( + ::core::arch::global_asm!( r#\".section \"{initcall_section}\", \"a\" - __{name}_initcall: - .long __{name}_init - . + __{ident}_initcall: + .long __{ident}_init - . .previous \"# ); @@ -309,7 +314,7 @@ pub(crate) fn module(ts: TokenStream) -> TokenStream { #[cfg(not(MODULE))] #[doc(hidden)] #[no_mangle] - pub extern \"C\" fn __{name}_init() -> kernel::ffi::c_int {{ + pub extern \"C\" fn __{ident}_init() -> ::kernel::ffi::c_int {{ // SAFETY: This function is inaccessible to the outside due to the double // module wrapping it. It is called exactly once by the C side via its // placement above in the initcall section. @@ -319,22 +324,22 @@ pub(crate) fn module(ts: TokenStream) -> TokenStream { #[cfg(not(MODULE))] #[doc(hidden)] #[no_mangle] - pub extern \"C\" fn __{name}_exit() {{ + pub extern \"C\" fn __{ident}_exit() {{ // SAFETY: // - This function is inaccessible to the outside due to the double // module wrapping it. It is called exactly once by the C side via its // unique name, - // - furthermore it is only called after `__{name}_init` has returned `0` - // (which delegates to `__init`). + // - furthermore it is only called after `__{ident}_init` has + // returned `0` (which delegates to `__init`). unsafe {{ __exit() }} }} /// # Safety /// /// This function must only be called once. - unsafe fn __init() -> kernel::ffi::c_int {{ + unsafe fn __init() -> ::kernel::ffi::c_int {{ let initer = - <{type_} as kernel::InPlaceModule>::init(&super::super::THIS_MODULE); + <{type_} as ::kernel::InPlaceModule>::init(&super::super::THIS_MODULE); // SAFETY: No data race, since `__MOD` can only be accessed by this module // and there only `__init` and `__exit` access it. These functions are only // called once and `__exit` cannot be called before or during `__init`. @@ -365,6 +370,7 @@ pub(crate) fn module(ts: TokenStream) -> TokenStream { ", type_ = info.type_, name = info.name, + ident = ident, modinfo = modinfo.buffer, initcall_section = ".initcall6.init" ) diff --git a/rust/macros/paste.rs b/rust/macros/paste.rs index 6529a387673f..cce712d19855 100644 --- a/rust/macros/paste.rs +++ b/rust/macros/paste.rs @@ -50,7 +50,7 @@ fn concat_helper(tokens: &[TokenTree]) -> Vec<(String, Span)> { let tokens = group.stream().into_iter().collect::<Vec<TokenTree>>(); segments.append(&mut concat_helper(tokens.as_slice())); } - token => panic!("unexpected token in paste segments: {:?}", token), + token => panic!("unexpected token in paste segments: {token:?}"), }; } diff --git a/rust/macros/quote.rs b/rust/macros/quote.rs index 33a199e4f176..92cacc4067c9 100644 --- a/rust/macros/quote.rs +++ b/rust/macros/quote.rs @@ -2,6 +2,7 @@ use proc_macro::{TokenStream, TokenTree}; +#[allow(dead_code)] pub(crate) trait ToTokens { fn to_tokens(&self, tokens: &mut TokenStream); } @@ -20,6 +21,12 @@ impl ToTokens for proc_macro::Group { } } +impl ToTokens for proc_macro::Ident { + fn to_tokens(&self, tokens: &mut TokenStream) { + tokens.extend([TokenTree::from(self.clone())]); + } +} + impl ToTokens for TokenTree { fn to_tokens(&self, tokens: &mut TokenStream) { tokens.extend([self.clone()]); @@ -40,7 +47,7 @@ impl ToTokens for TokenStream { /// `quote` crate but provides only just enough functionality needed by the current `macros` crate. macro_rules! quote_spanned { ($span:expr => $($tt:tt)*) => {{ - let mut tokens; + let mut tokens: ::std::vec::Vec<::proc_macro::TokenTree>; #[allow(clippy::vec_init_then_push)] { tokens = ::std::vec::Vec::new(); @@ -65,7 +72,8 @@ macro_rules! quote_spanned { quote_spanned!(@proc $v $span $($tt)*); }; (@proc $v:ident $span:ident ( $($inner:tt)* ) $($tt:tt)*) => { - let mut tokens = ::std::vec::Vec::new(); + #[allow(unused_mut)] + let mut tokens = ::std::vec::Vec::<::proc_macro::TokenTree>::new(); quote_spanned!(@proc tokens $span $($inner)*); $v.push(::proc_macro::TokenTree::Group(::proc_macro::Group::new( ::proc_macro::Delimiter::Parenthesis, @@ -136,6 +144,22 @@ macro_rules! quote_spanned { )); quote_spanned!(@proc $v $span $($tt)*); }; + (@proc $v:ident $span:ident = $($tt:tt)*) => { + $v.push(::proc_macro::TokenTree::Punct( + ::proc_macro::Punct::new('=', ::proc_macro::Spacing::Alone) + )); + quote_spanned!(@proc $v $span $($tt)*); + }; + (@proc $v:ident $span:ident # $($tt:tt)*) => { + $v.push(::proc_macro::TokenTree::Punct( + ::proc_macro::Punct::new('#', ::proc_macro::Spacing::Alone) + )); + quote_spanned!(@proc $v $span $($tt)*); + }; + (@proc $v:ident $span:ident _ $($tt:tt)*) => { + $v.push(::proc_macro::TokenTree::Ident(::proc_macro::Ident::new("_", $span))); + quote_spanned!(@proc $v $span $($tt)*); + }; (@proc $v:ident $span:ident $id:ident $($tt:tt)*) => { $v.push(::proc_macro::TokenTree::Ident(::proc_macro::Ident::new(stringify!($id), $span))); quote_spanned!(@proc $v $span $($tt)*); diff --git a/rust/pin-init/CONTRIBUTING.md b/rust/pin-init/CONTRIBUTING.md new file mode 100644 index 000000000000..16c899a7ae0b --- /dev/null +++ b/rust/pin-init/CONTRIBUTING.md @@ -0,0 +1,72 @@ +# Contributing to `pin-init` + +Thanks for showing interest in contributing to `pin-init`! This document outlines the guidelines for +contributing to `pin-init`. + +All contributions are double-licensed under Apache 2.0 and MIT. You can find the respective licenses +in the `LICENSE-APACHE` and `LICENSE-MIT` files. + +## Non-Code Contributions + +### Bug Reports + +For any type of bug report, please submit an issue using the bug report issue template. + +If the issue is a soundness issue, please privately report it as a security vulnerability via the +GitHub web interface. + +### Feature Requests + +If you have any feature requests, please submit an issue using the feature request issue template. + +### Questions and Getting Help + +You can ask questions in the Discussions page of the GitHub repository. If you're encountering +problems or just have questions related to `pin-init` in the Linux kernel, you can also ask your +questions in the [Rust-for-Linux Zulip](https://rust-for-linux.zulipchat.com/) or see +<https://rust-for-linux.com/contact>. + +## Contributing Code + +### Linux Kernel + +`pin-init` is used by the Linux kernel and all commits are synchronized to it. For this reason, the +same requirements for commits apply to `pin-init`. See [the kernel's documentation] for details. The +rest of this document will also cover some of the rules listed there and additional ones. + +[the kernel's documentation]: https://docs.kernel.org/process/submitting-patches.html + +Contributions to `pin-init` ideally go through the [GitHub repository], because that repository runs +a CI with lots of tests not present in the kernel. However, patches are also accepted (though not +preferred). Do note that there are some files that are only present in the GitHub repository such as +tests, licenses and cargo related files. Making changes to them can only happen via GitHub. + +[GitHub repository]: https://github.com/Rust-for-Linux/pin-init + +### Commit Style + +Everything must compile without errors or warnings and all tests must pass after **every commit**. +This is important for bisection and also required by the kernel. + +Each commit should be a single, logically cohesive change. Of course it's best to keep the changes +small and digestible, but logically linked changes should be made in the same commit. For example, +when fixing typos, create a single commit that fixes all of them instead of one commit per typo. + +Commits must have a meaningful commit title. Commits with changes to files in the `internal` +directory should have a title prefixed with `internal:`. The commit message should explain the +change and its rationale. You also have to add your `Signed-off-by` tag, see [Developer's +Certificate of Origin]. This has to be done for both mailing list submissions as well as GitHub +submissions. + +[Developer's Certificate of Origin]: https://docs.kernel.org/process/submitting-patches.html#sign-your-work-the-developer-s-certificate-of-origin + +Any changes made to public APIs must be documented not only in the commit message, but also in the +`CHANGELOG.md` file. This is especially important for breaking changes, as those warrant a major +version bump. + +If you make changes to the top-level crate documentation, you also need to update the `README.md` +via `cargo rdme`. + +Some of these rules can be ignored if the change is done solely to files that are not present in the +kernel version of this library. Those files are documented in the `sync-kernel.sh` script at the +very bottom in the `--exclude` flag given to the `git am` command. diff --git a/rust/pin-init/README.md b/rust/pin-init/README.md new file mode 100644 index 000000000000..a4c01a8d78b2 --- /dev/null +++ b/rust/pin-init/README.md @@ -0,0 +1,236 @@ +[](https://crates.io/crates/pin-init) +[](https://docs.rs/pin-init/) +[](https://deps.rs/repo/github/Rust-for-Linux/pin-init) + +[](#nightly-only) + +# `pin-init` + +<!-- cargo-rdme start --> + +Library to safely and fallibly initialize pinned `struct`s using in-place constructors. + +[Pinning][pinning] is Rust's way of ensuring data does not move. + +It also allows in-place initialization of big `struct`s that would otherwise produce a stack +overflow. + +This library's main use-case is in [Rust-for-Linux]. Although this version can be used +standalone. + +There are cases when you want to in-place initialize a struct. For example when it is very big +and moving it from the stack is not an option, because it is bigger than the stack itself. +Another reason would be that you need the address of the object to initialize it. This stands +in direct conflict with Rust's normal process of first initializing an object and then moving +it into it's final memory location. For more information, see +<https://rust-for-linux.com/the-safe-pinned-initialization-problem>. + +This library allows you to do in-place initialization safely. + +### Nightly Needed for `alloc` feature + +This library requires the [`allocator_api` unstable feature] when the `alloc` feature is +enabled and thus this feature can only be used with a nightly compiler. When enabling the +`alloc` feature, the user will be required to activate `allocator_api` as well. + +[`allocator_api` unstable feature]: https://doc.rust-lang.org/nightly/unstable-book/library-features/allocator-api.html + +The feature is enabled by default, thus by default `pin-init` will require a nightly compiler. +However, using the crate on stable compilers is possible by disabling `alloc`. In practice this +will require the `std` feature, because stable compilers have neither `Box` nor `Arc` in no-std +mode. + +### Nightly needed for `unsafe-pinned` feature + +This feature enables the `Wrapper` implementation on the unstable `core::pin::UnsafePinned` type. +This requires the [`unsafe_pinned` unstable feature](https://github.com/rust-lang/rust/issues/125735) +and therefore a nightly compiler. Note that this feature is not enabled by default. + +## Overview + +To initialize a `struct` with an in-place constructor you will need two things: +- an in-place constructor, +- a memory location that can hold your `struct` (this can be the [stack], an [`Arc<T>`], + [`Box<T>`] or any other smart pointer that supports this library). + +To get an in-place constructor there are generally three options: +- directly creating an in-place constructor using the [`pin_init!`] macro, +- a custom function/macro returning an in-place constructor provided by someone else, +- using the unsafe function [`pin_init_from_closure()`] to manually create an initializer. + +Aside from pinned initialization, this library also supports in-place construction without +pinning, the macros/types/functions are generally named like the pinned variants without the +`pin_` prefix. + +## Examples + +Throughout the examples we will often make use of the `CMutex` type which can be found in +`../examples/mutex.rs`. It is essentially a userland rebuild of the `struct mutex` type from +the Linux kernel. It also uses a wait list and a basic spinlock. Importantly the wait list +requires it to be pinned to be locked and thus is a prime candidate for using this library. + +### Using the [`pin_init!`] macro + +If you want to use [`PinInit`], then you will have to annotate your `struct` with +`#[`[`pin_data`]`]`. It is a macro that uses `#[pin]` as a marker for +[structurally pinned fields]. After doing this, you can then create an in-place constructor via +[`pin_init!`]. The syntax is almost the same as normal `struct` initializers. The difference is +that you need to write `<-` instead of `:` for fields that you want to initialize in-place. + +```rust +use pin_init::{pin_data, pin_init, InPlaceInit}; + +#[pin_data] +struct Foo { + #[pin] + a: CMutex<usize>, + b: u32, +} + +let foo = pin_init!(Foo { + a <- CMutex::new(42), + b: 24, +}); +``` + +`foo` now is of the type [`impl PinInit<Foo>`]. We can now use any smart pointer that we like +(or just the stack) to actually initialize a `Foo`: + +```rust +let foo: Result<Pin<Box<Foo>>, AllocError> = Box::pin_init(foo); +``` + +For more information see the [`pin_init!`] macro. + +### Using a custom function/macro that returns an initializer + +Many types that use this library supply a function/macro that returns an initializer, because +the above method only works for types where you can access the fields. + +```rust +let mtx: Result<Pin<Arc<CMutex<usize>>>, _> = Arc::pin_init(CMutex::new(42)); +``` + +To declare an init macro/function you just return an [`impl PinInit<T, E>`]: + +```rust +#[pin_data] +struct DriverData { + #[pin] + status: CMutex<i32>, + buffer: Box<[u8; 1_000_000]>, +} + +impl DriverData { + fn new() -> impl PinInit<Self, Error> { + try_pin_init!(Self { + status <- CMutex::new(0), + buffer: Box::init(pin_init::init_zeroed())?, + }? Error) + } +} +``` + +### Manual creation of an initializer + +Often when working with primitives the previous approaches are not sufficient. That is where +[`pin_init_from_closure()`] comes in. This `unsafe` function allows you to create a +[`impl PinInit<T, E>`] directly from a closure. Of course you have to ensure that the closure +actually does the initialization in the correct way. Here are the things to look out for +(we are calling the parameter to the closure `slot`): +- when the closure returns `Ok(())`, then it has completed the initialization successfully, so + `slot` now contains a valid bit pattern for the type `T`, +- when the closure returns `Err(e)`, then the caller may deallocate the memory at `slot`, so + you need to take care to clean up anything if your initialization fails mid-way, +- you may assume that `slot` will stay pinned even after the closure returns until `drop` of + `slot` gets called. + +```rust +use pin_init::{pin_data, pinned_drop, PinInit, PinnedDrop, pin_init_from_closure}; +use core::{ + ptr::addr_of_mut, + marker::PhantomPinned, + cell::UnsafeCell, + pin::Pin, + mem::MaybeUninit, +}; +mod bindings { + #[repr(C)] + pub struct foo { + /* fields from C ... */ + } + extern "C" { + pub fn init_foo(ptr: *mut foo); + pub fn destroy_foo(ptr: *mut foo); + #[must_use = "you must check the error return code"] + pub fn enable_foo(ptr: *mut foo, flags: u32) -> i32; + } +} + +/// # Invariants +/// +/// `foo` is always initialized +#[pin_data(PinnedDrop)] +pub struct RawFoo { + #[pin] + _p: PhantomPinned, + #[pin] + foo: UnsafeCell<MaybeUninit<bindings::foo>>, +} + +impl RawFoo { + pub fn new(flags: u32) -> impl PinInit<Self, i32> { + // SAFETY: + // - when the closure returns `Ok(())`, then it has successfully initialized and + // enabled `foo`, + // - when it returns `Err(e)`, then it has cleaned up before + unsafe { + pin_init_from_closure(move |slot: *mut Self| { + // `slot` contains uninit memory, avoid creating a reference. + let foo = addr_of_mut!((*slot).foo); + let foo = UnsafeCell::raw_get(foo).cast::<bindings::foo>(); + + // Initialize the `foo` + bindings::init_foo(foo); + + // Try to enable it. + let err = bindings::enable_foo(foo, flags); + if err != 0 { + // Enabling has failed, first clean up the foo and then return the error. + bindings::destroy_foo(foo); + Err(err) + } else { + // All fields of `RawFoo` have been initialized, since `_p` is a ZST. + Ok(()) + } + }) + } + } +} + +#[pinned_drop] +impl PinnedDrop for RawFoo { + fn drop(self: Pin<&mut Self>) { + // SAFETY: Since `foo` is initialized, destroying is safe. + unsafe { bindings::destroy_foo(self.foo.get().cast::<bindings::foo>()) }; + } +} +``` + +For more information on how to use [`pin_init_from_closure()`], take a look at the uses inside +the `kernel` crate. The [`sync`] module is a good starting point. + +[`sync`]: https://rust.docs.kernel.org/kernel/sync/index.html +[pinning]: https://doc.rust-lang.org/std/pin/index.html +[structurally pinned fields]: https://doc.rust-lang.org/std/pin/index.html#projections-and-structural-pinning +[stack]: https://docs.rs/pin-init/latest/pin_init/macro.stack_pin_init.html +[`impl PinInit<Foo>`]: https://docs.rs/pin-init/latest/pin_init/trait.PinInit.html +[`impl PinInit<T, E>`]: https://docs.rs/pin-init/latest/pin_init/trait.PinInit.html +[`impl Init<T, E>`]: https://docs.rs/pin-init/latest/pin_init/trait.Init.html +[Rust-for-Linux]: https://rust-for-linux.com/ + +<!-- cargo-rdme end --> + +<!-- These links are not picked up by cargo-rdme, since they are behind cfgs... --> +[`Arc<T>`]: https://doc.rust-lang.org/stable/alloc/sync/struct.Arc.html +[`Box<T>`]: https://doc.rust-lang.org/stable/alloc/boxed/struct.Box.html diff --git a/rust/pin-init/examples/big_struct_in_place.rs b/rust/pin-init/examples/big_struct_in_place.rs new file mode 100644 index 000000000000..c05139927486 --- /dev/null +++ b/rust/pin-init/examples/big_struct_in_place.rs @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT + +use pin_init::*; + +// Struct with size over 1GiB +#[derive(Debug)] +#[allow(dead_code)] +pub struct BigStruct { + buf: [u8; 1024 * 1024 * 1024], + a: u64, + b: u64, + c: u64, + d: u64, + managed_buf: ManagedBuf, +} + +#[derive(Debug)] +pub struct ManagedBuf { + buf: [u8; 1024 * 1024], +} + +impl ManagedBuf { + pub fn new() -> impl Init<Self> { + init!(ManagedBuf { buf <- init_zeroed() }) + } +} + +fn main() { + #[cfg(any(feature = "std", feature = "alloc"))] + { + // we want to initialize the struct in-place, otherwise we would get a stackoverflow + let buf: Box<BigStruct> = Box::init(init!(BigStruct { + buf <- init_zeroed(), + a: 7, + b: 186, + c: 7789, + d: 34, + managed_buf <- ManagedBuf::new(), + })) + .unwrap(); + println!("{}", core::mem::size_of_val(&*buf)); + } +} diff --git a/rust/pin-init/examples/error.rs b/rust/pin-init/examples/error.rs new file mode 100644 index 000000000000..e0cc258746ce --- /dev/null +++ b/rust/pin-init/examples/error.rs @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT + +#![cfg_attr(feature = "alloc", feature(allocator_api))] + +use core::convert::Infallible; + +#[cfg(feature = "alloc")] +use std::alloc::AllocError; + +#[derive(Debug)] +pub struct Error; + +impl From<Infallible> for Error { + fn from(e: Infallible) -> Self { + match e {} + } +} + +#[cfg(feature = "alloc")] +impl From<AllocError> for Error { + fn from(_: AllocError) -> Self { + Self + } +} + +#[allow(dead_code)] +fn main() {} diff --git a/rust/pin-init/examples/linked_list.rs b/rust/pin-init/examples/linked_list.rs new file mode 100644 index 000000000000..f9e117c7dfe0 --- /dev/null +++ b/rust/pin-init/examples/linked_list.rs @@ -0,0 +1,170 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT + +#![allow(clippy::undocumented_unsafe_blocks)] +#![cfg_attr(feature = "alloc", feature(allocator_api))] +#![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))] + +use core::{ + cell::Cell, + convert::Infallible, + marker::PhantomPinned, + pin::Pin, + ptr::{self, NonNull}, +}; + +use pin_init::*; + +#[allow(unused_attributes)] +mod error; +#[allow(unused_imports)] +use error::Error; + +#[pin_data(PinnedDrop)] +#[repr(C)] +#[derive(Debug)] +pub struct ListHead { + next: Link, + prev: Link, + #[pin] + pin: PhantomPinned, +} + +impl ListHead { + #[inline] + pub fn new() -> impl PinInit<Self, Infallible> { + try_pin_init!(&this in Self { + next: unsafe { Link::new_unchecked(this) }, + prev: unsafe { Link::new_unchecked(this) }, + pin: PhantomPinned, + }? Infallible) + } + + #[inline] + #[allow(dead_code)] + pub fn insert_next(list: &ListHead) -> impl PinInit<Self, Infallible> + '_ { + try_pin_init!(&this in Self { + prev: list.next.prev().replace(unsafe { Link::new_unchecked(this)}), + next: list.next.replace(unsafe { Link::new_unchecked(this)}), + pin: PhantomPinned, + }? Infallible) + } + + #[inline] + pub fn insert_prev(list: &ListHead) -> impl PinInit<Self, Infallible> + '_ { + try_pin_init!(&this in Self { + next: list.prev.next().replace(unsafe { Link::new_unchecked(this)}), + prev: list.prev.replace(unsafe { Link::new_unchecked(this)}), + pin: PhantomPinned, + }? Infallible) + } + + #[inline] + pub fn next(&self) -> Option<NonNull<Self>> { + if ptr::eq(self.next.as_ptr(), self) { + None + } else { + Some(unsafe { NonNull::new_unchecked(self.next.as_ptr() as *mut Self) }) + } + } + + #[allow(dead_code)] + pub fn size(&self) -> usize { + let mut size = 1; + let mut cur = self.next.clone(); + while !ptr::eq(self, cur.cur()) { + cur = cur.next().clone(); + size += 1; + } + size + } +} + +#[pinned_drop] +impl PinnedDrop for ListHead { + //#[inline] + fn drop(self: Pin<&mut Self>) { + if !ptr::eq(self.next.as_ptr(), &*self) { + let next = unsafe { &*self.next.as_ptr() }; + let prev = unsafe { &*self.prev.as_ptr() }; + next.prev.set(&self.prev); + prev.next.set(&self.next); + } + } +} + +#[repr(transparent)] +#[derive(Clone, Debug)] +struct Link(Cell<NonNull<ListHead>>); + +impl Link { + /// # Safety + /// + /// The contents of the pointer should form a consistent circular + /// linked list; for example, a "next" link should be pointed back + /// by the target `ListHead`'s "prev" link and a "prev" link should be + /// pointed back by the target `ListHead`'s "next" link. + #[inline] + unsafe fn new_unchecked(ptr: NonNull<ListHead>) -> Self { + Self(Cell::new(ptr)) + } + + #[inline] + fn next(&self) -> &Link { + unsafe { &(*self.0.get().as_ptr()).next } + } + + #[inline] + #[allow(dead_code)] + fn prev(&self) -> &Link { + unsafe { &(*self.0.get().as_ptr()).prev } + } + + #[allow(dead_code)] + fn cur(&self) -> &ListHead { + unsafe { &*self.0.get().as_ptr() } + } + + #[inline] + fn replace(&self, other: Link) -> Link { + unsafe { Link::new_unchecked(self.0.replace(other.0.get())) } + } + + #[inline] + fn as_ptr(&self) -> *const ListHead { + self.0.get().as_ptr() + } + + #[inline] + fn set(&self, val: &Link) { + self.0.set(val.0.get()); + } +} + +#[allow(dead_code)] +#[cfg(not(any(feature = "std", feature = "alloc")))] +fn main() {} + +#[allow(dead_code)] +#[cfg_attr(test, test)] +#[cfg(any(feature = "std", feature = "alloc"))] +fn main() -> Result<(), Error> { + let a = Box::pin_init(ListHead::new())?; + stack_pin_init!(let b = ListHead::insert_next(&a)); + stack_pin_init!(let c = ListHead::insert_next(&a)); + stack_pin_init!(let d = ListHead::insert_next(&b)); + let e = Box::pin_init(ListHead::insert_next(&b))?; + println!("a ({a:p}): {a:?}"); + println!("b ({b:p}): {b:?}"); + println!("c ({c:p}): {c:?}"); + println!("d ({d:p}): {d:?}"); + println!("e ({e:p}): {e:?}"); + let mut inspect = &*a; + while let Some(next) = inspect.next() { + println!("({inspect:p}): {inspect:?}"); + inspect = unsafe { &*next.as_ptr() }; + if core::ptr::eq(inspect, &*a) { + break; + } + } + Ok(()) +} diff --git a/rust/pin-init/examples/mutex.rs b/rust/pin-init/examples/mutex.rs new file mode 100644 index 000000000000..9f295226cd64 --- /dev/null +++ b/rust/pin-init/examples/mutex.rs @@ -0,0 +1,225 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT + +#![allow(clippy::undocumented_unsafe_blocks)] +#![cfg_attr(feature = "alloc", feature(allocator_api))] +#![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))] +#![allow(clippy::missing_safety_doc)] + +use core::{ + cell::{Cell, UnsafeCell}, + marker::PhantomPinned, + ops::{Deref, DerefMut}, + pin::Pin, + sync::atomic::{AtomicBool, Ordering}, +}; +#[cfg(feature = "std")] +use std::{ + sync::Arc, + thread::{self, sleep, Builder, Thread}, + time::Duration, +}; + +use pin_init::*; +#[allow(unused_attributes)] +#[path = "./linked_list.rs"] +pub mod linked_list; +use linked_list::*; + +pub struct SpinLock { + inner: AtomicBool, +} + +impl SpinLock { + #[inline] + pub fn acquire(&self) -> SpinLockGuard<'_> { + while self + .inner + .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed) + .is_err() + { + #[cfg(feature = "std")] + while self.inner.load(Ordering::Relaxed) { + thread::yield_now(); + } + } + SpinLockGuard(self) + } + + #[inline] + #[allow(clippy::new_without_default)] + pub const fn new() -> Self { + Self { + inner: AtomicBool::new(false), + } + } +} + +pub struct SpinLockGuard<'a>(&'a SpinLock); + +impl Drop for SpinLockGuard<'_> { + #[inline] + fn drop(&mut self) { + self.0.inner.store(false, Ordering::Release); + } +} + +#[pin_data] +pub struct CMutex<T> { + #[pin] + wait_list: ListHead, + spin_lock: SpinLock, + locked: Cell<bool>, + #[pin] + data: UnsafeCell<T>, +} + +impl<T> CMutex<T> { + #[inline] + pub fn new(val: impl PinInit<T>) -> impl PinInit<Self> { + pin_init!(CMutex { + wait_list <- ListHead::new(), + spin_lock: SpinLock::new(), + locked: Cell::new(false), + data <- unsafe { + pin_init_from_closure(|slot: *mut UnsafeCell<T>| { + val.__pinned_init(slot.cast::<T>()) + }) + }, + }) + } + + #[inline] + pub fn lock(&self) -> Pin<CMutexGuard<'_, T>> { + let mut sguard = self.spin_lock.acquire(); + if self.locked.get() { + stack_pin_init!(let wait_entry = WaitEntry::insert_new(&self.wait_list)); + // println!("wait list length: {}", self.wait_list.size()); + while self.locked.get() { + drop(sguard); + #[cfg(feature = "std")] + thread::park(); + sguard = self.spin_lock.acquire(); + } + // This does have an effect, as the ListHead inside wait_entry implements Drop! + #[expect(clippy::drop_non_drop)] + drop(wait_entry); + } + self.locked.set(true); + unsafe { + Pin::new_unchecked(CMutexGuard { + mtx: self, + _pin: PhantomPinned, + }) + } + } + + #[allow(dead_code)] + pub fn get_data_mut(self: Pin<&mut Self>) -> &mut T { + // SAFETY: we have an exclusive reference and thus nobody has access to data. + unsafe { &mut *self.data.get() } + } +} + +unsafe impl<T: Send> Send for CMutex<T> {} +unsafe impl<T: Send> Sync for CMutex<T> {} + +pub struct CMutexGuard<'a, T> { + mtx: &'a CMutex<T>, + _pin: PhantomPinned, +} + +impl<T> Drop for CMutexGuard<'_, T> { + #[inline] + fn drop(&mut self) { + let sguard = self.mtx.spin_lock.acquire(); + self.mtx.locked.set(false); + if let Some(list_field) = self.mtx.wait_list.next() { + let _wait_entry = list_field.as_ptr().cast::<WaitEntry>(); + #[cfg(feature = "std")] + unsafe { + (*_wait_entry).thread.unpark() + }; + } + drop(sguard); + } +} + +impl<T> Deref for CMutexGuard<'_, T> { + type Target = T; + + #[inline] + fn deref(&self) -> &Self::Target { + unsafe { &*self.mtx.data.get() } + } +} + +impl<T> DerefMut for CMutexGuard<'_, T> { + #[inline] + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *self.mtx.data.get() } + } +} + +#[pin_data] +#[repr(C)] +struct WaitEntry { + #[pin] + wait_list: ListHead, + #[cfg(feature = "std")] + thread: Thread, +} + +impl WaitEntry { + #[inline] + fn insert_new(list: &ListHead) -> impl PinInit<Self> + '_ { + #[cfg(feature = "std")] + { + pin_init!(Self { + thread: thread::current(), + wait_list <- ListHead::insert_prev(list), + }) + } + #[cfg(not(feature = "std"))] + { + pin_init!(Self { + wait_list <- ListHead::insert_prev(list), + }) + } + } +} + +#[cfg_attr(test, test)] +#[allow(dead_code)] +fn main() { + #[cfg(feature = "std")] + { + let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap(); + let mut handles = vec![]; + let thread_count = 20; + let workload = if cfg!(miri) { 100 } else { 1_000 }; + for i in 0..thread_count { + let mtx = mtx.clone(); + handles.push( + Builder::new() + .name(format!("worker #{i}")) + .spawn(move || { + for _ in 0..workload { + *mtx.lock() += 1; + } + println!("{i} halfway"); + sleep(Duration::from_millis((i as u64) * 10)); + for _ in 0..workload { + *mtx.lock() += 1; + } + println!("{i} finished"); + }) + .expect("should not fail"), + ); + } + for h in handles { + h.join().expect("thread panicked"); + } + println!("{:?}", &*mtx.lock()); + assert_eq!(*mtx.lock(), workload * thread_count * 2); + } +} diff --git a/rust/pin-init/examples/pthread_mutex.rs b/rust/pin-init/examples/pthread_mutex.rs new file mode 100644 index 000000000000..49b004c8c137 --- /dev/null +++ b/rust/pin-init/examples/pthread_mutex.rs @@ -0,0 +1,184 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT + +// inspired by <https://github.com/nbdd0121/pin-init/blob/trunk/examples/pthread_mutex.rs> +#![allow(clippy::undocumented_unsafe_blocks)] +#![cfg_attr(feature = "alloc", feature(allocator_api))] +#![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))] + +#[cfg(not(windows))] +mod pthread_mtx { + #[cfg(feature = "alloc")] + use core::alloc::AllocError; + use core::{ + cell::UnsafeCell, + marker::PhantomPinned, + mem::MaybeUninit, + ops::{Deref, DerefMut}, + pin::Pin, + }; + use pin_init::*; + use std::convert::Infallible; + + #[pin_data(PinnedDrop)] + pub struct PThreadMutex<T> { + #[pin] + raw: UnsafeCell<libc::pthread_mutex_t>, + data: UnsafeCell<T>, + #[pin] + pin: PhantomPinned, + } + + unsafe impl<T: Send> Send for PThreadMutex<T> {} + unsafe impl<T: Send> Sync for PThreadMutex<T> {} + + #[pinned_drop] + impl<T> PinnedDrop for PThreadMutex<T> { + fn drop(self: Pin<&mut Self>) { + unsafe { + libc::pthread_mutex_destroy(self.raw.get()); + } + } + } + + #[derive(Debug)] + pub enum Error { + #[allow(dead_code)] + IO(std::io::Error), + #[allow(dead_code)] + Alloc, + } + + impl From<Infallible> for Error { + fn from(e: Infallible) -> Self { + match e {} + } + } + + #[cfg(feature = "alloc")] + impl From<AllocError> for Error { + fn from(_: AllocError) -> Self { + Self::Alloc + } + } + + impl<T> PThreadMutex<T> { + #[allow(dead_code)] + pub fn new(data: T) -> impl PinInit<Self, Error> { + fn init_raw() -> impl PinInit<UnsafeCell<libc::pthread_mutex_t>, Error> { + let init = |slot: *mut UnsafeCell<libc::pthread_mutex_t>| { + // we can cast, because `UnsafeCell` has the same layout as T. + let slot: *mut libc::pthread_mutex_t = slot.cast(); + let mut attr = MaybeUninit::uninit(); + let attr = attr.as_mut_ptr(); + // SAFETY: ptr is valid + let ret = unsafe { libc::pthread_mutexattr_init(attr) }; + if ret != 0 { + return Err(Error::IO(std::io::Error::from_raw_os_error(ret))); + } + // SAFETY: attr is initialized + let ret = unsafe { + libc::pthread_mutexattr_settype(attr, libc::PTHREAD_MUTEX_NORMAL) + }; + if ret != 0 { + // SAFETY: attr is initialized + unsafe { libc::pthread_mutexattr_destroy(attr) }; + return Err(Error::IO(std::io::Error::from_raw_os_error(ret))); + } + // SAFETY: slot is valid + unsafe { slot.write(libc::PTHREAD_MUTEX_INITIALIZER) }; + // SAFETY: attr and slot are valid ptrs and attr is initialized + let ret = unsafe { libc::pthread_mutex_init(slot, attr) }; + // SAFETY: attr was initialized + unsafe { libc::pthread_mutexattr_destroy(attr) }; + if ret != 0 { + return Err(Error::IO(std::io::Error::from_raw_os_error(ret))); + } + Ok(()) + }; + // SAFETY: mutex has been initialized + unsafe { pin_init_from_closure(init) } + } + try_pin_init!(Self { + data: UnsafeCell::new(data), + raw <- init_raw(), + pin: PhantomPinned, + }? Error) + } + + #[allow(dead_code)] + pub fn lock(&self) -> PThreadMutexGuard<'_, T> { + // SAFETY: raw is always initialized + unsafe { libc::pthread_mutex_lock(self.raw.get()) }; + PThreadMutexGuard { mtx: self } + } + } + + pub struct PThreadMutexGuard<'a, T> { + mtx: &'a PThreadMutex<T>, + } + + impl<T> Drop for PThreadMutexGuard<'_, T> { + fn drop(&mut self) { + // SAFETY: raw is always initialized + unsafe { libc::pthread_mutex_unlock(self.mtx.raw.get()) }; + } + } + + impl<T> Deref for PThreadMutexGuard<'_, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + unsafe { &*self.mtx.data.get() } + } + } + + impl<T> DerefMut for PThreadMutexGuard<'_, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *self.mtx.data.get() } + } + } +} + +#[cfg_attr(test, test)] +#[cfg_attr(all(test, miri), ignore)] +fn main() { + #[cfg(all(any(feature = "std", feature = "alloc"), not(windows)))] + { + use core::pin::Pin; + use pin_init::*; + use pthread_mtx::*; + use std::{ + sync::Arc, + thread::{sleep, Builder}, + time::Duration, + }; + let mtx: Pin<Arc<PThreadMutex<usize>>> = Arc::try_pin_init(PThreadMutex::new(0)).unwrap(); + let mut handles = vec![]; + let thread_count = 20; + let workload = 1_000_000; + for i in 0..thread_count { + let mtx = mtx.clone(); + handles.push( + Builder::new() + .name(format!("worker #{i}")) + .spawn(move || { + for _ in 0..workload { + *mtx.lock() += 1; + } + println!("{i} halfway"); + sleep(Duration::from_millis((i as u64) * 10)); + for _ in 0..workload { + *mtx.lock() += 1; + } + println!("{i} finished"); + }) + .expect("should not fail"), + ); + } + for h in handles { + h.join().expect("thread panicked"); + } + println!("{:?}", &*mtx.lock()); + assert_eq!(*mtx.lock(), workload * thread_count * 2); + } +} diff --git a/rust/pin-init/examples/static_init.rs b/rust/pin-init/examples/static_init.rs new file mode 100644 index 000000000000..0e165daa9798 --- /dev/null +++ b/rust/pin-init/examples/static_init.rs @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT + +#![allow(clippy::undocumented_unsafe_blocks)] +#![cfg_attr(feature = "alloc", feature(allocator_api))] +#![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))] +#![allow(unused_imports)] + +use core::{ + cell::{Cell, UnsafeCell}, + mem::MaybeUninit, + ops, + pin::Pin, + time::Duration, +}; +use pin_init::*; +#[cfg(feature = "std")] +use std::{ + sync::Arc, + thread::{sleep, Builder}, +}; + +#[allow(unused_attributes)] +mod mutex; +use mutex::*; + +pub struct StaticInit<T, I> { + cell: UnsafeCell<MaybeUninit<T>>, + init: Cell<Option<I>>, + lock: SpinLock, + present: Cell<bool>, +} + +unsafe impl<T: Sync, I> Sync for StaticInit<T, I> {} +unsafe impl<T: Send, I> Send for StaticInit<T, I> {} + +impl<T, I: PinInit<T>> StaticInit<T, I> { + pub const fn new(init: I) -> Self { + Self { + cell: UnsafeCell::new(MaybeUninit::uninit()), + init: Cell::new(Some(init)), + lock: SpinLock::new(), + present: Cell::new(false), + } + } +} + +impl<T, I: PinInit<T>> ops::Deref for StaticInit<T, I> { + type Target = T; + fn deref(&self) -> &Self::Target { + if self.present.get() { + unsafe { (*self.cell.get()).assume_init_ref() } + } else { + println!("acquire spinlock on static init"); + let _guard = self.lock.acquire(); + println!("rechecking present..."); + std::thread::sleep(std::time::Duration::from_millis(200)); + if self.present.get() { + return unsafe { (*self.cell.get()).assume_init_ref() }; + } + println!("doing init"); + let ptr = self.cell.get().cast::<T>(); + match self.init.take() { + Some(f) => unsafe { f.__pinned_init(ptr).unwrap() }, + None => unsafe { core::hint::unreachable_unchecked() }, + } + self.present.set(true); + unsafe { (*self.cell.get()).assume_init_ref() } + } + } +} + +pub struct CountInit; + +unsafe impl PinInit<CMutex<usize>> for CountInit { + unsafe fn __pinned_init( + self, + slot: *mut CMutex<usize>, + ) -> Result<(), core::convert::Infallible> { + let init = CMutex::new(0); + std::thread::sleep(std::time::Duration::from_millis(1000)); + unsafe { init.__pinned_init(slot) } + } +} + +pub static COUNT: StaticInit<CMutex<usize>, CountInit> = StaticInit::new(CountInit); + +fn main() { + #[cfg(feature = "std")] + { + let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap(); + let mut handles = vec![]; + let thread_count = 20; + let workload = 1_000; + for i in 0..thread_count { + let mtx = mtx.clone(); + handles.push( + Builder::new() + .name(format!("worker #{i}")) + .spawn(move || { + for _ in 0..workload { + *COUNT.lock() += 1; + std::thread::sleep(std::time::Duration::from_millis(10)); + *mtx.lock() += 1; + std::thread::sleep(std::time::Duration::from_millis(10)); + *COUNT.lock() += 1; + } + println!("{i} halfway"); + sleep(Duration::from_millis((i as u64) * 10)); + for _ in 0..workload { + std::thread::sleep(std::time::Duration::from_millis(10)); + *mtx.lock() += 1; + } + println!("{i} finished"); + }) + .expect("should not fail"), + ); + } + for h in handles { + h.join().expect("thread panicked"); + } + println!("{:?}, {:?}", &*mtx.lock(), &*COUNT.lock()); + assert_eq!(*mtx.lock(), workload * thread_count * 2); + } +} diff --git a/rust/pin-init/internal/src/helpers.rs b/rust/pin-init/internal/src/helpers.rs new file mode 100644 index 000000000000..236f989a50f2 --- /dev/null +++ b/rust/pin-init/internal/src/helpers.rs @@ -0,0 +1,152 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT + +#[cfg(not(kernel))] +use proc_macro2 as proc_macro; + +use proc_macro::{TokenStream, TokenTree}; + +/// Parsed generics. +/// +/// See the field documentation for an explanation what each of the fields represents. +/// +/// # Examples +/// +/// ```rust,ignore +/// # let input = todo!(); +/// let (Generics { decl_generics, impl_generics, ty_generics }, rest) = parse_generics(input); +/// quote! { +/// struct Foo<$($decl_generics)*> { +/// // ... +/// } +/// +/// impl<$impl_generics> Foo<$ty_generics> { +/// fn foo() { +/// // ... +/// } +/// } +/// } +/// ``` +pub(crate) struct Generics { + /// The generics with bounds and default values (e.g. `T: Clone, const N: usize = 0`). + /// + /// Use this on type definitions e.g. `struct Foo<$decl_generics> ...` (or `union`/`enum`). + pub(crate) decl_generics: Vec<TokenTree>, + /// The generics with bounds (e.g. `T: Clone, const N: usize`). + /// + /// Use this on `impl` blocks e.g. `impl<$impl_generics> Trait for ...`. + pub(crate) impl_generics: Vec<TokenTree>, + /// The generics without bounds and without default values (e.g. `T, N`). + /// + /// Use this when you use the type that is declared with these generics e.g. + /// `Foo<$ty_generics>`. + pub(crate) ty_generics: Vec<TokenTree>, +} + +/// Parses the given `TokenStream` into `Generics` and the rest. +/// +/// The generics are not present in the rest, but a where clause might remain. +pub(crate) fn parse_generics(input: TokenStream) -> (Generics, Vec<TokenTree>) { + // The generics with bounds and default values. + let mut decl_generics = vec![]; + // `impl_generics`, the declared generics with their bounds. + let mut impl_generics = vec![]; + // Only the names of the generics, without any bounds. + let mut ty_generics = vec![]; + // Tokens not related to the generics e.g. the `where` token and definition. + let mut rest = vec![]; + // The current level of `<`. + let mut nesting = 0; + let mut toks = input.into_iter(); + // If we are at the beginning of a generic parameter. + let mut at_start = true; + let mut skip_until_comma = false; + while let Some(tt) = toks.next() { + if nesting == 1 && matches!(&tt, TokenTree::Punct(p) if p.as_char() == '>') { + // Found the end of the generics. + break; + } else if nesting >= 1 { + decl_generics.push(tt.clone()); + } + match tt.clone() { + TokenTree::Punct(p) if p.as_char() == '<' => { + if nesting >= 1 && !skip_until_comma { + // This is inside of the generics and part of some bound. + impl_generics.push(tt); + } + nesting += 1; + } + TokenTree::Punct(p) if p.as_char() == '>' => { + // This is a parsing error, so we just end it here. + if nesting == 0 { + break; + } else { + nesting -= 1; + if nesting >= 1 && !skip_until_comma { + // We are still inside of the generics and part of some bound. + impl_generics.push(tt); + } + } + } + TokenTree::Punct(p) if skip_until_comma && p.as_char() == ',' => { + if nesting == 1 { + impl_generics.push(tt.clone()); + impl_generics.push(tt); + skip_until_comma = false; + } + } + _ if !skip_until_comma => { + match nesting { + // If we haven't entered the generics yet, we still want to keep these tokens. + 0 => rest.push(tt), + 1 => { + // Here depending on the token, it might be a generic variable name. + match tt.clone() { + TokenTree::Ident(i) if at_start && i.to_string() == "const" => { + let Some(name) = toks.next() else { + // Parsing error. + break; + }; + impl_generics.push(tt); + impl_generics.push(name.clone()); + ty_generics.push(name.clone()); + decl_generics.push(name); + at_start = false; + } + TokenTree::Ident(_) if at_start => { + impl_generics.push(tt.clone()); + ty_generics.push(tt); + at_start = false; + } + TokenTree::Punct(p) if p.as_char() == ',' => { + impl_generics.push(tt.clone()); + ty_generics.push(tt); + at_start = true; + } + // Lifetimes begin with `'`. + TokenTree::Punct(p) if p.as_char() == '\'' && at_start => { + impl_generics.push(tt.clone()); + ty_generics.push(tt); + } + // Generics can have default values, we skip these. + TokenTree::Punct(p) if p.as_char() == '=' => { + skip_until_comma = true; + } + _ => impl_generics.push(tt), + } + } + _ => impl_generics.push(tt), + } + } + _ => {} + } + } + rest.extend(toks); + ( + Generics { + impl_generics, + decl_generics, + ty_generics, + }, + rest, + ) +} diff --git a/rust/pin-init/internal/src/lib.rs b/rust/pin-init/internal/src/lib.rs new file mode 100644 index 000000000000..297b0129a5bf --- /dev/null +++ b/rust/pin-init/internal/src/lib.rs @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT + +// When fixdep scans this, it will find this string `CONFIG_RUSTC_VERSION_TEXT` +// and thus add a dependency on `include/config/RUSTC_VERSION_TEXT`, which is +// touched by Kconfig when the version string from the compiler changes. + +//! `pin-init` proc macros. + +#![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))] +// Allow `.into()` to convert +// - `proc_macro2::TokenStream` into `proc_macro::TokenStream` in the user-space version. +// - `proc_macro::TokenStream` into `proc_macro::TokenStream` in the kernel version. +// Clippy warns on this conversion, but it's required by the user-space version. +// +// Remove once we have `proc_macro2` in the kernel. +#![allow(clippy::useless_conversion)] +// Documentation is done in the pin-init crate instead. +#![allow(missing_docs)] + +use proc_macro::TokenStream; + +#[cfg(kernel)] +#[path = "../../../macros/quote.rs"] +#[macro_use] +#[cfg_attr(not(kernel), rustfmt::skip)] +mod quote; +#[cfg(not(kernel))] +#[macro_use] +extern crate quote; + +mod helpers; +mod pin_data; +mod pinned_drop; +mod zeroable; + +#[proc_macro_attribute] +pub fn pin_data(inner: TokenStream, item: TokenStream) -> TokenStream { + pin_data::pin_data(inner.into(), item.into()).into() +} + +#[proc_macro_attribute] +pub fn pinned_drop(args: TokenStream, input: TokenStream) -> TokenStream { + pinned_drop::pinned_drop(args.into(), input.into()).into() +} + +#[proc_macro_derive(Zeroable)] +pub fn derive_zeroable(input: TokenStream) -> TokenStream { + zeroable::derive(input.into()).into() +} + +#[proc_macro_derive(MaybeZeroable)] +pub fn maybe_derive_zeroable(input: TokenStream) -> TokenStream { + zeroable::maybe_derive(input.into()).into() +} diff --git a/rust/macros/pin_data.rs b/rust/pin-init/internal/src/pin_data.rs index 1d4a3547c684..87d4a7eb1d35 100644 --- a/rust/macros/pin_data.rs +++ b/rust/pin-init/internal/src/pin_data.rs @@ -1,11 +1,14 @@ // SPDX-License-Identifier: Apache-2.0 OR MIT +#[cfg(not(kernel))] +use proc_macro2 as proc_macro; + use crate::helpers::{parse_generics, Generics}; use proc_macro::{Group, Punct, Spacing, TokenStream, TokenTree}; pub(crate) fn pin_data(args: TokenStream, input: TokenStream) -> TokenStream { // This proc-macro only does some pre-parsing and then delegates the actual parsing to - // `kernel::__pin_data!`. + // `pin_init::__pin_data!`. let ( Generics { @@ -71,7 +74,7 @@ pub(crate) fn pin_data(args: TokenStream, input: TokenStream) -> TokenStream { .collect::<Vec<_>>(); // This should be the body of the struct `{...}`. let last = rest.pop(); - let mut quoted = quote!(::kernel::__pin_data! { + let mut quoted = quote!(::pin_init::__pin_data! { parse_input: @args(#args), @sig(#(#rest)*), diff --git a/rust/macros/pinned_drop.rs b/rust/pin-init/internal/src/pinned_drop.rs index 88fb72b20660..c4ca7a70b726 100644 --- a/rust/macros/pinned_drop.rs +++ b/rust/pin-init/internal/src/pinned_drop.rs @@ -1,5 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 OR MIT +#[cfg(not(kernel))] +use proc_macro2 as proc_macro; + use proc_macro::{TokenStream, TokenTree}; pub(crate) fn pinned_drop(_args: TokenStream, input: TokenStream) -> TokenStream { @@ -25,8 +28,7 @@ pub(crate) fn pinned_drop(_args: TokenStream, input: TokenStream) -> TokenStream // Found the end of the generics, this should be `PinnedDrop`. assert!( matches!(tt, TokenTree::Ident(i) if i.to_string() == "PinnedDrop"), - "expected 'PinnedDrop', found: '{:?}'", - tt + "expected 'PinnedDrop', found: '{tt:?}'" ); pinned_drop_idx = Some(i); break; @@ -35,11 +37,11 @@ pub(crate) fn pinned_drop(_args: TokenStream, input: TokenStream) -> TokenStream let idx = pinned_drop_idx .unwrap_or_else(|| panic!("Expected an `impl` block implementing `PinnedDrop`.")); // Fully qualify the `PinnedDrop`, as to avoid any tampering. - toks.splice(idx..idx, quote!(::kernel::init::)); + toks.splice(idx..idx, quote!(::pin_init::)); // Take the `{}` body and call the declarative macro. if let Some(TokenTree::Group(last)) = toks.pop() { let last = last.stream(); - quote!(::kernel::__pinned_drop! { + quote!(::pin_init::__pinned_drop! { @impl_sig(#(#toks)*), @impl_body(#last), }) diff --git a/rust/macros/zeroable.rs b/rust/pin-init/internal/src/zeroable.rs index cfee2cec18d5..e0ed3998445c 100644 --- a/rust/macros/zeroable.rs +++ b/rust/pin-init/internal/src/zeroable.rs @@ -1,9 +1,19 @@ // SPDX-License-Identifier: GPL-2.0 +#[cfg(not(kernel))] +use proc_macro2 as proc_macro; + use crate::helpers::{parse_generics, Generics}; use proc_macro::{TokenStream, TokenTree}; -pub(crate) fn derive(input: TokenStream) -> TokenStream { +pub(crate) fn parse_zeroable_derive_input( + input: TokenStream, +) -> ( + Vec<TokenTree>, + Vec<TokenTree>, + Vec<TokenTree>, + Option<TokenTree>, +) { let ( Generics { impl_generics, @@ -27,7 +37,7 @@ pub(crate) fn derive(input: TokenStream) -> TokenStream { // If we find a `,`, then we have finished a generic/constant/lifetime parameter. TokenTree::Punct(p) if nested == 0 && p.as_char() == ',' => { if in_generic && !inserted { - new_impl_generics.extend(quote! { : ::kernel::init::Zeroable }); + new_impl_generics.extend(quote! { : ::pin_init::Zeroable }); } in_generic = true; inserted = false; @@ -41,7 +51,7 @@ pub(crate) fn derive(input: TokenStream) -> TokenStream { TokenTree::Punct(p) if nested == 0 && p.as_char() == ':' => { new_impl_generics.push(tt); if in_generic { - new_impl_generics.extend(quote! { ::kernel::init::Zeroable + }); + new_impl_generics.extend(quote! { ::pin_init::Zeroable + }); inserted = true; } } @@ -59,10 +69,28 @@ pub(crate) fn derive(input: TokenStream) -> TokenStream { } assert_eq!(nested, 0); if in_generic && !inserted { - new_impl_generics.extend(quote! { : ::kernel::init::Zeroable }); + new_impl_generics.extend(quote! { : ::pin_init::Zeroable }); } + (rest, new_impl_generics, ty_generics, last) +} + +pub(crate) fn derive(input: TokenStream) -> TokenStream { + let (rest, new_impl_generics, ty_generics, last) = parse_zeroable_derive_input(input); + quote! { + ::pin_init::__derive_zeroable!( + parse_input: + @sig(#(#rest)*), + @impl_generics(#(#new_impl_generics)*), + @ty_generics(#(#ty_generics)*), + @body(#last), + ); + } +} + +pub(crate) fn maybe_derive(input: TokenStream) -> TokenStream { + let (rest, new_impl_generics, ty_generics, last) = parse_zeroable_derive_input(input); quote! { - ::kernel::__derive_zeroable!( + ::pin_init::__maybe_derive_zeroable!( parse_input: @sig(#(#rest)*), @impl_generics(#(#new_impl_generics)*), diff --git a/rust/kernel/init/__internal.rs b/rust/pin-init/src/__internal.rs index 74329cc3262c..90f18e9a2912 100644 --- a/rust/kernel/init/__internal.rs +++ b/rust/pin-init/src/__internal.rs @@ -1,19 +1,20 @@ // SPDX-License-Identifier: Apache-2.0 OR MIT -//! This module contains API-internal items for pin-init. +//! This module contains library internal items. //! -//! These items must not be used outside of -//! - `kernel/init.rs` -//! - `macros/pin_data.rs` -//! - `macros/pinned_drop.rs` +//! These items must not be used outside of this crate and the pin-init-internal crate located at +//! `../internal`. use super::*; /// See the [nomicon] for what subtyping is. See also [this table]. /// +/// The reason for not using `PhantomData<*mut T>` is that that type never implements [`Send`] and +/// [`Sync`]. Hence `fn(*mut T) -> *mut T` is used, as that type always implements them. +/// /// [nomicon]: https://doc.rust-lang.org/nomicon/subtyping.html /// [this table]: https://doc.rust-lang.org/nomicon/phantom-data.html#table-of-phantomdata-patterns -pub(super) type Invariant<T> = PhantomData<fn(*mut T) -> *mut T>; +pub(crate) type Invariant<T> = PhantomData<fn(*mut T) -> *mut T>; /// Module-internal type implementing `PinInit` and `Init`. /// @@ -105,7 +106,7 @@ pub unsafe trait InitData: Copy { } } -pub struct AllData<T: ?Sized>(PhantomData<fn(KBox<T>) -> KBox<T>>); +pub struct AllData<T: ?Sized>(Invariant<T>); impl<T: ?Sized> Clone for AllData<T> { fn clone(&self) -> Self { @@ -135,7 +136,7 @@ unsafe impl<T: ?Sized> HasInitData for T { /// /// If `self.is_init` is true, then `self.value` is initialized. /// -/// [`stack_pin_init`]: kernel::stack_pin_init +/// [`stack_pin_init`]: crate::stack_pin_init pub struct StackInit<T> { value: MaybeUninit<T>, is_init: bool, @@ -156,7 +157,7 @@ impl<T> StackInit<T> { /// Creates a new [`StackInit<T>`] that is uninitialized. Use [`stack_pin_init`] instead of this /// primitive. /// - /// [`stack_pin_init`]: kernel::stack_pin_init + /// [`stack_pin_init`]: crate::stack_pin_init #[inline] pub fn uninit() -> Self { Self { @@ -186,6 +187,34 @@ impl<T> StackInit<T> { } } +#[test] +#[cfg(feature = "std")] +fn stack_init_reuse() { + use ::std::{borrow::ToOwned, println, string::String}; + use core::pin::pin; + + #[derive(Debug)] + struct Foo { + a: usize, + b: String, + } + let mut slot: Pin<&mut StackInit<Foo>> = pin!(StackInit::uninit()); + let value: Result<Pin<&mut Foo>, core::convert::Infallible> = + slot.as_mut().init(crate::init!(Foo { + a: 42, + b: "Hello".to_owned(), + })); + let value = value.unwrap(); + println!("{value:?}"); + let value: Result<Pin<&mut Foo>, core::convert::Infallible> = + slot.as_mut().init(crate::init!(Foo { + a: 24, + b: "world!".to_owned(), + })); + let value = value.unwrap(); + println!("{value:?}"); +} + /// When a value of this type is dropped, it drops a `T`. /// /// Can be forgotten to prevent the drop. diff --git a/rust/pin-init/src/alloc.rs b/rust/pin-init/src/alloc.rs new file mode 100644 index 000000000000..5017f57442d8 --- /dev/null +++ b/rust/pin-init/src/alloc.rs @@ -0,0 +1,156 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT + +#[cfg(all(feature = "alloc", not(feature = "std")))] +use alloc::{boxed::Box, sync::Arc}; +#[cfg(feature = "alloc")] +use core::alloc::AllocError; +use core::{mem::MaybeUninit, pin::Pin}; +#[cfg(feature = "std")] +use std::sync::Arc; + +#[cfg(not(feature = "alloc"))] +type AllocError = core::convert::Infallible; + +use crate::{ + init_from_closure, pin_init_from_closure, InPlaceWrite, Init, PinInit, ZeroableOption, +}; + +pub extern crate alloc; + +// SAFETY: All zeros is equivalent to `None` (option layout optimization guarantee: +// <https://doc.rust-lang.org/stable/std/option/index.html#representation>). +unsafe impl<T> ZeroableOption for Box<T> {} + +/// Smart pointer that can initialize memory in-place. +pub trait InPlaceInit<T>: Sized { + /// Use the given pin-initializer to pin-initialize a `T` inside of a new smart pointer of this + /// type. + /// + /// If `T: !Unpin` it will not be able to move afterwards. + fn try_pin_init<E>(init: impl PinInit<T, E>) -> Result<Pin<Self>, E> + where + E: From<AllocError>; + + /// Use the given pin-initializer to pin-initialize a `T` inside of a new smart pointer of this + /// type. + /// + /// If `T: !Unpin` it will not be able to move afterwards. + fn pin_init(init: impl PinInit<T>) -> Result<Pin<Self>, AllocError> { + // SAFETY: We delegate to `init` and only change the error type. + let init = unsafe { + pin_init_from_closure(|slot| match init.__pinned_init(slot) { + Ok(()) => Ok(()), + Err(i) => match i {}, + }) + }; + Self::try_pin_init(init) + } + + /// Use the given initializer to in-place initialize a `T`. + fn try_init<E>(init: impl Init<T, E>) -> Result<Self, E> + where + E: From<AllocError>; + + /// Use the given initializer to in-place initialize a `T`. + fn init(init: impl Init<T>) -> Result<Self, AllocError> { + // SAFETY: We delegate to `init` and only change the error type. + let init = unsafe { + init_from_closure(|slot| match init.__init(slot) { + Ok(()) => Ok(()), + Err(i) => match i {}, + }) + }; + Self::try_init(init) + } +} + +#[cfg(feature = "alloc")] +macro_rules! try_new_uninit { + ($type:ident) => { + $type::try_new_uninit()? + }; +} +#[cfg(all(feature = "std", not(feature = "alloc")))] +macro_rules! try_new_uninit { + ($type:ident) => { + $type::new_uninit() + }; +} + +impl<T> InPlaceInit<T> for Box<T> { + #[inline] + fn try_pin_init<E>(init: impl PinInit<T, E>) -> Result<Pin<Self>, E> + where + E: From<AllocError>, + { + try_new_uninit!(Box).write_pin_init(init) + } + + #[inline] + fn try_init<E>(init: impl Init<T, E>) -> Result<Self, E> + where + E: From<AllocError>, + { + try_new_uninit!(Box).write_init(init) + } +} + +impl<T> InPlaceInit<T> for Arc<T> { + #[inline] + fn try_pin_init<E>(init: impl PinInit<T, E>) -> Result<Pin<Self>, E> + where + E: From<AllocError>, + { + let mut this = try_new_uninit!(Arc); + let Some(slot) = Arc::get_mut(&mut this) else { + // SAFETY: the Arc has just been created and has no external references + unsafe { core::hint::unreachable_unchecked() } + }; + let slot = slot.as_mut_ptr(); + // SAFETY: When init errors/panics, slot will get deallocated but not dropped, + // slot is valid and will not be moved, because we pin it later. + unsafe { init.__pinned_init(slot)? }; + // SAFETY: All fields have been initialized and this is the only `Arc` to that data. + Ok(unsafe { Pin::new_unchecked(this.assume_init()) }) + } + + #[inline] + fn try_init<E>(init: impl Init<T, E>) -> Result<Self, E> + where + E: From<AllocError>, + { + let mut this = try_new_uninit!(Arc); + let Some(slot) = Arc::get_mut(&mut this) else { + // SAFETY: the Arc has just been created and has no external references + unsafe { core::hint::unreachable_unchecked() } + }; + let slot = slot.as_mut_ptr(); + // SAFETY: When init errors/panics, slot will get deallocated but not dropped, + // slot is valid. + unsafe { init.__init(slot)? }; + // SAFETY: All fields have been initialized. + Ok(unsafe { this.assume_init() }) + } +} + +impl<T> InPlaceWrite<T> for Box<MaybeUninit<T>> { + type Initialized = Box<T>; + + fn write_init<E>(mut self, init: impl Init<T, E>) -> Result<Self::Initialized, E> { + let slot = self.as_mut_ptr(); + // SAFETY: When init errors/panics, slot will get deallocated but not dropped, + // slot is valid. + unsafe { init.__init(slot)? }; + // SAFETY: All fields have been initialized. + Ok(unsafe { self.assume_init() }) + } + + fn write_pin_init<E>(mut self, init: impl PinInit<T, E>) -> Result<Pin<Self::Initialized>, E> { + let slot = self.as_mut_ptr(); + // SAFETY: When init errors/panics, slot will get deallocated but not dropped, + // slot is valid and will not be moved, because we pin it later. + unsafe { init.__pinned_init(slot)? }; + // SAFETY: All fields have been initialized. + Ok(unsafe { self.assume_init() }.into()) + } +} diff --git a/rust/pin-init/src/lib.rs b/rust/pin-init/src/lib.rs new file mode 100644 index 000000000000..62e013a5cc20 --- /dev/null +++ b/rust/pin-init/src/lib.rs @@ -0,0 +1,1731 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT + +//! Library to safely and fallibly initialize pinned `struct`s using in-place constructors. +//! +//! [Pinning][pinning] is Rust's way of ensuring data does not move. +//! +//! It also allows in-place initialization of big `struct`s that would otherwise produce a stack +//! overflow. +//! +//! This library's main use-case is in [Rust-for-Linux]. Although this version can be used +//! standalone. +//! +//! There are cases when you want to in-place initialize a struct. For example when it is very big +//! and moving it from the stack is not an option, because it is bigger than the stack itself. +//! Another reason would be that you need the address of the object to initialize it. This stands +//! in direct conflict with Rust's normal process of first initializing an object and then moving +//! it into it's final memory location. For more information, see +//! <https://rust-for-linux.com/the-safe-pinned-initialization-problem>. +//! +//! This library allows you to do in-place initialization safely. +//! +//! ## Nightly Needed for `alloc` feature +//! +//! This library requires the [`allocator_api` unstable feature] when the `alloc` feature is +//! enabled and thus this feature can only be used with a nightly compiler. When enabling the +//! `alloc` feature, the user will be required to activate `allocator_api` as well. +//! +//! [`allocator_api` unstable feature]: https://doc.rust-lang.org/nightly/unstable-book/library-features/allocator-api.html +//! +//! The feature is enabled by default, thus by default `pin-init` will require a nightly compiler. +//! However, using the crate on stable compilers is possible by disabling `alloc`. In practice this +//! will require the `std` feature, because stable compilers have neither `Box` nor `Arc` in no-std +//! mode. +//! +//! ## Nightly needed for `unsafe-pinned` feature +//! +//! This feature enables the `Wrapper` implementation on the unstable `core::pin::UnsafePinned` type. +//! This requires the [`unsafe_pinned` unstable feature](https://github.com/rust-lang/rust/issues/125735) +//! and therefore a nightly compiler. Note that this feature is not enabled by default. +//! +//! # Overview +//! +//! To initialize a `struct` with an in-place constructor you will need two things: +//! - an in-place constructor, +//! - a memory location that can hold your `struct` (this can be the [stack], an [`Arc<T>`], +//! [`Box<T>`] or any other smart pointer that supports this library). +//! +//! To get an in-place constructor there are generally three options: +//! - directly creating an in-place constructor using the [`pin_init!`] macro, +//! - a custom function/macro returning an in-place constructor provided by someone else, +//! - using the unsafe function [`pin_init_from_closure()`] to manually create an initializer. +//! +//! Aside from pinned initialization, this library also supports in-place construction without +//! pinning, the macros/types/functions are generally named like the pinned variants without the +//! `pin_` prefix. +//! +//! # Examples +//! +//! Throughout the examples we will often make use of the `CMutex` type which can be found in +//! `../examples/mutex.rs`. It is essentially a userland rebuild of the `struct mutex` type from +//! the Linux kernel. It also uses a wait list and a basic spinlock. Importantly the wait list +//! requires it to be pinned to be locked and thus is a prime candidate for using this library. +//! +//! ## Using the [`pin_init!`] macro +//! +//! If you want to use [`PinInit`], then you will have to annotate your `struct` with +//! `#[`[`pin_data`]`]`. It is a macro that uses `#[pin]` as a marker for +//! [structurally pinned fields]. After doing this, you can then create an in-place constructor via +//! [`pin_init!`]. The syntax is almost the same as normal `struct` initializers. The difference is +//! that you need to write `<-` instead of `:` for fields that you want to initialize in-place. +//! +//! ```rust +//! # #![expect(clippy::disallowed_names)] +//! # #![feature(allocator_api)] +//! # #[path = "../examples/mutex.rs"] mod mutex; use mutex::*; +//! # use core::pin::Pin; +//! use pin_init::{pin_data, pin_init, InPlaceInit}; +//! +//! #[pin_data] +//! struct Foo { +//! #[pin] +//! a: CMutex<usize>, +//! b: u32, +//! } +//! +//! let foo = pin_init!(Foo { +//! a <- CMutex::new(42), +//! b: 24, +//! }); +//! # let _ = Box::pin_init(foo); +//! ``` +//! +//! `foo` now is of the type [`impl PinInit<Foo>`]. We can now use any smart pointer that we like +//! (or just the stack) to actually initialize a `Foo`: +//! +//! ```rust +//! # #![expect(clippy::disallowed_names)] +//! # #![feature(allocator_api)] +//! # #[path = "../examples/mutex.rs"] mod mutex; use mutex::*; +//! # use core::{alloc::AllocError, pin::Pin}; +//! # use pin_init::*; +//! # +//! # #[pin_data] +//! # struct Foo { +//! # #[pin] +//! # a: CMutex<usize>, +//! # b: u32, +//! # } +//! # +//! # let foo = pin_init!(Foo { +//! # a <- CMutex::new(42), +//! # b: 24, +//! # }); +//! let foo: Result<Pin<Box<Foo>>, AllocError> = Box::pin_init(foo); +//! ``` +//! +//! For more information see the [`pin_init!`] macro. +//! +//! ## Using a custom function/macro that returns an initializer +//! +//! Many types that use this library supply a function/macro that returns an initializer, because +//! the above method only works for types where you can access the fields. +//! +//! ```rust +//! # #![feature(allocator_api)] +//! # #[path = "../examples/mutex.rs"] mod mutex; use mutex::*; +//! # use pin_init::*; +//! # use std::sync::Arc; +//! # use core::pin::Pin; +//! let mtx: Result<Pin<Arc<CMutex<usize>>>, _> = Arc::pin_init(CMutex::new(42)); +//! ``` +//! +//! To declare an init macro/function you just return an [`impl PinInit<T, E>`]: +//! +//! ```rust +//! # #![feature(allocator_api)] +//! # use pin_init::*; +//! # #[path = "../examples/error.rs"] mod error; use error::Error; +//! # #[path = "../examples/mutex.rs"] mod mutex; use mutex::*; +//! #[pin_data] +//! struct DriverData { +//! #[pin] +//! status: CMutex<i32>, +//! buffer: Box<[u8; 1_000_000]>, +//! } +//! +//! impl DriverData { +//! fn new() -> impl PinInit<Self, Error> { +//! try_pin_init!(Self { +//! status <- CMutex::new(0), +//! buffer: Box::init(pin_init::init_zeroed())?, +//! }? Error) +//! } +//! } +//! ``` +//! +//! ## Manual creation of an initializer +//! +//! Often when working with primitives the previous approaches are not sufficient. That is where +//! [`pin_init_from_closure()`] comes in. This `unsafe` function allows you to create a +//! [`impl PinInit<T, E>`] directly from a closure. Of course you have to ensure that the closure +//! actually does the initialization in the correct way. Here are the things to look out for +//! (we are calling the parameter to the closure `slot`): +//! - when the closure returns `Ok(())`, then it has completed the initialization successfully, so +//! `slot` now contains a valid bit pattern for the type `T`, +//! - when the closure returns `Err(e)`, then the caller may deallocate the memory at `slot`, so +//! you need to take care to clean up anything if your initialization fails mid-way, +//! - you may assume that `slot` will stay pinned even after the closure returns until `drop` of +//! `slot` gets called. +//! +//! ```rust +//! # #![feature(extern_types)] +//! use pin_init::{pin_data, pinned_drop, PinInit, PinnedDrop, pin_init_from_closure}; +//! use core::{ +//! ptr::addr_of_mut, +//! marker::PhantomPinned, +//! cell::UnsafeCell, +//! pin::Pin, +//! mem::MaybeUninit, +//! }; +//! mod bindings { +//! #[repr(C)] +//! pub struct foo { +//! /* fields from C ... */ +//! } +//! extern "C" { +//! pub fn init_foo(ptr: *mut foo); +//! pub fn destroy_foo(ptr: *mut foo); +//! #[must_use = "you must check the error return code"] +//! pub fn enable_foo(ptr: *mut foo, flags: u32) -> i32; +//! } +//! } +//! +//! /// # Invariants +//! /// +//! /// `foo` is always initialized +//! #[pin_data(PinnedDrop)] +//! pub struct RawFoo { +//! #[pin] +//! _p: PhantomPinned, +//! #[pin] +//! foo: UnsafeCell<MaybeUninit<bindings::foo>>, +//! } +//! +//! impl RawFoo { +//! pub fn new(flags: u32) -> impl PinInit<Self, i32> { +//! // SAFETY: +//! // - when the closure returns `Ok(())`, then it has successfully initialized and +//! // enabled `foo`, +//! // - when it returns `Err(e)`, then it has cleaned up before +//! unsafe { +//! pin_init_from_closure(move |slot: *mut Self| { +//! // `slot` contains uninit memory, avoid creating a reference. +//! let foo = addr_of_mut!((*slot).foo); +//! let foo = UnsafeCell::raw_get(foo).cast::<bindings::foo>(); +//! +//! // Initialize the `foo` +//! bindings::init_foo(foo); +//! +//! // Try to enable it. +//! let err = bindings::enable_foo(foo, flags); +//! if err != 0 { +//! // Enabling has failed, first clean up the foo and then return the error. +//! bindings::destroy_foo(foo); +//! Err(err) +//! } else { +//! // All fields of `RawFoo` have been initialized, since `_p` is a ZST. +//! Ok(()) +//! } +//! }) +//! } +//! } +//! } +//! +//! #[pinned_drop] +//! impl PinnedDrop for RawFoo { +//! fn drop(self: Pin<&mut Self>) { +//! // SAFETY: Since `foo` is initialized, destroying is safe. +//! unsafe { bindings::destroy_foo(self.foo.get().cast::<bindings::foo>()) }; +//! } +//! } +//! ``` +//! +//! For more information on how to use [`pin_init_from_closure()`], take a look at the uses inside +//! the `kernel` crate. The [`sync`] module is a good starting point. +//! +//! [`sync`]: https://rust.docs.kernel.org/kernel/sync/index.html +//! [pinning]: https://doc.rust-lang.org/std/pin/index.html +//! [structurally pinned fields]: +//! https://doc.rust-lang.org/std/pin/index.html#projections-and-structural-pinning +//! [stack]: crate::stack_pin_init +#![cfg_attr( + kernel, + doc = "[`Arc<T>`]: https://rust.docs.kernel.org/kernel/sync/struct.Arc.html" +)] +#![cfg_attr( + kernel, + doc = "[`Box<T>`]: https://rust.docs.kernel.org/kernel/alloc/kbox/struct.Box.html" +)] +#![cfg_attr(not(kernel), doc = "[`Arc<T>`]: alloc::alloc::sync::Arc")] +#![cfg_attr(not(kernel), doc = "[`Box<T>`]: alloc::alloc::boxed::Box")] +//! [`impl PinInit<Foo>`]: crate::PinInit +//! [`impl PinInit<T, E>`]: crate::PinInit +//! [`impl Init<T, E>`]: crate::Init +//! [Rust-for-Linux]: https://rust-for-linux.com/ + +#![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))] +#![cfg_attr( + all( + any(feature = "alloc", feature = "std"), + not(RUSTC_NEW_UNINIT_IS_STABLE) + ), + feature(new_uninit) +)] +#![forbid(missing_docs, unsafe_op_in_unsafe_fn)] +#![cfg_attr(not(feature = "std"), no_std)] +#![cfg_attr(feature = "alloc", feature(allocator_api))] +#![cfg_attr( + all(feature = "unsafe-pinned", CONFIG_RUSTC_HAS_UNSAFE_PINNED), + feature(unsafe_pinned) +)] + +use core::{ + cell::UnsafeCell, + convert::Infallible, + marker::PhantomData, + mem::MaybeUninit, + num::*, + pin::Pin, + ptr::{self, NonNull}, +}; + +#[doc(hidden)] +pub mod __internal; +#[doc(hidden)] +pub mod macros; + +#[cfg(any(feature = "std", feature = "alloc"))] +mod alloc; +#[cfg(any(feature = "std", feature = "alloc"))] +pub use alloc::InPlaceInit; + +/// Used to specify the pinning information of the fields of a struct. +/// +/// This is somewhat similar in purpose as +/// [pin-project-lite](https://crates.io/crates/pin-project-lite). +/// Place this macro on a struct definition and then `#[pin]` in front of the attributes of each +/// field you want to structurally pin. +/// +/// This macro enables the use of the [`pin_init!`] macro. When pin-initializing a `struct`, +/// then `#[pin]` directs the type of initializer that is required. +/// +/// If your `struct` implements `Drop`, then you need to add `PinnedDrop` as arguments to this +/// macro, and change your `Drop` implementation to `PinnedDrop` annotated with +/// `#[`[`macro@pinned_drop`]`]`, since dropping pinned values requires extra care. +/// +/// # Examples +/// +/// ``` +/// # #![feature(allocator_api)] +/// # #[path = "../examples/mutex.rs"] mod mutex; use mutex::*; +/// use pin_init::pin_data; +/// +/// enum Command { +/// /* ... */ +/// } +/// +/// #[pin_data] +/// struct DriverData { +/// #[pin] +/// queue: CMutex<Vec<Command>>, +/// buf: Box<[u8; 1024 * 1024]>, +/// } +/// ``` +/// +/// ``` +/// # #![feature(allocator_api)] +/// # #[path = "../examples/mutex.rs"] mod mutex; use mutex::*; +/// # mod bindings { pub struct info; pub unsafe fn destroy_info(_: *mut info) {} } +/// use core::pin::Pin; +/// use pin_init::{pin_data, pinned_drop, PinnedDrop}; +/// +/// enum Command { +/// /* ... */ +/// } +/// +/// #[pin_data(PinnedDrop)] +/// struct DriverData { +/// #[pin] +/// queue: CMutex<Vec<Command>>, +/// buf: Box<[u8; 1024 * 1024]>, +/// raw_info: *mut bindings::info, +/// } +/// +/// #[pinned_drop] +/// impl PinnedDrop for DriverData { +/// fn drop(self: Pin<&mut Self>) { +/// unsafe { bindings::destroy_info(self.raw_info) }; +/// } +/// } +/// ``` +pub use ::pin_init_internal::pin_data; + +/// Used to implement `PinnedDrop` safely. +/// +/// Only works on structs that are annotated via `#[`[`macro@pin_data`]`]`. +/// +/// # Examples +/// +/// ``` +/// # #![feature(allocator_api)] +/// # #[path = "../examples/mutex.rs"] mod mutex; use mutex::*; +/// # mod bindings { pub struct info; pub unsafe fn destroy_info(_: *mut info) {} } +/// use core::pin::Pin; +/// use pin_init::{pin_data, pinned_drop, PinnedDrop}; +/// +/// enum Command { +/// /* ... */ +/// } +/// +/// #[pin_data(PinnedDrop)] +/// struct DriverData { +/// #[pin] +/// queue: CMutex<Vec<Command>>, +/// buf: Box<[u8; 1024 * 1024]>, +/// raw_info: *mut bindings::info, +/// } +/// +/// #[pinned_drop] +/// impl PinnedDrop for DriverData { +/// fn drop(self: Pin<&mut Self>) { +/// unsafe { bindings::destroy_info(self.raw_info) }; +/// } +/// } +/// ``` +pub use ::pin_init_internal::pinned_drop; + +/// Derives the [`Zeroable`] trait for the given `struct` or `union`. +/// +/// This can only be used for `struct`s/`union`s where every field implements the [`Zeroable`] +/// trait. +/// +/// # Examples +/// +/// ``` +/// use pin_init::Zeroable; +/// +/// #[derive(Zeroable)] +/// pub struct DriverData { +/// pub(crate) id: i64, +/// buf_ptr: *mut u8, +/// len: usize, +/// } +/// ``` +/// +/// ``` +/// use pin_init::Zeroable; +/// +/// #[derive(Zeroable)] +/// pub union SignCast { +/// signed: i64, +/// unsigned: u64, +/// } +/// ``` +pub use ::pin_init_internal::Zeroable; + +/// Derives the [`Zeroable`] trait for the given `struct` or `union` if all fields implement +/// [`Zeroable`]. +/// +/// Contrary to the derive macro named [`macro@Zeroable`], this one silently fails when a field +/// doesn't implement [`Zeroable`]. +/// +/// # Examples +/// +/// ``` +/// use pin_init::MaybeZeroable; +/// +/// // implmements `Zeroable` +/// #[derive(MaybeZeroable)] +/// pub struct DriverData { +/// pub(crate) id: i64, +/// buf_ptr: *mut u8, +/// len: usize, +/// } +/// +/// // does not implmement `Zeroable` +/// #[derive(MaybeZeroable)] +/// pub struct DriverData2 { +/// pub(crate) id: i64, +/// buf_ptr: *mut u8, +/// len: usize, +/// // this field doesn't implement `Zeroable` +/// other_data: &'static i32, +/// } +/// ``` +pub use ::pin_init_internal::MaybeZeroable; + +/// Initialize and pin a type directly on the stack. +/// +/// # Examples +/// +/// ```rust +/// # #![expect(clippy::disallowed_names)] +/// # #![feature(allocator_api)] +/// # #[path = "../examples/mutex.rs"] mod mutex; use mutex::*; +/// # use pin_init::*; +/// # use core::pin::Pin; +/// #[pin_data] +/// struct Foo { +/// #[pin] +/// a: CMutex<usize>, +/// b: Bar, +/// } +/// +/// #[pin_data] +/// struct Bar { +/// x: u32, +/// } +/// +/// stack_pin_init!(let foo = pin_init!(Foo { +/// a <- CMutex::new(42), +/// b: Bar { +/// x: 64, +/// }, +/// })); +/// let foo: Pin<&mut Foo> = foo; +/// println!("a: {}", &*foo.a.lock()); +/// ``` +/// +/// # Syntax +/// +/// A normal `let` binding with optional type annotation. The expression is expected to implement +/// [`PinInit`]/[`Init`] with the error type [`Infallible`]. If you want to use a different error +/// type, then use [`stack_try_pin_init!`]. +#[macro_export] +macro_rules! stack_pin_init { + (let $var:ident $(: $t:ty)? = $val:expr) => { + let val = $val; + let mut $var = ::core::pin::pin!($crate::__internal::StackInit$(::<$t>)?::uninit()); + let mut $var = match $crate::__internal::StackInit::init($var, val) { + Ok(res) => res, + Err(x) => { + let x: ::core::convert::Infallible = x; + match x {} + } + }; + }; +} + +/// Initialize and pin a type directly on the stack. +/// +/// # Examples +/// +/// ```rust +/// # #![expect(clippy::disallowed_names)] +/// # #![feature(allocator_api)] +/// # #[path = "../examples/error.rs"] mod error; use error::Error; +/// # #[path = "../examples/mutex.rs"] mod mutex; use mutex::*; +/// # use pin_init::*; +/// #[pin_data] +/// struct Foo { +/// #[pin] +/// a: CMutex<usize>, +/// b: Box<Bar>, +/// } +/// +/// struct Bar { +/// x: u32, +/// } +/// +/// stack_try_pin_init!(let foo: Foo = try_pin_init!(Foo { +/// a <- CMutex::new(42), +/// b: Box::try_new(Bar { +/// x: 64, +/// })?, +/// }? Error)); +/// let foo = foo.unwrap(); +/// println!("a: {}", &*foo.a.lock()); +/// ``` +/// +/// ```rust +/// # #![expect(clippy::disallowed_names)] +/// # #![feature(allocator_api)] +/// # #[path = "../examples/error.rs"] mod error; use error::Error; +/// # #[path = "../examples/mutex.rs"] mod mutex; use mutex::*; +/// # use pin_init::*; +/// #[pin_data] +/// struct Foo { +/// #[pin] +/// a: CMutex<usize>, +/// b: Box<Bar>, +/// } +/// +/// struct Bar { +/// x: u32, +/// } +/// +/// stack_try_pin_init!(let foo: Foo =? try_pin_init!(Foo { +/// a <- CMutex::new(42), +/// b: Box::try_new(Bar { +/// x: 64, +/// })?, +/// }? Error)); +/// println!("a: {}", &*foo.a.lock()); +/// # Ok::<_, Error>(()) +/// ``` +/// +/// # Syntax +/// +/// A normal `let` binding with optional type annotation. The expression is expected to implement +/// [`PinInit`]/[`Init`]. This macro assigns a result to the given variable, adding a `?` after the +/// `=` will propagate this error. +#[macro_export] +macro_rules! stack_try_pin_init { + (let $var:ident $(: $t:ty)? = $val:expr) => { + let val = $val; + let mut $var = ::core::pin::pin!($crate::__internal::StackInit$(::<$t>)?::uninit()); + let mut $var = $crate::__internal::StackInit::init($var, val); + }; + (let $var:ident $(: $t:ty)? =? $val:expr) => { + let val = $val; + let mut $var = ::core::pin::pin!($crate::__internal::StackInit$(::<$t>)?::uninit()); + let mut $var = $crate::__internal::StackInit::init($var, val)?; + }; +} + +/// Construct an in-place, pinned initializer for `struct`s. +/// +/// This macro defaults the error to [`Infallible`]. If you need a different error, then use +/// [`try_pin_init!`]. +/// +/// The syntax is almost identical to that of a normal `struct` initializer: +/// +/// ```rust +/// # use pin_init::*; +/// # use core::pin::Pin; +/// #[pin_data] +/// struct Foo { +/// a: usize, +/// b: Bar, +/// } +/// +/// #[pin_data] +/// struct Bar { +/// x: u32, +/// } +/// +/// # fn demo() -> impl PinInit<Foo> { +/// let a = 42; +/// +/// let initializer = pin_init!(Foo { +/// a, +/// b: Bar { +/// x: 64, +/// }, +/// }); +/// # initializer } +/// # Box::pin_init(demo()).unwrap(); +/// ``` +/// +/// Arbitrary Rust expressions can be used to set the value of a variable. +/// +/// The fields are initialized in the order that they appear in the initializer. So it is possible +/// to read already initialized fields using raw pointers. +/// +/// IMPORTANT: You are not allowed to create references to fields of the struct inside of the +/// initializer. +/// +/// # Init-functions +/// +/// When working with this library it is often desired to let others construct your types without +/// giving access to all fields. This is where you would normally write a plain function `new` that +/// would return a new instance of your type. With this library that is also possible. However, +/// there are a few extra things to keep in mind. +/// +/// To create an initializer function, simply declare it like this: +/// +/// ```rust +/// # use pin_init::*; +/// # use core::pin::Pin; +/// # #[pin_data] +/// # struct Foo { +/// # a: usize, +/// # b: Bar, +/// # } +/// # #[pin_data] +/// # struct Bar { +/// # x: u32, +/// # } +/// impl Foo { +/// fn new() -> impl PinInit<Self> { +/// pin_init!(Self { +/// a: 42, +/// b: Bar { +/// x: 64, +/// }, +/// }) +/// } +/// } +/// ``` +/// +/// Users of `Foo` can now create it like this: +/// +/// ```rust +/// # #![expect(clippy::disallowed_names)] +/// # use pin_init::*; +/// # use core::pin::Pin; +/// # #[pin_data] +/// # struct Foo { +/// # a: usize, +/// # b: Bar, +/// # } +/// # #[pin_data] +/// # struct Bar { +/// # x: u32, +/// # } +/// # impl Foo { +/// # fn new() -> impl PinInit<Self> { +/// # pin_init!(Self { +/// # a: 42, +/// # b: Bar { +/// # x: 64, +/// # }, +/// # }) +/// # } +/// # } +/// let foo = Box::pin_init(Foo::new()); +/// ``` +/// +/// They can also easily embed it into their own `struct`s: +/// +/// ```rust +/// # use pin_init::*; +/// # use core::pin::Pin; +/// # #[pin_data] +/// # struct Foo { +/// # a: usize, +/// # b: Bar, +/// # } +/// # #[pin_data] +/// # struct Bar { +/// # x: u32, +/// # } +/// # impl Foo { +/// # fn new() -> impl PinInit<Self> { +/// # pin_init!(Self { +/// # a: 42, +/// # b: Bar { +/// # x: 64, +/// # }, +/// # }) +/// # } +/// # } +/// #[pin_data] +/// struct FooContainer { +/// #[pin] +/// foo1: Foo, +/// #[pin] +/// foo2: Foo, +/// other: u32, +/// } +/// +/// impl FooContainer { +/// fn new(other: u32) -> impl PinInit<Self> { +/// pin_init!(Self { +/// foo1 <- Foo::new(), +/// foo2 <- Foo::new(), +/// other, +/// }) +/// } +/// } +/// ``` +/// +/// Here we see that when using `pin_init!` with `PinInit`, one needs to write `<-` instead of `:`. +/// This signifies that the given field is initialized in-place. As with `struct` initializers, just +/// writing the field (in this case `other`) without `:` or `<-` means `other: other,`. +/// +/// # Syntax +/// +/// As already mentioned in the examples above, inside of `pin_init!` a `struct` initializer with +/// the following modifications is expected: +/// - Fields that you want to initialize in-place have to use `<-` instead of `:`. +/// - In front of the initializer you can write `&this in` to have access to a [`NonNull<Self>`] +/// pointer named `this` inside of the initializer. +/// - Using struct update syntax one can place `..Zeroable::init_zeroed()` at the very end of the +/// struct, this initializes every field with 0 and then runs all initializers specified in the +/// body. This can only be done if [`Zeroable`] is implemented for the struct. +/// +/// For instance: +/// +/// ```rust +/// # use pin_init::*; +/// # use core::{ptr::addr_of_mut, marker::PhantomPinned}; +/// #[pin_data] +/// #[derive(Zeroable)] +/// struct Buf { +/// // `ptr` points into `buf`. +/// ptr: *mut u8, +/// buf: [u8; 64], +/// #[pin] +/// pin: PhantomPinned, +/// } +/// +/// let init = pin_init!(&this in Buf { +/// buf: [0; 64], +/// // SAFETY: TODO. +/// ptr: unsafe { addr_of_mut!((*this.as_ptr()).buf).cast() }, +/// pin: PhantomPinned, +/// }); +/// let init = pin_init!(Buf { +/// buf: [1; 64], +/// ..Zeroable::init_zeroed() +/// }); +/// ``` +/// +/// [`NonNull<Self>`]: core::ptr::NonNull +// For a detailed example of how this macro works, see the module documentation of the hidden +// module `macros` inside of `macros.rs`. +#[macro_export] +macro_rules! pin_init { + ($(&$this:ident in)? $t:ident $(::<$($generics:ty),* $(,)?>)? { + $($fields:tt)* + }) => { + $crate::try_pin_init!($(&$this in)? $t $(::<$($generics),*>)? { + $($fields)* + }? ::core::convert::Infallible) + }; +} + +/// Construct an in-place, fallible pinned initializer for `struct`s. +/// +/// If the initialization can complete without error (or [`Infallible`]), then use [`pin_init!`]. +/// +/// You can use the `?` operator or use `return Err(err)` inside the initializer to stop +/// initialization and return the error. +/// +/// IMPORTANT: if you have `unsafe` code inside of the initializer you have to ensure that when +/// initialization fails, the memory can be safely deallocated without any further modifications. +/// +/// The syntax is identical to [`pin_init!`] with the following exception: you must append `? $type` +/// after the `struct` initializer to specify the error type you want to use. +/// +/// # Examples +/// +/// ```rust +/// # #![feature(allocator_api)] +/// # #[path = "../examples/error.rs"] mod error; use error::Error; +/// use pin_init::{pin_data, try_pin_init, PinInit, InPlaceInit, init_zeroed}; +/// +/// #[pin_data] +/// struct BigBuf { +/// big: Box<[u8; 1024 * 1024 * 1024]>, +/// small: [u8; 1024 * 1024], +/// ptr: *mut u8, +/// } +/// +/// impl BigBuf { +/// fn new() -> impl PinInit<Self, Error> { +/// try_pin_init!(Self { +/// big: Box::init(init_zeroed())?, +/// small: [0; 1024 * 1024], +/// ptr: core::ptr::null_mut(), +/// }? Error) +/// } +/// } +/// # let _ = Box::pin_init(BigBuf::new()); +/// ``` +// For a detailed example of how this macro works, see the module documentation of the hidden +// module `macros` inside of `macros.rs`. +#[macro_export] +macro_rules! try_pin_init { + ($(&$this:ident in)? $t:ident $(::<$($generics:ty),* $(,)?>)? { + $($fields:tt)* + }? $err:ty) => { + $crate::__init_internal!( + @this($($this)?), + @typ($t $(::<$($generics),*>)? ), + @fields($($fields)*), + @error($err), + @data(PinData, use_data), + @has_data(HasPinData, __pin_data), + @construct_closure(pin_init_from_closure), + @munch_fields($($fields)*), + ) + } +} + +/// Construct an in-place initializer for `struct`s. +/// +/// This macro defaults the error to [`Infallible`]. If you need a different error, then use +/// [`try_init!`]. +/// +/// The syntax is identical to [`pin_init!`] and its safety caveats also apply: +/// - `unsafe` code must guarantee either full initialization or return an error and allow +/// deallocation of the memory. +/// - the fields are initialized in the order given in the initializer. +/// - no references to fields are allowed to be created inside of the initializer. +/// +/// This initializer is for initializing data in-place that might later be moved. If you want to +/// pin-initialize, use [`pin_init!`]. +/// +/// # Examples +/// +/// ```rust +/// # #![feature(allocator_api)] +/// # #[path = "../examples/error.rs"] mod error; use error::Error; +/// # #[path = "../examples/mutex.rs"] mod mutex; use mutex::*; +/// # use pin_init::InPlaceInit; +/// use pin_init::{init, Init, init_zeroed}; +/// +/// struct BigBuf { +/// small: [u8; 1024 * 1024], +/// } +/// +/// impl BigBuf { +/// fn new() -> impl Init<Self> { +/// init!(Self { +/// small <- init_zeroed(), +/// }) +/// } +/// } +/// # let _ = Box::init(BigBuf::new()); +/// ``` +// For a detailed example of how this macro works, see the module documentation of the hidden +// module `macros` inside of `macros.rs`. +#[macro_export] +macro_rules! init { + ($(&$this:ident in)? $t:ident $(::<$($generics:ty),* $(,)?>)? { + $($fields:tt)* + }) => { + $crate::try_init!($(&$this in)? $t $(::<$($generics),*>)? { + $($fields)* + }? ::core::convert::Infallible) + } +} + +/// Construct an in-place fallible initializer for `struct`s. +/// +/// If the initialization can complete without error (or [`Infallible`]), then use +/// [`init!`]. +/// +/// The syntax is identical to [`try_pin_init!`]. You need to specify a custom error +/// via `? $type` after the `struct` initializer. +/// The safety caveats from [`try_pin_init!`] also apply: +/// - `unsafe` code must guarantee either full initialization or return an error and allow +/// deallocation of the memory. +/// - the fields are initialized in the order given in the initializer. +/// - no references to fields are allowed to be created inside of the initializer. +/// +/// # Examples +/// +/// ```rust +/// # #![feature(allocator_api)] +/// # use core::alloc::AllocError; +/// # use pin_init::InPlaceInit; +/// use pin_init::{try_init, Init, init_zeroed}; +/// +/// struct BigBuf { +/// big: Box<[u8; 1024 * 1024 * 1024]>, +/// small: [u8; 1024 * 1024], +/// } +/// +/// impl BigBuf { +/// fn new() -> impl Init<Self, AllocError> { +/// try_init!(Self { +/// big: Box::init(init_zeroed())?, +/// small: [0; 1024 * 1024], +/// }? AllocError) +/// } +/// } +/// # let _ = Box::init(BigBuf::new()); +/// ``` +// For a detailed example of how this macro works, see the module documentation of the hidden +// module `macros` inside of `macros.rs`. +#[macro_export] +macro_rules! try_init { + ($(&$this:ident in)? $t:ident $(::<$($generics:ty),* $(,)?>)? { + $($fields:tt)* + }? $err:ty) => { + $crate::__init_internal!( + @this($($this)?), + @typ($t $(::<$($generics),*>)?), + @fields($($fields)*), + @error($err), + @data(InitData, /*no use_data*/), + @has_data(HasInitData, __init_data), + @construct_closure(init_from_closure), + @munch_fields($($fields)*), + ) + } +} + +/// Asserts that a field on a struct using `#[pin_data]` is marked with `#[pin]` ie. that it is +/// structurally pinned. +/// +/// # Examples +/// +/// This will succeed: +/// ``` +/// use pin_init::{pin_data, assert_pinned}; +/// +/// #[pin_data] +/// struct MyStruct { +/// #[pin] +/// some_field: u64, +/// } +/// +/// assert_pinned!(MyStruct, some_field, u64); +/// ``` +/// +/// This will fail: +/// ```compile_fail +/// use pin_init::{pin_data, assert_pinned}; +/// +/// #[pin_data] +/// struct MyStruct { +/// some_field: u64, +/// } +/// +/// assert_pinned!(MyStruct, some_field, u64); +/// ``` +/// +/// Some uses of the macro may trigger the `can't use generic parameters from outer item` error. To +/// work around this, you may pass the `inline` parameter to the macro. The `inline` parameter can +/// only be used when the macro is invoked from a function body. +/// ``` +/// # use core::pin::Pin; +/// use pin_init::{pin_data, assert_pinned}; +/// +/// #[pin_data] +/// struct Foo<T> { +/// #[pin] +/// elem: T, +/// } +/// +/// impl<T> Foo<T> { +/// fn project(self: Pin<&mut Self>) -> Pin<&mut T> { +/// assert_pinned!(Foo<T>, elem, T, inline); +/// +/// // SAFETY: The field is structurally pinned. +/// unsafe { self.map_unchecked_mut(|me| &mut me.elem) } +/// } +/// } +/// ``` +#[macro_export] +macro_rules! assert_pinned { + ($ty:ty, $field:ident, $field_ty:ty, inline) => { + let _ = move |ptr: *mut $field_ty| { + // SAFETY: This code is unreachable. + let data = unsafe { <$ty as $crate::__internal::HasPinData>::__pin_data() }; + let init = $crate::__internal::AlwaysFail::<$field_ty>::new(); + // SAFETY: This code is unreachable. + unsafe { data.$field(ptr, init) }.ok(); + }; + }; + + ($ty:ty, $field:ident, $field_ty:ty) => { + const _: () = { + $crate::assert_pinned!($ty, $field, $field_ty, inline); + }; + }; +} + +/// A pin-initializer for the type `T`. +/// +/// To use this initializer, you will need a suitable memory location that can hold a `T`. This can +/// be [`Box<T>`], [`Arc<T>`] or even the stack (see [`stack_pin_init!`]). +/// +/// Also see the [module description](self). +/// +/// # Safety +/// +/// When implementing this trait you will need to take great care. Also there are probably very few +/// cases where a manual implementation is necessary. Use [`pin_init_from_closure`] where possible. +/// +/// The [`PinInit::__pinned_init`] function: +/// - returns `Ok(())` if it initialized every field of `slot`, +/// - returns `Err(err)` if it encountered an error and then cleaned `slot`, this means: +/// - `slot` can be deallocated without UB occurring, +/// - `slot` does not need to be dropped, +/// - `slot` is not partially initialized. +/// - while constructing the `T` at `slot` it upholds the pinning invariants of `T`. +/// +#[cfg_attr( + kernel, + doc = "[`Arc<T>`]: https://rust.docs.kernel.org/kernel/sync/struct.Arc.html" +)] +#[cfg_attr( + kernel, + doc = "[`Box<T>`]: https://rust.docs.kernel.org/kernel/alloc/kbox/struct.Box.html" +)] +#[cfg_attr(not(kernel), doc = "[`Arc<T>`]: alloc::alloc::sync::Arc")] +#[cfg_attr(not(kernel), doc = "[`Box<T>`]: alloc::alloc::boxed::Box")] +#[must_use = "An initializer must be used in order to create its value."] +pub unsafe trait PinInit<T: ?Sized, E = Infallible>: Sized { + /// Initializes `slot`. + /// + /// # Safety + /// + /// - `slot` is a valid pointer to uninitialized memory. + /// - the caller does not touch `slot` when `Err` is returned, they are only permitted to + /// deallocate. + /// - `slot` will not move until it is dropped, i.e. it will be pinned. + unsafe fn __pinned_init(self, slot: *mut T) -> Result<(), E>; + + /// First initializes the value using `self` then calls the function `f` with the initialized + /// value. + /// + /// If `f` returns an error the value is dropped and the initializer will forward the error. + /// + /// # Examples + /// + /// ```rust + /// # #![feature(allocator_api)] + /// # #[path = "../examples/mutex.rs"] mod mutex; use mutex::*; + /// # use pin_init::*; + /// let mtx_init = CMutex::new(42); + /// // Make the initializer print the value. + /// let mtx_init = mtx_init.pin_chain(|mtx| { + /// println!("{:?}", mtx.get_data_mut()); + /// Ok(()) + /// }); + /// ``` + fn pin_chain<F>(self, f: F) -> ChainPinInit<Self, F, T, E> + where + F: FnOnce(Pin<&mut T>) -> Result<(), E>, + { + ChainPinInit(self, f, PhantomData) + } +} + +/// An initializer returned by [`PinInit::pin_chain`]. +pub struct ChainPinInit<I, F, T: ?Sized, E>(I, F, __internal::Invariant<(E, T)>); + +// SAFETY: The `__pinned_init` function is implemented such that it +// - returns `Ok(())` on successful initialization, +// - returns `Err(err)` on error and in this case `slot` will be dropped. +// - considers `slot` pinned. +unsafe impl<T: ?Sized, E, I, F> PinInit<T, E> for ChainPinInit<I, F, T, E> +where + I: PinInit<T, E>, + F: FnOnce(Pin<&mut T>) -> Result<(), E>, +{ + unsafe fn __pinned_init(self, slot: *mut T) -> Result<(), E> { + // SAFETY: All requirements fulfilled since this function is `__pinned_init`. + unsafe { self.0.__pinned_init(slot)? }; + // SAFETY: The above call initialized `slot` and we still have unique access. + let val = unsafe { &mut *slot }; + // SAFETY: `slot` is considered pinned. + let val = unsafe { Pin::new_unchecked(val) }; + // SAFETY: `slot` was initialized above. + (self.1)(val).inspect_err(|_| unsafe { core::ptr::drop_in_place(slot) }) + } +} + +/// An initializer for `T`. +/// +/// To use this initializer, you will need a suitable memory location that can hold a `T`. This can +/// be [`Box<T>`], [`Arc<T>`] or even the stack (see [`stack_pin_init!`]). Because +/// [`PinInit<T, E>`] is a super trait, you can use every function that takes it as well. +/// +/// Also see the [module description](self). +/// +/// # Safety +/// +/// When implementing this trait you will need to take great care. Also there are probably very few +/// cases where a manual implementation is necessary. Use [`init_from_closure`] where possible. +/// +/// The [`Init::__init`] function: +/// - returns `Ok(())` if it initialized every field of `slot`, +/// - returns `Err(err)` if it encountered an error and then cleaned `slot`, this means: +/// - `slot` can be deallocated without UB occurring, +/// - `slot` does not need to be dropped, +/// - `slot` is not partially initialized. +/// - while constructing the `T` at `slot` it upholds the pinning invariants of `T`. +/// +/// The `__pinned_init` function from the supertrait [`PinInit`] needs to execute the exact same +/// code as `__init`. +/// +/// Contrary to its supertype [`PinInit<T, E>`] the caller is allowed to +/// move the pointee after initialization. +/// +#[cfg_attr( + kernel, + doc = "[`Arc<T>`]: https://rust.docs.kernel.org/kernel/sync/struct.Arc.html" +)] +#[cfg_attr( + kernel, + doc = "[`Box<T>`]: https://rust.docs.kernel.org/kernel/alloc/kbox/struct.Box.html" +)] +#[cfg_attr(not(kernel), doc = "[`Arc<T>`]: alloc::alloc::sync::Arc")] +#[cfg_attr(not(kernel), doc = "[`Box<T>`]: alloc::alloc::boxed::Box")] +#[must_use = "An initializer must be used in order to create its value."] +pub unsafe trait Init<T: ?Sized, E = Infallible>: PinInit<T, E> { + /// Initializes `slot`. + /// + /// # Safety + /// + /// - `slot` is a valid pointer to uninitialized memory. + /// - the caller does not touch `slot` when `Err` is returned, they are only permitted to + /// deallocate. + unsafe fn __init(self, slot: *mut T) -> Result<(), E>; + + /// First initializes the value using `self` then calls the function `f` with the initialized + /// value. + /// + /// If `f` returns an error the value is dropped and the initializer will forward the error. + /// + /// # Examples + /// + /// ```rust + /// # #![expect(clippy::disallowed_names)] + /// use pin_init::{init, init_zeroed, Init}; + /// + /// struct Foo { + /// buf: [u8; 1_000_000], + /// } + /// + /// impl Foo { + /// fn setup(&mut self) { + /// println!("Setting up foo"); + /// } + /// } + /// + /// let foo = init!(Foo { + /// buf <- init_zeroed() + /// }).chain(|foo| { + /// foo.setup(); + /// Ok(()) + /// }); + /// ``` + fn chain<F>(self, f: F) -> ChainInit<Self, F, T, E> + where + F: FnOnce(&mut T) -> Result<(), E>, + { + ChainInit(self, f, PhantomData) + } +} + +/// An initializer returned by [`Init::chain`]. +pub struct ChainInit<I, F, T: ?Sized, E>(I, F, __internal::Invariant<(E, T)>); + +// SAFETY: The `__init` function is implemented such that it +// - returns `Ok(())` on successful initialization, +// - returns `Err(err)` on error and in this case `slot` will be dropped. +unsafe impl<T: ?Sized, E, I, F> Init<T, E> for ChainInit<I, F, T, E> +where + I: Init<T, E>, + F: FnOnce(&mut T) -> Result<(), E>, +{ + unsafe fn __init(self, slot: *mut T) -> Result<(), E> { + // SAFETY: All requirements fulfilled since this function is `__init`. + unsafe { self.0.__pinned_init(slot)? }; + // SAFETY: The above call initialized `slot` and we still have unique access. + (self.1)(unsafe { &mut *slot }).inspect_err(|_| + // SAFETY: `slot` was initialized above. + unsafe { core::ptr::drop_in_place(slot) }) + } +} + +// SAFETY: `__pinned_init` behaves exactly the same as `__init`. +unsafe impl<T: ?Sized, E, I, F> PinInit<T, E> for ChainInit<I, F, T, E> +where + I: Init<T, E>, + F: FnOnce(&mut T) -> Result<(), E>, +{ + unsafe fn __pinned_init(self, slot: *mut T) -> Result<(), E> { + // SAFETY: `__init` has less strict requirements compared to `__pinned_init`. + unsafe { self.__init(slot) } + } +} + +/// Creates a new [`PinInit<T, E>`] from the given closure. +/// +/// # Safety +/// +/// The closure: +/// - returns `Ok(())` if it initialized every field of `slot`, +/// - returns `Err(err)` if it encountered an error and then cleaned `slot`, this means: +/// - `slot` can be deallocated without UB occurring, +/// - `slot` does not need to be dropped, +/// - `slot` is not partially initialized. +/// - may assume that the `slot` does not move if `T: !Unpin`, +/// - while constructing the `T` at `slot` it upholds the pinning invariants of `T`. +#[inline] +pub const unsafe fn pin_init_from_closure<T: ?Sized, E>( + f: impl FnOnce(*mut T) -> Result<(), E>, +) -> impl PinInit<T, E> { + __internal::InitClosure(f, PhantomData) +} + +/// Creates a new [`Init<T, E>`] from the given closure. +/// +/// # Safety +/// +/// The closure: +/// - returns `Ok(())` if it initialized every field of `slot`, +/// - returns `Err(err)` if it encountered an error and then cleaned `slot`, this means: +/// - `slot` can be deallocated without UB occurring, +/// - `slot` does not need to be dropped, +/// - `slot` is not partially initialized. +/// - the `slot` may move after initialization. +/// - while constructing the `T` at `slot` it upholds the pinning invariants of `T`. +#[inline] +pub const unsafe fn init_from_closure<T: ?Sized, E>( + f: impl FnOnce(*mut T) -> Result<(), E>, +) -> impl Init<T, E> { + __internal::InitClosure(f, PhantomData) +} + +/// Changes the to be initialized type. +/// +/// # Safety +/// +/// - `*mut U` must be castable to `*mut T` and any value of type `T` written through such a +/// pointer must result in a valid `U`. +#[expect(clippy::let_and_return)] +pub const unsafe fn cast_pin_init<T, U, E>(init: impl PinInit<T, E>) -> impl PinInit<U, E> { + // SAFETY: initialization delegated to a valid initializer. Cast is valid by function safety + // requirements. + let res = unsafe { pin_init_from_closure(|ptr: *mut U| init.__pinned_init(ptr.cast::<T>())) }; + // FIXME: remove the let statement once the nightly-MSRV allows it (1.78 otherwise encounters a + // cycle when computing the type returned by this function) + res +} + +/// Changes the to be initialized type. +/// +/// # Safety +/// +/// - `*mut U` must be castable to `*mut T` and any value of type `T` written through such a +/// pointer must result in a valid `U`. +#[expect(clippy::let_and_return)] +pub const unsafe fn cast_init<T, U, E>(init: impl Init<T, E>) -> impl Init<U, E> { + // SAFETY: initialization delegated to a valid initializer. Cast is valid by function safety + // requirements. + let res = unsafe { init_from_closure(|ptr: *mut U| init.__init(ptr.cast::<T>())) }; + // FIXME: remove the let statement once the nightly-MSRV allows it (1.78 otherwise encounters a + // cycle when computing the type returned by this function) + res +} + +/// An initializer that leaves the memory uninitialized. +/// +/// The initializer is a no-op. The `slot` memory is not changed. +#[inline] +pub fn uninit<T, E>() -> impl Init<MaybeUninit<T>, E> { + // SAFETY: The memory is allowed to be uninitialized. + unsafe { init_from_closure(|_| Ok(())) } +} + +/// Initializes an array by initializing each element via the provided initializer. +/// +/// # Examples +/// +/// ```rust +/// # use pin_init::*; +/// use pin_init::init_array_from_fn; +/// let array: Box<[usize; 1_000]> = Box::init(init_array_from_fn(|i| i)).unwrap(); +/// assert_eq!(array.len(), 1_000); +/// ``` +pub fn init_array_from_fn<I, const N: usize, T, E>( + mut make_init: impl FnMut(usize) -> I, +) -> impl Init<[T; N], E> +where + I: Init<T, E>, +{ + let init = move |slot: *mut [T; N]| { + let slot = slot.cast::<T>(); + for i in 0..N { + let init = make_init(i); + // SAFETY: Since 0 <= `i` < N, it is still in bounds of `[T; N]`. + let ptr = unsafe { slot.add(i) }; + // SAFETY: The pointer is derived from `slot` and thus satisfies the `__init` + // requirements. + if let Err(e) = unsafe { init.__init(ptr) } { + // SAFETY: The loop has initialized the elements `slot[0..i]` and since we return + // `Err` below, `slot` will be considered uninitialized memory. + unsafe { ptr::drop_in_place(ptr::slice_from_raw_parts_mut(slot, i)) }; + return Err(e); + } + } + Ok(()) + }; + // SAFETY: The initializer above initializes every element of the array. On failure it drops + // any initialized elements and returns `Err`. + unsafe { init_from_closure(init) } +} + +/// Initializes an array by initializing each element via the provided initializer. +/// +/// # Examples +/// +/// ```rust +/// # #![feature(allocator_api)] +/// # #[path = "../examples/mutex.rs"] mod mutex; use mutex::*; +/// # use pin_init::*; +/// # use core::pin::Pin; +/// use pin_init::pin_init_array_from_fn; +/// use std::sync::Arc; +/// let array: Pin<Arc<[CMutex<usize>; 1_000]>> = +/// Arc::pin_init(pin_init_array_from_fn(|i| CMutex::new(i))).unwrap(); +/// assert_eq!(array.len(), 1_000); +/// ``` +pub fn pin_init_array_from_fn<I, const N: usize, T, E>( + mut make_init: impl FnMut(usize) -> I, +) -> impl PinInit<[T; N], E> +where + I: PinInit<T, E>, +{ + let init = move |slot: *mut [T; N]| { + let slot = slot.cast::<T>(); + for i in 0..N { + let init = make_init(i); + // SAFETY: Since 0 <= `i` < N, it is still in bounds of `[T; N]`. + let ptr = unsafe { slot.add(i) }; + // SAFETY: The pointer is derived from `slot` and thus satisfies the `__init` + // requirements. + if let Err(e) = unsafe { init.__pinned_init(ptr) } { + // SAFETY: The loop has initialized the elements `slot[0..i]` and since we return + // `Err` below, `slot` will be considered uninitialized memory. + unsafe { ptr::drop_in_place(ptr::slice_from_raw_parts_mut(slot, i)) }; + return Err(e); + } + } + Ok(()) + }; + // SAFETY: The initializer above initializes every element of the array. On failure it drops + // any initialized elements and returns `Err`. + unsafe { pin_init_from_closure(init) } +} + +// SAFETY: the `__init` function always returns `Ok(())` and initializes every field of `slot`. +unsafe impl<T> Init<T> for T { + unsafe fn __init(self, slot: *mut T) -> Result<(), Infallible> { + // SAFETY: `slot` is valid for writes by the safety requirements of this function. + unsafe { slot.write(self) }; + Ok(()) + } +} + +// SAFETY: the `__pinned_init` function always returns `Ok(())` and initializes every field of +// `slot`. Additionally, all pinning invariants of `T` are upheld. +unsafe impl<T> PinInit<T> for T { + unsafe fn __pinned_init(self, slot: *mut T) -> Result<(), Infallible> { + // SAFETY: `slot` is valid for writes by the safety requirements of this function. + unsafe { slot.write(self) }; + Ok(()) + } +} + +// SAFETY: when the `__init` function returns with +// - `Ok(())`, `slot` was initialized and all pinned invariants of `T` are upheld. +// - `Err(err)`, slot was not written to. +unsafe impl<T, E> Init<T, E> for Result<T, E> { + unsafe fn __init(self, slot: *mut T) -> Result<(), E> { + // SAFETY: `slot` is valid for writes by the safety requirements of this function. + unsafe { slot.write(self?) }; + Ok(()) + } +} + +// SAFETY: when the `__pinned_init` function returns with +// - `Ok(())`, `slot` was initialized and all pinned invariants of `T` are upheld. +// - `Err(err)`, slot was not written to. +unsafe impl<T, E> PinInit<T, E> for Result<T, E> { + unsafe fn __pinned_init(self, slot: *mut T) -> Result<(), E> { + // SAFETY: `slot` is valid for writes by the safety requirements of this function. + unsafe { slot.write(self?) }; + Ok(()) + } +} + +/// Smart pointer containing uninitialized memory and that can write a value. +pub trait InPlaceWrite<T> { + /// The type `Self` turns into when the contents are initialized. + type Initialized; + + /// Use the given initializer to write a value into `self`. + /// + /// Does not drop the current value and considers it as uninitialized memory. + fn write_init<E>(self, init: impl Init<T, E>) -> Result<Self::Initialized, E>; + + /// Use the given pin-initializer to write a value into `self`. + /// + /// Does not drop the current value and considers it as uninitialized memory. + fn write_pin_init<E>(self, init: impl PinInit<T, E>) -> Result<Pin<Self::Initialized>, E>; +} + +/// Trait facilitating pinned destruction. +/// +/// Use [`pinned_drop`] to implement this trait safely: +/// +/// ```rust +/// # #![feature(allocator_api)] +/// # #[path = "../examples/mutex.rs"] mod mutex; use mutex::*; +/// # use pin_init::*; +/// use core::pin::Pin; +/// #[pin_data(PinnedDrop)] +/// struct Foo { +/// #[pin] +/// mtx: CMutex<usize>, +/// } +/// +/// #[pinned_drop] +/// impl PinnedDrop for Foo { +/// fn drop(self: Pin<&mut Self>) { +/// println!("Foo is being dropped!"); +/// } +/// } +/// ``` +/// +/// # Safety +/// +/// This trait must be implemented via the [`pinned_drop`] proc-macro attribute on the impl. +pub unsafe trait PinnedDrop: __internal::HasPinData { + /// Executes the pinned destructor of this type. + /// + /// While this function is marked safe, it is actually unsafe to call it manually. For this + /// reason it takes an additional parameter. This type can only be constructed by `unsafe` code + /// and thus prevents this function from being called where it should not. + /// + /// This extra parameter will be generated by the `#[pinned_drop]` proc-macro attribute + /// automatically. + fn drop(self: Pin<&mut Self>, only_call_from_drop: __internal::OnlyCallFromDrop); +} + +/// Marker trait for types that can be initialized by writing just zeroes. +/// +/// # Safety +/// +/// The bit pattern consisting of only zeroes is a valid bit pattern for this type. In other words, +/// this is not UB: +/// +/// ```rust,ignore +/// let val: Self = unsafe { core::mem::zeroed() }; +/// ``` +pub unsafe trait Zeroable { + /// Create a new zeroed `Self`. + /// + /// The returned initializer will write `0x00` to every byte of the given `slot`. + #[inline] + fn init_zeroed() -> impl Init<Self> + where + Self: Sized, + { + init_zeroed() + } + + /// Create a `Self` consisting of all zeroes. + /// + /// Whenever a type implements [`Zeroable`], this function should be preferred over + /// [`core::mem::zeroed()`] or using `MaybeUninit<T>::zeroed().assume_init()`. + /// + /// # Examples + /// + /// ``` + /// use pin_init::{Zeroable, zeroed}; + /// + /// #[derive(Zeroable)] + /// struct Point { + /// x: u32, + /// y: u32, + /// } + /// + /// let point: Point = zeroed(); + /// assert_eq!(point.x, 0); + /// assert_eq!(point.y, 0); + /// ``` + fn zeroed() -> Self + where + Self: Sized, + { + zeroed() + } +} + +/// Marker trait for types that allow `Option<Self>` to be set to all zeroes in order to write +/// `None` to that location. +/// +/// # Safety +/// +/// The implementer needs to ensure that `unsafe impl Zeroable for Option<Self> {}` is sound. +pub unsafe trait ZeroableOption {} + +// SAFETY: by the safety requirement of `ZeroableOption`, this is valid. +unsafe impl<T: ZeroableOption> Zeroable for Option<T> {} + +// SAFETY: `Option<&T>` is part of the option layout optimization guarantee: +// <https://doc.rust-lang.org/stable/std/option/index.html#representation>. +unsafe impl<T> ZeroableOption for &T {} +// SAFETY: `Option<&mut T>` is part of the option layout optimization guarantee: +// <https://doc.rust-lang.org/stable/std/option/index.html#representation>. +unsafe impl<T> ZeroableOption for &mut T {} +// SAFETY: `Option<NonNull<T>>` is part of the option layout optimization guarantee: +// <https://doc.rust-lang.org/stable/std/option/index.html#representation>. +unsafe impl<T> ZeroableOption for NonNull<T> {} + +/// Create an initializer for a zeroed `T`. +/// +/// The returned initializer will write `0x00` to every byte of the given `slot`. +#[inline] +pub fn init_zeroed<T: Zeroable>() -> impl Init<T> { + // SAFETY: Because `T: Zeroable`, all bytes zero is a valid bit pattern for `T` + // and because we write all zeroes, the memory is initialized. + unsafe { + init_from_closure(|slot: *mut T| { + slot.write_bytes(0, 1); + Ok(()) + }) + } +} + +/// Create a `T` consisting of all zeroes. +/// +/// Whenever a type implements [`Zeroable`], this function should be preferred over +/// [`core::mem::zeroed()`] or using `MaybeUninit<T>::zeroed().assume_init()`. +/// +/// # Examples +/// +/// ``` +/// use pin_init::{Zeroable, zeroed}; +/// +/// #[derive(Zeroable)] +/// struct Point { +/// x: u32, +/// y: u32, +/// } +/// +/// let point: Point = zeroed(); +/// assert_eq!(point.x, 0); +/// assert_eq!(point.y, 0); +/// ``` +pub const fn zeroed<T: Zeroable>() -> T { + // SAFETY:By the type invariants of `Zeroable`, all zeroes is a valid bit pattern for `T`. + unsafe { core::mem::zeroed() } +} + +macro_rules! impl_zeroable { + ($($({$($generics:tt)*})? $t:ty, )*) => { + // SAFETY: Safety comments written in the macro invocation. + $(unsafe impl$($($generics)*)? Zeroable for $t {})* + }; +} + +impl_zeroable! { + // SAFETY: All primitives that are allowed to be zero. + bool, + char, + u8, u16, u32, u64, u128, usize, + i8, i16, i32, i64, i128, isize, + f32, f64, + + // Note: do not add uninhabited types (such as `!` or `core::convert::Infallible`) to this list; + // creating an instance of an uninhabited type is immediate undefined behavior. For more on + // uninhabited/empty types, consult The Rustonomicon: + // <https://doc.rust-lang.org/stable/nomicon/exotic-sizes.html#empty-types>. The Rust Reference + // also has information on undefined behavior: + // <https://doc.rust-lang.org/stable/reference/behavior-considered-undefined.html>. + // + // SAFETY: These are inhabited ZSTs; there is nothing to zero and a valid value exists. + {<T: ?Sized>} PhantomData<T>, core::marker::PhantomPinned, (), + + // SAFETY: Type is allowed to take any value, including all zeros. + {<T>} MaybeUninit<T>, + + // SAFETY: `T: Zeroable` and `UnsafeCell` is `repr(transparent)`. + {<T: ?Sized + Zeroable>} UnsafeCell<T>, + + // SAFETY: All zeros is equivalent to `None` (option layout optimization guarantee: + // <https://doc.rust-lang.org/stable/std/option/index.html#representation>). + Option<NonZeroU8>, Option<NonZeroU16>, Option<NonZeroU32>, Option<NonZeroU64>, + Option<NonZeroU128>, Option<NonZeroUsize>, + Option<NonZeroI8>, Option<NonZeroI16>, Option<NonZeroI32>, Option<NonZeroI64>, + Option<NonZeroI128>, Option<NonZeroIsize>, + + // SAFETY: `null` pointer is valid. + // + // We cannot use `T: ?Sized`, since the VTABLE pointer part of fat pointers is not allowed to be + // null. + // + // When `Pointee` gets stabilized, we could use + // `T: ?Sized where <T as Pointee>::Metadata: Zeroable` + {<T>} *mut T, {<T>} *const T, + + // SAFETY: `null` pointer is valid and the metadata part of these fat pointers is allowed to be + // zero. + {<T>} *mut [T], {<T>} *const [T], *mut str, *const str, + + // SAFETY: `T` is `Zeroable`. + {<const N: usize, T: Zeroable>} [T; N], {<T: Zeroable>} Wrapping<T>, +} + +macro_rules! impl_tuple_zeroable { + ($(,)?) => {}; + ($first:ident, $($t:ident),* $(,)?) => { + // SAFETY: All elements are zeroable and padding can be zero. + unsafe impl<$first: Zeroable, $($t: Zeroable),*> Zeroable for ($first, $($t),*) {} + impl_tuple_zeroable!($($t),* ,); + } +} + +impl_tuple_zeroable!(A, B, C, D, E, F, G, H, I, J); + +macro_rules! impl_fn_zeroable_option { + ([$($abi:literal),* $(,)?] $args:tt) => { + $(impl_fn_zeroable_option!({extern $abi} $args);)* + $(impl_fn_zeroable_option!({unsafe extern $abi} $args);)* + }; + ({$($prefix:tt)*} {$(,)?}) => {}; + ({$($prefix:tt)*} {$ret:ident, $($rest:ident),* $(,)?}) => { + // SAFETY: function pointers are part of the option layout optimization: + // <https://doc.rust-lang.org/stable/std/option/index.html#representation>. + unsafe impl<$ret, $($rest),*> ZeroableOption for $($prefix)* fn($($rest),*) -> $ret {} + impl_fn_zeroable_option!({$($prefix)*} {$($rest),*,}); + }; +} + +impl_fn_zeroable_option!(["Rust", "C"] { A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U }); + +/// This trait allows creating an instance of `Self` which contains exactly one +/// [structurally pinned value](https://doc.rust-lang.org/std/pin/index.html#projections-and-structural-pinning). +/// +/// This is useful when using wrapper `struct`s like [`UnsafeCell`] or with new-type `struct`s. +/// +/// # Examples +/// +/// ``` +/// # use core::cell::UnsafeCell; +/// # use pin_init::{pin_data, pin_init, Wrapper}; +/// +/// #[pin_data] +/// struct Foo {} +/// +/// #[pin_data] +/// struct Bar { +/// #[pin] +/// content: UnsafeCell<Foo> +/// }; +/// +/// let foo_initializer = pin_init!(Foo{}); +/// let initializer = pin_init!(Bar { +/// content <- UnsafeCell::pin_init(foo_initializer) +/// }); +/// ``` +pub trait Wrapper<T> { + /// Creates an pin-initializer for a [`Self`] containing `T` from the `value_init` initializer. + fn pin_init<E>(value_init: impl PinInit<T, E>) -> impl PinInit<Self, E>; +} + +impl<T> Wrapper<T> for UnsafeCell<T> { + fn pin_init<E>(value_init: impl PinInit<T, E>) -> impl PinInit<Self, E> { + // SAFETY: `UnsafeCell<T>` has a compatible layout to `T`. + unsafe { cast_pin_init(value_init) } + } +} + +impl<T> Wrapper<T> for MaybeUninit<T> { + fn pin_init<E>(value_init: impl PinInit<T, E>) -> impl PinInit<Self, E> { + // SAFETY: `MaybeUninit<T>` has a compatible layout to `T`. + unsafe { cast_pin_init(value_init) } + } +} + +#[cfg(all(feature = "unsafe-pinned", CONFIG_RUSTC_HAS_UNSAFE_PINNED))] +impl<T> Wrapper<T> for core::pin::UnsafePinned<T> { + fn pin_init<E>(init: impl PinInit<T, E>) -> impl PinInit<Self, E> { + // SAFETY: `UnsafePinned<T>` has a compatible layout to `T`. + unsafe { cast_pin_init(init) } + } +} diff --git a/rust/kernel/init/macros.rs b/rust/pin-init/src/macros.rs index 1fd146a83241..9ced630737b8 100644 --- a/rust/kernel/init/macros.rs +++ b/rust/pin-init/src/macros.rs @@ -1,8 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 OR MIT //! This module provides the macros that actually implement the proc-macros `pin_data` and -//! `pinned_drop`. It also contains `__init_internal` the implementation of the `{try_}{pin_}init!` -//! macros. +//! `pinned_drop`. It also contains `__init_internal`, the implementation of the +//! `{try_}{pin_}init!` macros. //! //! These macros should never be called directly, since they expect their input to be //! in a certain format which is internal. If used incorrectly, these macros can lead to UB even in @@ -11,16 +11,17 @@ //! This architecture has been chosen because the kernel does not yet have access to `syn` which //! would make matters a lot easier for implementing these as proc-macros. //! +//! Since this library and the kernel implementation should diverge as little as possible, the same +//! approach has been taken here. +//! //! # Macro expansion example //! //! This section is intended for readers trying to understand the macros in this module and the -//! `pin_init!` macros from `init.rs`. +//! `[try_][pin_]init!` macros from `lib.rs`. //! //! We will look at the following example: //! //! ```rust,ignore -//! # use kernel::init::*; -//! # use core::pin::Pin; //! #[pin_data] //! #[repr(C)] //! struct Bar<T> { @@ -45,7 +46,7 @@ //! #[pinned_drop] //! impl PinnedDrop for Foo { //! fn drop(self: Pin<&mut Self>) { -//! pr_info!("{self:p} is getting dropped."); +//! println!("{self:p} is getting dropped."); //! } //! } //! @@ -75,7 +76,6 @@ //! Here is the definition of `Bar` from our example: //! //! ```rust,ignore -//! # use kernel::init::*; //! #[pin_data] //! #[repr(C)] //! struct Bar<T> { @@ -121,22 +121,22 @@ //! self, //! slot: *mut T, //! // Since `t` is `#[pin]`, this is `PinInit`. -//! init: impl ::kernel::init::PinInit<T, E>, +//! init: impl ::pin_init::PinInit<T, E>, //! ) -> ::core::result::Result<(), E> { -//! unsafe { ::kernel::init::PinInit::__pinned_init(init, slot) } +//! unsafe { ::pin_init::PinInit::__pinned_init(init, slot) } //! } //! pub unsafe fn x<E>( //! self, //! slot: *mut usize, //! // Since `x` is not `#[pin]`, this is `Init`. -//! init: impl ::kernel::init::Init<usize, E>, +//! init: impl ::pin_init::Init<usize, E>, //! ) -> ::core::result::Result<(), E> { -//! unsafe { ::kernel::init::Init::__init(init, slot) } +//! unsafe { ::pin_init::Init::__init(init, slot) } //! } //! } //! // Implement the internal `HasPinData` trait that associates `Bar` with the pin-data struct //! // that we constructed above. -//! unsafe impl<T> ::kernel::init::__internal::HasPinData for Bar<T> { +//! unsafe impl<T> ::pin_init::__internal::HasPinData for Bar<T> { //! type PinData = __ThePinData<T>; //! unsafe fn __pin_data() -> Self::PinData { //! __ThePinData { @@ -147,7 +147,7 @@ //! // Implement the internal `PinData` trait that marks the pin-data struct as a pin-data //! // struct. This is important to ensure that no user can implement a rogue `__pin_data` //! // function without using `unsafe`. -//! unsafe impl<T> ::kernel::init::__internal::PinData for __ThePinData<T> { +//! unsafe impl<T> ::pin_init::__internal::PinData for __ThePinData<T> { //! type Datee = Bar<T>; //! } //! // Now we only want to implement `Unpin` for `Bar` when every structurally pinned field is @@ -191,7 +191,7 @@ //! #[expect(non_camel_case_types)] //! trait UselessPinnedDropImpl_you_need_to_specify_PinnedDrop {} //! impl< -//! T: ::kernel::init::PinnedDrop, +//! T: ::pin_init::PinnedDrop, //! > UselessPinnedDropImpl_you_need_to_specify_PinnedDrop for T {} //! impl<T> UselessPinnedDropImpl_you_need_to_specify_PinnedDrop for Bar<T> {} //! }; @@ -227,11 +227,11 @@ //! // - we `use` the `HasPinData` trait in the block, it is only available in that //! // scope. //! let data = unsafe { -//! use ::kernel::init::__internal::HasPinData; +//! use ::pin_init::__internal::HasPinData; //! Self::__pin_data() //! }; //! // Ensure that `data` really is of type `PinData` and help with type inference: -//! let init = ::kernel::init::__internal::PinData::make_closure::< +//! let init = ::pin_init::__internal::PinData::make_closure::< //! _, //! __InitOk, //! ::core::convert::Infallible, @@ -251,7 +251,7 @@ //! // is an error later. This `DropGuard` will drop the field when it gets //! // dropped and has not yet been forgotten. //! let __t_guard = unsafe { -//! ::pinned_init::__internal::DropGuard::new(::core::addr_of_mut!((*slot).t)) +//! ::pin_init::__internal::DropGuard::new(::core::addr_of_mut!((*slot).t)) //! }; //! // Expansion of `x: 0,`: //! // Since this can be an arbitrary expression we cannot place it inside @@ -262,7 +262,7 @@ //! } //! // We again create a `DropGuard`. //! let __x_guard = unsafe { -//! ::kernel::init::__internal::DropGuard::new(::core::addr_of_mut!((*slot).x)) +//! ::pin_init::__internal::DropGuard::new(::core::addr_of_mut!((*slot).x)) //! }; //! // Since initialization has successfully completed, we can now forget //! // the guards. This is not `mem::forget`, since we only have @@ -303,7 +303,7 @@ //! }; //! // Construct the initializer. //! let init = unsafe { -//! ::kernel::init::pin_init_from_closure::< +//! ::pin_init::pin_init_from_closure::< //! _, //! ::core::convert::Infallible, //! >(init) @@ -350,19 +350,19 @@ //! unsafe fn b<E>( //! self, //! slot: *mut Bar<u32>, -//! init: impl ::kernel::init::PinInit<Bar<u32>, E>, +//! init: impl ::pin_init::PinInit<Bar<u32>, E>, //! ) -> ::core::result::Result<(), E> { -//! unsafe { ::kernel::init::PinInit::__pinned_init(init, slot) } +//! unsafe { ::pin_init::PinInit::__pinned_init(init, slot) } //! } //! unsafe fn a<E>( //! self, //! slot: *mut usize, -//! init: impl ::kernel::init::Init<usize, E>, +//! init: impl ::pin_init::Init<usize, E>, //! ) -> ::core::result::Result<(), E> { -//! unsafe { ::kernel::init::Init::__init(init, slot) } +//! unsafe { ::pin_init::Init::__init(init, slot) } //! } //! } -//! unsafe impl ::kernel::init::__internal::HasPinData for Foo { +//! unsafe impl ::pin_init::__internal::HasPinData for Foo { //! type PinData = __ThePinData; //! unsafe fn __pin_data() -> Self::PinData { //! __ThePinData { @@ -370,7 +370,7 @@ //! } //! } //! } -//! unsafe impl ::kernel::init::__internal::PinData for __ThePinData { +//! unsafe impl ::pin_init::__internal::PinData for __ThePinData { //! type Datee = Foo; //! } //! #[allow(dead_code)] @@ -394,8 +394,8 @@ //! let pinned = unsafe { ::core::pin::Pin::new_unchecked(self) }; //! // Create the unsafe token that proves that we are inside of a destructor, this //! // type is only allowed to be created in a destructor. -//! let token = unsafe { ::kernel::init::__internal::OnlyCallFromDrop::new() }; -//! ::kernel::init::PinnedDrop::drop(pinned, token); +//! let token = unsafe { ::pin_init::__internal::OnlyCallFromDrop::new() }; +//! ::pin_init::PinnedDrop::drop(pinned, token); //! } //! } //! }; @@ -412,7 +412,7 @@ //! #[pinned_drop] //! impl PinnedDrop for Foo { //! fn drop(self: Pin<&mut Self>) { -//! pr_info!("{self:p} is getting dropped."); +//! println!("{self:p} is getting dropped."); //! } //! } //! ``` @@ -421,9 +421,9 @@ //! //! ```rust,ignore //! // `unsafe`, full path and the token parameter are added, everything else stays the same. -//! unsafe impl ::kernel::init::PinnedDrop for Foo { -//! fn drop(self: Pin<&mut Self>, _: ::kernel::init::__internal::OnlyCallFromDrop) { -//! pr_info!("{self:p} is getting dropped."); +//! unsafe impl ::pin_init::PinnedDrop for Foo { +//! fn drop(self: Pin<&mut Self>, _: ::pin_init::__internal::OnlyCallFromDrop) { +//! println!("{self:p} is getting dropped."); //! } //! } //! ``` @@ -448,10 +448,10 @@ //! let initializer = { //! struct __InitOk; //! let data = unsafe { -//! use ::kernel::init::__internal::HasPinData; +//! use ::pin_init::__internal::HasPinData; //! Foo::__pin_data() //! }; -//! let init = ::kernel::init::__internal::PinData::make_closure::< +//! let init = ::pin_init::__internal::PinData::make_closure::< //! _, //! __InitOk, //! ::core::convert::Infallible, @@ -462,12 +462,12 @@ //! unsafe { ::core::ptr::write(::core::addr_of_mut!((*slot).a), a) }; //! } //! let __a_guard = unsafe { -//! ::kernel::init::__internal::DropGuard::new(::core::addr_of_mut!((*slot).a)) +//! ::pin_init::__internal::DropGuard::new(::core::addr_of_mut!((*slot).a)) //! }; //! let init = Bar::new(36); //! unsafe { data.b(::core::addr_of_mut!((*slot).b), b)? }; //! let __b_guard = unsafe { -//! ::kernel::init::__internal::DropGuard::new(::core::addr_of_mut!((*slot).b)) +//! ::pin_init::__internal::DropGuard::new(::core::addr_of_mut!((*slot).b)) //! }; //! ::core::mem::forget(__b_guard); //! ::core::mem::forget(__a_guard); @@ -492,12 +492,17 @@ //! init(slot).map(|__InitOk| ()) //! }; //! let init = unsafe { -//! ::kernel::init::pin_init_from_closure::<_, ::core::convert::Infallible>(init) +//! ::pin_init::pin_init_from_closure::<_, ::core::convert::Infallible>(init) //! }; //! init //! }; //! ``` +#[cfg(kernel)] +pub use ::macros::paste; +#[cfg(not(kernel))] +pub use ::paste::paste; + /// Creates a `unsafe impl<...> PinnedDrop for $type` block. /// /// See [`PinnedDrop`] for more information. @@ -517,7 +522,7 @@ macro_rules! __pinned_drop { unsafe $($impl_sig)* { // Inherit all attributes and the type/ident tokens for the signature. $(#[$($attr)*])* - fn drop($($sig)*, _: $crate::init::__internal::OnlyCallFromDrop) { + fn drop($($sig)*, _: $crate::__internal::OnlyCallFromDrop) { $($inner)* } } @@ -863,7 +868,7 @@ macro_rules! __pin_data { // SAFETY: We have added the correct projection functions above to `__ThePinData` and // we also use the least restrictive generics possible. unsafe impl<$($impl_generics)*> - $crate::init::__internal::HasPinData for $name<$($ty_generics)*> + $crate::__internal::HasPinData for $name<$($ty_generics)*> where $($whr)* { type PinData = __ThePinData<$($ty_generics)*>; @@ -875,7 +880,7 @@ macro_rules! __pin_data { // SAFETY: TODO. unsafe impl<$($impl_generics)*> - $crate::init::__internal::PinData for __ThePinData<$($ty_generics)*> + $crate::__internal::PinData for __ThePinData<$($ty_generics)*> where $($whr)* { type Datee = $name<$($ty_generics)*>; @@ -934,7 +939,7 @@ macro_rules! __pin_data { // `PinnedDrop` as the parameter to `#[pin_data]`. #[expect(non_camel_case_types)] trait UselessPinnedDropImpl_you_need_to_specify_PinnedDrop {} - impl<T: $crate::init::PinnedDrop> + impl<T: $crate::PinnedDrop> UselessPinnedDropImpl_you_need_to_specify_PinnedDrop for T {} impl<$($impl_generics)*> UselessPinnedDropImpl_you_need_to_specify_PinnedDrop for $name<$($ty_generics)*> @@ -957,8 +962,8 @@ macro_rules! __pin_data { let pinned = unsafe { ::core::pin::Pin::new_unchecked(self) }; // SAFETY: Since this is a drop function, we can create this token to call the // pinned destructor of this type. - let token = unsafe { $crate::init::__internal::OnlyCallFromDrop::new() }; - $crate::init::PinnedDrop::drop(pinned, token); + let token = unsafe { $crate::__internal::OnlyCallFromDrop::new() }; + $crate::PinnedDrop::drop(pinned, token); } } }; @@ -998,10 +1003,10 @@ macro_rules! __pin_data { $pvis unsafe fn $p_field<E>( self, slot: *mut $p_type, - init: impl $crate::init::PinInit<$p_type, E>, + init: impl $crate::PinInit<$p_type, E>, ) -> ::core::result::Result<(), E> { // SAFETY: TODO. - unsafe { $crate::init::PinInit::__pinned_init(init, slot) } + unsafe { $crate::PinInit::__pinned_init(init, slot) } } )* $( @@ -1009,10 +1014,10 @@ macro_rules! __pin_data { $fvis unsafe fn $field<E>( self, slot: *mut $type, - init: impl $crate::init::Init<$type, E>, + init: impl $crate::Init<$type, E>, ) -> ::core::result::Result<(), E> { // SAFETY: TODO. - unsafe { $crate::init::Init::__init(init, slot) } + unsafe { $crate::Init::__init(init, slot) } } )* } @@ -1025,7 +1030,7 @@ macro_rules! __pin_data { /// /// This macro has multiple internal call configurations, these are always the very first ident: /// - nothing: this is the base case and called by the `{try_}{pin_}init!` macros. -/// - `with_update_parsed`: when the `..Zeroable::zeroed()` syntax has been handled. +/// - `with_update_parsed`: when the `..Zeroable::init_zeroed()` syntax has been handled. /// - `init_slot`: recursively creates the code that initializes all fields in `slot`. /// - `make_initializer`: recursively create the struct initializer that guarantees that every /// field has been initialized exactly once. @@ -1054,7 +1059,7 @@ macro_rules! __init_internal { @data($data, $($use_data)?), @has_data($has_data, $get_data), @construct_closure($construct_closure), - @zeroed(), // Nothing means default behavior. + @init_zeroed(), // Nothing means default behavior. ) }; ( @@ -1069,7 +1074,7 @@ macro_rules! __init_internal { @has_data($has_data:ident, $get_data:ident), // `pin_init_from_closure` or `init_from_closure`. @construct_closure($construct_closure:ident), - @munch_fields(..Zeroable::zeroed()), + @munch_fields(..Zeroable::init_zeroed()), ) => { $crate::__init_internal!(with_update_parsed: @this($($this)?), @@ -1079,7 +1084,7 @@ macro_rules! __init_internal { @data($data, $($use_data)?), @has_data($has_data, $get_data), @construct_closure($construct_closure), - @zeroed(()), // `()` means zero all fields not mentioned. + @init_zeroed(()), // `()` means zero all fields not mentioned. ) }; ( @@ -1119,7 +1124,7 @@ macro_rules! __init_internal { @has_data($has_data:ident, $get_data:ident), // `pin_init_from_closure` or `init_from_closure`. @construct_closure($construct_closure:ident), - @zeroed($($init_zeroed:expr)?), + @init_zeroed($($init_zeroed:expr)?), ) => {{ // We do not want to allow arbitrary returns, so we declare this type as the `Ok` return // type and shadow it later when we insert the arbitrary user code. That way there will be @@ -1129,15 +1134,15 @@ macro_rules! __init_internal { // // SAFETY: TODO. let data = unsafe { - use $crate::init::__internal::$has_data; + use $crate::__internal::$has_data; // Here we abuse `paste!` to retokenize `$t`. Declarative macros have some internal // information that is associated to already parsed fragments, so a path fragment // cannot be used in this position. Doing the retokenization results in valid rust // code. - ::kernel::macros::paste!($t::$get_data()) + $crate::macros::paste!($t::$get_data()) }; // Ensure that `data` really is of type `$data` and help with type inference: - let init = $crate::init::__internal::$data::make_closure::<_, __InitOk, $err>( + let init = $crate::__internal::$data::make_closure::<_, __InitOk, $err>( data, move |slot| { { @@ -1147,7 +1152,7 @@ macro_rules! __init_internal { // error when fields are missing (since they will be zeroed). We also have to // check that the type actually implements `Zeroable`. $({ - fn assert_zeroable<T: $crate::init::Zeroable>(_: *mut T) {} + fn assert_zeroable<T: $crate::Zeroable>(_: *mut T) {} // Ensure that the struct is indeed `Zeroable`. assert_zeroable(slot); // SAFETY: The type implements `Zeroable` by the check above. @@ -1184,14 +1189,14 @@ macro_rules! __init_internal { init(slot).map(|__InitOk| ()) }; // SAFETY: TODO. - let init = unsafe { $crate::init::$construct_closure::<_, $err>(init) }; + let init = unsafe { $crate::$construct_closure::<_, $err>(init) }; init }}; (init_slot($($use_data:ident)?): @data($data:ident), @slot($slot:ident), @guards($($guards:ident,)*), - @munch_fields($(..Zeroable::zeroed())? $(,)?), + @munch_fields($(..Zeroable::init_zeroed())? $(,)?), ) => { // Endpoint of munching, no fields are left. If execution reaches this point, all fields // have been initialized. Therefore we can now dismiss the guards by forgetting them. @@ -1215,10 +1220,10 @@ macro_rules! __init_internal { // // We rely on macro hygiene to make it impossible for users to access this local variable. // We use `paste!` to create new hygiene for `$field`. - ::kernel::macros::paste! { + $crate::macros::paste! { // SAFETY: We forget the guard later when initialization has succeeded. let [< __ $field _guard >] = unsafe { - $crate::init::__internal::DropGuard::new(::core::ptr::addr_of_mut!((*$slot).$field)) + $crate::__internal::DropGuard::new(::core::ptr::addr_of_mut!((*$slot).$field)) }; $crate::__init_internal!(init_slot($use_data): @@ -1241,15 +1246,15 @@ macro_rules! __init_internal { // // SAFETY: `slot` is valid, because we are inside of an initializer closure, we // return when an error/panic occurs. - unsafe { $crate::init::Init::__init(init, ::core::ptr::addr_of_mut!((*$slot).$field))? }; + unsafe { $crate::Init::__init(init, ::core::ptr::addr_of_mut!((*$slot).$field))? }; // Create the drop guard: // // We rely on macro hygiene to make it impossible for users to access this local variable. // We use `paste!` to create new hygiene for `$field`. - ::kernel::macros::paste! { + $crate::macros::paste! { // SAFETY: We forget the guard later when initialization has succeeded. let [< __ $field _guard >] = unsafe { - $crate::init::__internal::DropGuard::new(::core::ptr::addr_of_mut!((*$slot).$field)) + $crate::__internal::DropGuard::new(::core::ptr::addr_of_mut!((*$slot).$field)) }; $crate::__init_internal!(init_slot(): @@ -1278,10 +1283,10 @@ macro_rules! __init_internal { // // We rely on macro hygiene to make it impossible for users to access this local variable. // We use `paste!` to create new hygiene for `$field`. - ::kernel::macros::paste! { + $crate::macros::paste! { // SAFETY: We forget the guard later when initialization has succeeded. let [< __ $field _guard >] = unsafe { - $crate::init::__internal::DropGuard::new(::core::ptr::addr_of_mut!((*$slot).$field)) + $crate::__internal::DropGuard::new(::core::ptr::addr_of_mut!((*$slot).$field)) }; $crate::__init_internal!(init_slot($($use_data)?): @@ -1295,11 +1300,11 @@ macro_rules! __init_internal { (make_initializer: @slot($slot:ident), @type_name($t:path), - @munch_fields(..Zeroable::zeroed() $(,)?), + @munch_fields(..Zeroable::init_zeroed() $(,)?), @acc($($acc:tt)*), ) => { // Endpoint, nothing more to munch, create the initializer. Since the users specified - // `..Zeroable::zeroed()`, the slot will already have been zeroed and all field that have + // `..Zeroable::init_zeroed()`, the slot will already have been zeroed and all field that have // not been overwritten are thus zero and initialized. We still check that all fields are // actually accessible by using the struct update syntax ourselves. // We are inside of a closure that is never executed and thus we can abuse `slot` to @@ -1315,7 +1320,7 @@ macro_rules! __init_internal { // information that is associated to already parsed fragments, so a path fragment // cannot be used in this position. Doing the retokenization results in valid rust // code. - ::kernel::macros::paste!( + $crate::macros::paste!( ::core::ptr::write($slot, $t { $($acc)* ..zeroed @@ -1339,7 +1344,7 @@ macro_rules! __init_internal { // information that is associated to already parsed fragments, so a path fragment // cannot be used in this position. Doing the retokenization results in valid rust // code. - ::kernel::macros::paste!( + $crate::macros::paste!( ::core::ptr::write($slot, $t { $($acc)* }); @@ -1388,18 +1393,48 @@ macro_rules! __derive_zeroable { @body({ $( $(#[$($field_attr:tt)*])* - $field:ident : $field_ty:ty + $field_vis:vis $field:ident : $field_ty:ty + ),* $(,)? + }), + ) => { + // SAFETY: Every field type implements `Zeroable` and padding bytes may be zero. + #[automatically_derived] + unsafe impl<$($impl_generics)*> $crate::Zeroable for $name<$($ty_generics)*> + where + $($($whr)*)? + {} + const _: () = { + fn assert_zeroable<T: ?::core::marker::Sized + $crate::Zeroable>() {} + fn ensure_zeroable<$($impl_generics)*>() + where $($($whr)*)? + { + $(assert_zeroable::<$field_ty>();)* + } + }; + }; + (parse_input: + @sig( + $(#[$($struct_attr:tt)*])* + $vis:vis union $name:ident + $(where $($whr:tt)*)? + ), + @impl_generics($($impl_generics:tt)*), + @ty_generics($($ty_generics:tt)*), + @body({ + $( + $(#[$($field_attr:tt)*])* + $field_vis:vis $field:ident : $field_ty:ty ),* $(,)? }), ) => { // SAFETY: Every field type implements `Zeroable` and padding bytes may be zero. #[automatically_derived] - unsafe impl<$($impl_generics)*> $crate::init::Zeroable for $name<$($ty_generics)*> + unsafe impl<$($impl_generics)*> $crate::Zeroable for $name<$($ty_generics)*> where $($($whr)*)? {} const _: () = { - fn assert_zeroable<T: ?::core::marker::Sized + $crate::init::Zeroable>() {} + fn assert_zeroable<T: ?::core::marker::Sized + $crate::Zeroable>() {} fn ensure_zeroable<$($impl_generics)*>() where $($($whr)*)? { @@ -1408,3 +1443,62 @@ macro_rules! __derive_zeroable { }; }; } + +#[doc(hidden)] +#[macro_export] +macro_rules! __maybe_derive_zeroable { + (parse_input: + @sig( + $(#[$($struct_attr:tt)*])* + $vis:vis struct $name:ident + $(where $($whr:tt)*)? + ), + @impl_generics($($impl_generics:tt)*), + @ty_generics($($ty_generics:tt)*), + @body({ + $( + $(#[$($field_attr:tt)*])* + $field_vis:vis $field:ident : $field_ty:ty + ),* $(,)? + }), + ) => { + // SAFETY: Every field type implements `Zeroable` and padding bytes may be zero. + #[automatically_derived] + unsafe impl<$($impl_generics)*> $crate::Zeroable for $name<$($ty_generics)*> + where + $( + // the `for<'__dummy>` HRTB makes this not error without the `trivial_bounds` + // feature <https://github.com/rust-lang/rust/issues/48214#issuecomment-2557829956>. + $field_ty: for<'__dummy> $crate::Zeroable, + )* + $($($whr)*)? + {} + }; + (parse_input: + @sig( + $(#[$($struct_attr:tt)*])* + $vis:vis union $name:ident + $(where $($whr:tt)*)? + ), + @impl_generics($($impl_generics:tt)*), + @ty_generics($($ty_generics:tt)*), + @body({ + $( + $(#[$($field_attr:tt)*])* + $field_vis:vis $field:ident : $field_ty:ty + ),* $(,)? + }), + ) => { + // SAFETY: Every field type implements `Zeroable` and padding bytes may be zero. + #[automatically_derived] + unsafe impl<$($impl_generics)*> $crate::Zeroable for $name<$($ty_generics)*> + where + $( + // the `for<'__dummy>` HRTB makes this not error without the `trivial_bounds` + // feature <https://github.com/rust-lang/rust/issues/48214#issuecomment-2557829956>. + $field_ty: for<'__dummy> $crate::Zeroable, + )* + $($($whr)*)? + {} + }; +} diff --git a/rust/uapi/lib.rs b/rust/uapi/lib.rs index 13495910271f..31c2f713313f 100644 --- a/rust/uapi/lib.rs +++ b/rust/uapi/lib.rs @@ -14,6 +14,9 @@ #![cfg_attr(test, allow(unsafe_op_in_unsafe_fn))] #![allow( clippy::all, + clippy::cast_lossless, + clippy::ptr_as_ptr, + clippy::ref_as_ptr, clippy::undocumented_unsafe_blocks, dead_code, missing_docs, @@ -24,6 +27,7 @@ unreachable_pub, unsafe_op_in_unsafe_fn )] +#![cfg_attr(CONFIG_RUSTC_HAS_UNNECESSARY_TRANSMUTES, allow(unnecessary_transmutes))] // Manual definition of blocklisted types. type __kernel_size_t = usize; diff --git a/rust/uapi/uapi_helper.h b/rust/uapi/uapi_helper.h index 76d3f103e764..1409441359f5 100644 --- a/rust/uapi/uapi_helper.h +++ b/rust/uapi/uapi_helper.h @@ -7,6 +7,8 @@ */ #include <uapi/asm-generic/ioctl.h> +#include <uapi/drm/drm.h> +#include <uapi/drm/nova_drm.h> #include <uapi/linux/mdio.h> #include <uapi/linux/mii.h> #include <uapi/linux/ethtool.h> |