Skip to content

feat: metal backend#1175

Open
dcvz wants to merge 28 commits intotracel-ai:mainfrom
oxiglade:feat/native-metal-backend
Open

feat: metal backend#1175
dcvz wants to merge 28 commits intotracel-ai:mainfrom
oxiglade:feat/native-metal-backend

Conversation

@dcvz
Copy link
Copy Markdown
Contributor

@dcvz dcvz commented Feb 4, 2026

No description provided.

@dcvz dcvz force-pushed the feat/native-metal-backend branch 7 times, most recently from 3bed613 to d379550 Compare February 5, 2026 21:33
@dcvz dcvz marked this pull request as ready for review February 6, 2026 14:47
@dcvz dcvz force-pushed the feat/native-metal-backend branch 3 times, most recently from e68df50 to 622c366 Compare March 16, 2026 18:02
Comment thread crates/cubecl-metal/src/runtime.rs Outdated
@dcvz dcvz force-pushed the feat/native-metal-backend branch from d212496 to 4a63e18 Compare April 8, 2026 20:18
@antimora antimora requested a review from Copilot April 8, 2026 21:39
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a native Metal backend to CubeCL and wires it into the workspace’s build/test/doc tooling and codegen so Metal can be selected as a runtime and supported in CI workflows.

Changes:

  • Introduces new cubecl-metal crate implementing a Metal runtime/server/stream/memory/storage stack.
  • Exposes Metal runtime via cubecl feature flags (metal) and TestRuntime cfg plumbing.
  • Updates xtask + CI workflow to support doc --ci and to exclude Metal/CUDA/HIP crates in CI.

Reviewed changes

