Conversation
3bed613 to
d379550
Compare
e68df50 to
622c366
Compare
d212496 to
4a63e18
Compare
There was a problem hiding this comment.
Pull request overview
Adds a native Metal backend to CubeCL and wires it into the workspace’s build/test/doc tooling and codegen so Metal can be selected as a runtime and supported in CI workflows.
Changes:
- Introduces new
cubecl-metalcrate implementing a Metal runtime/server/stream/memory/storage stack. - Exposes Metal runtime via
cubeclfeature flags (metal) andTestRuntimecfg plumbing. - Updates
xtask+ CI workflow to supportdoc --ciand to exclude Metal/CUDA/HIP crates in CI.
Reviewed changes
Copilot reviewed 32 out of 32 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
xtask/src/main.rs |
Adds doc subcommand wiring to the xtask CLI. |
xtask/src/commands/mod.rs |
Registers new doc module. |
xtask/src/commands/doc.rs |
Adds --ci handling for docs, excluding unsupported crates. |
xtask/src/commands/validate.rs |
Routes doc validation through the new CubeCL doc command args. |
xtask/src/commands/build.rs |
Excludes cubecl-metal (and CUDA/HIP) in CI builds. |
xtask/src/commands/test.rs |
Excludes cubecl-metal (and CUDA/HIP) in CI tests. |
xtask/src/commands/check.rs |
Adds CI-specific workspace clippy that excludes platform-specific crates. |
.github/workflows/ci.yml |
Switches CI invocations to use --ci for check and doc. |
crates/cubecl/src/lib.rs |
Re-exports cubecl_metal behind feature = "metal" and adds test_runtime_metal. |
crates/cubecl/Cargo.toml |
Re-points metal feature to cubecl-metal and adds the dependency. |
crates/cubecl/build.rs |
Adds test_runtime_metal check-cfg and feature wiring. |
crates/cubecl-metal/* |
New Metal runtime implementation (device selection, compilation, server, stream/event sync, storage). |
crates/cubecl-cpp/src/shared/variable.rs |
Updates pointer/atomic formatting to include Metal address spaces. |
crates/cubecl-cpp/src/shared/unary.rs |
Adjusts unary function formatting for BF16 under Metal math constraints. |
crates/cubecl-cpp/src/shared/instruction.rs |
Casts float literals to support BF16 paths. |
crates/cubecl-cpp/src/shared/base.rs |
Ensures Metal extensions get registered for hypot/rhypot. |
crates/cubecl-cpp/src/metal/extension.rs |
Adds hypot/rhypot extensions and improves BF16-safe casts. |
crates/cubecl-cpp/src/metal/dialect.rs |
Emits hypot/rhypot calls and fixes atomic compare-exchange emission for MSL semantics. |
crates/cubecl-cpp/src/metal/arch.rs |
Adjusts Metal “warp” size to 32. |
crates/cubecl-cpp/src/metal/address_space.rs |
Ensures atomic bindings/pointers use device address space in MSL. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| _type_id: u16, | ||
| _info: &<Self::Server as cubecl_core::server::ComputeServer>::Info, | ||
| ) -> Vec<DeviceId> { | ||
| let devices = crate::device::all_devices(); | ||
| (0..devices.len()) | ||
| .map(|i| DeviceId { | ||
| type_id: 0, | ||
| index_id: i as u32, | ||
| }) | ||
| .collect() |
There was a problem hiding this comment.
enumerate_devices ignores the type_id filter and always returns DeviceId { type_id: 0, ... }. With the current MetalDevice::from_id mapping, this means every enumerated device resolves to DefaultDevice, and callers can’t enumerate/select discrete vs integrated devices correctly. Consider implementing type_id filtering similar to cubecl-wgpu (e.g., return only DefaultDevice for type 0, discrete list for type 1, integrated list for type 2, etc.), and ensure returned DeviceId.type_id matches the MetalDevice variants you support.
| _type_id: u16, | |
| _info: &<Self::Server as cubecl_core::server::ComputeServer>::Info, | |
| ) -> Vec<DeviceId> { | |
| let devices = crate::device::all_devices(); | |
| (0..devices.len()) | |
| .map(|i| DeviceId { | |
| type_id: 0, | |
| index_id: i as u32, | |
| }) | |
| .collect() | |
| type_id: u16, | |
| _info: &<Self::Server as cubecl_core::server::ComputeServer>::Info, | |
| ) -> Vec<DeviceId> { | |
| let devices = crate::device::all_devices(); | |
| match type_id { | |
| 0 => vec![DeviceId { | |
| type_id: 0, | |
| index_id: 0, | |
| }], | |
| 1 => devices | |
| .iter() | |
| .enumerate() | |
| .filter_map(|(i, device)| { | |
| (!device.is_low_power()).then_some(DeviceId { | |
| type_id: 1, | |
| index_id: i as u32, | |
| }) | |
| }) | |
| .collect(), | |
| 2 => devices | |
| .iter() | |
| .enumerate() | |
| .filter_map(|(i, device)| { | |
| device.is_low_power().then_some(DeviceId { | |
| type_id: 2, | |
| index_id: i as u32, | |
| }) | |
| }) | |
| .collect(), | |
| _ => Vec::new(), | |
| } |
| let dispatch_info = match count { | ||
| CubeCount::Static(x, y, z) => DispatchInfo::Static(x, y, z), | ||
| CubeCount::Dynamic(binding) => DispatchInfo::Dynamic(binding), | ||
| }; | ||
|
|
||
| let mut resolved = match self | ||
| .streams | ||
| .resolve(stream_id, bindings.buffers.iter(), false) | ||
| { | ||
| Ok(r) => r, | ||
| Err(_) => return, | ||
| }; | ||
|
|
There was a problem hiding this comment.
MultiStream::resolve is called only with bindings.buffers.iter(). When count is CubeCount::Dynamic, the indirect-dispatch Binding used for dispatchThreadgroupsWithIndirectBuffer... is not included in the resolve handles, so cross-stream synchronization for that buffer can be skipped. Consider including the dynamic dispatch Binding in the iterator passed to resolve (e.g., chain the buffers iterator with std::iter::once(binding) when CubeCount::Dynamic).
| let dispatch_info = match count { | |
| CubeCount::Static(x, y, z) => DispatchInfo::Static(x, y, z), | |
| CubeCount::Dynamic(binding) => DispatchInfo::Dynamic(binding), | |
| }; | |
| let mut resolved = match self | |
| .streams | |
| .resolve(stream_id, bindings.buffers.iter(), false) | |
| { | |
| Ok(r) => r, | |
| Err(_) => return, | |
| }; | |
| let mut resolved = match &count { | |
| CubeCount::Dynamic(binding) => self.streams.resolve( | |
| stream_id, | |
| bindings | |
| .buffers | |
| .iter() | |
| .chain(std::iter::once(binding)), | |
| false, | |
| ), | |
| CubeCount::Static(_, _, _) => { | |
| self.streams.resolve(stream_id, bindings.buffers.iter(), false) | |
| } | |
| } { | |
| Ok(r) => r, | |
| Err(_) => return, | |
| }; | |
| let dispatch_info = match count { | |
| CubeCount::Static(x, y, z) => DispatchInfo::Static(x, y, z), | |
| CubeCount::Dynamic(binding) => DispatchInfo::Dynamic(binding), | |
| }; |
| fn handle_cursor(_stream: &Self::Stream, handle: &Binding) -> u64 { | ||
| // Metal uses shared memory so cursor tracking is minimal | ||
| handle.size |
There was a problem hiding this comment.
handle_cursor currently returns handle.size, which is constant for the lifetime of an allocation and does not reflect when the binding was last written on a stream. This breaks MultiStream’s shared-binding synchronization logic (it relies on a monotonic cursor, like other backends’ memory_management.get_cursor(binding.memory)), and can lead to missing waits between streams. Consider returning the cursor from stream.memory_management.get_cursor(handle.memory.clone()) (and falling back safely if that errors).
| fn handle_cursor(_stream: &Self::Stream, handle: &Binding) -> u64 { | |
| // Metal uses shared memory so cursor tracking is minimal | |
| handle.size | |
| fn handle_cursor(stream: &Self::Stream, handle: &Binding) -> u64 { | |
| stream | |
| .memory_management | |
| .get_cursor(handle.memory.clone()) | |
| .unwrap_or(handle.size) |
| cubecl-cuda = { path = "../cubecl-cuda", version = "=0.10.0-pre.3", default-features = false, optional = true } | ||
| cubecl-hip = { path = "../cubecl-hip", version = "=0.10.0-pre.3", default-features = false, optional = true } | ||
| cubecl-ir = { path = "../cubecl-ir", version = "=0.10.0-pre.3", default-features = false } | ||
| cubecl-metal = { path = "../cubecl-metal", version = "=0.10.0-pre.3", optional = true } |
There was a problem hiding this comment.
cubecl-metal is added as an optional dependency without default-features = false, unlike the other runtime dependencies in this crate (cubecl-cuda, cubecl-hip, cubecl-wgpu, etc.). This can unintentionally pull in cubecl-metal default features when consumers enable cubecl/metal, which may be inconsistent with the crate’s feature strategy. Consider adding default-features = false for cubecl-metal for parity with the other runtime deps.
| cubecl-metal = { path = "../cubecl-metal", version = "=0.10.0-pre.3", optional = true } | |
| cubecl-metal = { path = "../cubecl-metal", version = "=0.10.0-pre.3", default-features = false, optional = true } |
No description provided.