Copilot reviewed 32 out of 32 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
xtask/src/main.rs Adds doc subcommand wiring to the xtask CLI.
xtask/src/commands/mod.rs Registers new doc module.
xtask/src/commands/doc.rs Adds --ci handling for docs, excluding unsupported crates.
xtask/src/commands/validate.rs Routes doc validation through the new CubeCL doc command args.
xtask/src/commands/build.rs Excludes cubecl-metal (and CUDA/HIP) in CI builds.
xtask/src/commands/test.rs Excludes cubecl-metal (and CUDA/HIP) in CI tests.
xtask/src/commands/check.rs Adds CI-specific workspace clippy that excludes platform-specific crates.
.github/workflows/ci.yml Switches CI invocations to use --ci for check and doc.
crates/cubecl/src/lib.rs Re-exports cubecl_metal behind feature = "metal" and adds test_runtime_metal.
crates/cubecl/Cargo.toml Re-points metal feature to cubecl-metal and adds the dependency.
crates/cubecl/build.rs Adds test_runtime_metal check-cfg and feature wiring.
crates/cubecl-metal/* New Metal runtime implementation (device selection, compilation, server, stream/event sync, storage).
crates/cubecl-cpp/src/shared/variable.rs Updates pointer/atomic formatting to include Metal address spaces.
crates/cubecl-cpp/src/shared/unary.rs Adjusts unary function formatting for BF16 under Metal math constraints.
crates/cubecl-cpp/src/shared/instruction.rs Casts float literals to support BF16 paths.
crates/cubecl-cpp/src/shared/base.rs Ensures Metal extensions get registered for hypot/rhypot.
crates/cubecl-cpp/src/metal/extension.rs Adds hypot/rhypot extensions and improves BF16-safe casts.
crates/cubecl-cpp/src/metal/dialect.rs Emits hypot/rhypot calls and fixes atomic compare-exchange emission for MSL semantics.
crates/cubecl-cpp/src/metal/arch.rs Adjusts Metal “warp” size to 32.
crates/cubecl-cpp/src/metal/address_space.rs Ensures atomic bindings/pointers use device address space in MSL.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +146 to +155
_type_id: u16,
_info: &<Self::Server as cubecl_core::server::ComputeServer>::Info,
) -> Vec<DeviceId> {
let devices = crate::device::all_devices();
(0..devices.len())
.map(|i| DeviceId {
type_id: 0,
index_id: i as u32,
})
.collect()
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enumerate_devices ignores the type_id filter and always returns DeviceId { type_id: 0, ... }. With the current MetalDevice::from_id mapping, this means every enumerated device resolves to DefaultDevice, and callers can’t enumerate/select discrete vs integrated devices correctly. Consider implementing type_id filtering similar to cubecl-wgpu (e.g., return only DefaultDevice for type 0, discrete list for type 1, integrated list for type 2, etc.), and ensure returned DeviceId.type_id matches the MetalDevice variants you support.

Suggested change
_type_id: u16,
_info: &<Self::Server as cubecl_core::server::ComputeServer>::Info,
) -> Vec<DeviceId> {
let devices = crate::device::all_devices();
(0..devices.len())
.map(|i| DeviceId {
type_id: 0,
index_id: i as u32,
})
.collect()
type_id: u16,
_info: &<Self::Server as cubecl_core::server::ComputeServer>::Info,
) -> Vec<DeviceId> {
let devices = crate::device::all_devices();
match type_id {
0 => vec![DeviceId {
type_id: 0,
index_id: 0,
}],
1 => devices
.iter()
.enumerate()
.filter_map(|(i, device)| {
(!device.is_low_power()).then_some(DeviceId {
type_id: 1,
index_id: i as u32,
})
})
.collect(),
2 => devices
.iter()
.enumerate()
.filter_map(|(i, device)| {
device.is_low_power().then_some(DeviceId {
type_id: 2,
index_id: i as u32,
})
})
.collect(),
_ => Vec::new(),
}

Copilot uses AI. Check for mistakes.
Comment on lines +326 to +338
let dispatch_info = match count {
CubeCount::Static(x, y, z) => DispatchInfo::Static(x, y, z),
CubeCount::Dynamic(binding) => DispatchInfo::Dynamic(binding),
};

let mut resolved = match self
.streams
.resolve(stream_id, bindings.buffers.iter(), false)
{
Ok(r) => r,
Err(_) => return,
};

Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MultiStream::resolve is called only with bindings.buffers.iter(). When count is CubeCount::Dynamic, the indirect-dispatch Binding used for dispatchThreadgroupsWithIndirectBuffer... is not included in the resolve handles, so cross-stream synchronization for that buffer can be skipped. Consider including the dynamic dispatch Binding in the iterator passed to resolve (e.g., chain the buffers iterator with std::iter::once(binding) when CubeCount::Dynamic).

Suggested change
let dispatch_info = match count {
CubeCount::Static(x, y, z) => DispatchInfo::Static(x, y, z),
CubeCount::Dynamic(binding) => DispatchInfo::Dynamic(binding),
};
let mut resolved = match self
.streams
.resolve(stream_id, bindings.buffers.iter(), false)
{
Ok(r) => r,
Err(_) => return,
};
let mut resolved = match &count {
CubeCount::Dynamic(binding) => self.streams.resolve(
stream_id,
bindings
.buffers
.iter()
.chain(std::iter::once(binding)),
false,
),
CubeCount::Static(_, _, _) => {
self.streams.resolve(stream_id, bindings.buffers.iter(), false)
}
} {
Ok(r) => r,
Err(_) => return,
};
let dispatch_info = match count {
CubeCount::Static(x, y, z) => DispatchInfo::Static(x, y, z),
CubeCount::Dynamic(binding) => DispatchInfo::Dynamic(binding),
};

Copilot uses AI. Check for mistakes.
Comment on lines +310 to +312
fn handle_cursor(_stream: &Self::Stream, handle: &Binding) -> u64 {
// Metal uses shared memory so cursor tracking is minimal
handle.size
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

handle_cursor currently returns handle.size, which is constant for the lifetime of an allocation and does not reflect when the binding was last written on a stream. This breaks MultiStream’s shared-binding synchronization logic (it relies on a monotonic cursor, like other backends’ memory_management.get_cursor(binding.memory)), and can lead to missing waits between streams. Consider returning the cursor from stream.memory_management.get_cursor(handle.memory.clone()) (and falling back safely if that errors).

Suggested change
fn handle_cursor(_stream: &Self::Stream, handle: &Binding) -> u64 {
// Metal uses shared memory so cursor tracking is minimal
handle.size
fn handle_cursor(stream: &Self::Stream, handle: &Binding) -> u64 {
stream
.memory_management
.get_cursor(handle.memory.clone())
.unwrap_or(handle.size)

Copilot uses AI. Check for mistakes.
Comment thread crates/cubecl/Cargo.toml
cubecl-cuda = { path = "../cubecl-cuda", version = "=0.10.0-pre.3", default-features = false, optional = true }
cubecl-hip = { path = "../cubecl-hip", version = "=0.10.0-pre.3", default-features = false, optional = true }
cubecl-ir = { path = "../cubecl-ir", version = "=0.10.0-pre.3", default-features = false }
cubecl-metal = { path = "../cubecl-metal", version = "=0.10.0-pre.3", optional = true }
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cubecl-metal is added as an optional dependency without default-features = false, unlike the other runtime dependencies in this crate (cubecl-cuda, cubecl-hip, cubecl-wgpu, etc.). This can unintentionally pull in cubecl-metal default features when consumers enable cubecl/metal, which may be inconsistent with the crate’s feature strategy. Consider adding default-features = false for cubecl-metal for parity with the other runtime deps.

Suggested change
cubecl-metal = { path = "../cubecl-metal", version = "=0.10.0-pre.3", optional = true }
cubecl-metal = { path = "../cubecl-metal", version = "=0.10.0-pre.3", default-features = false, optional = true }

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants