Skip to content

Fix GDN/KDA bugs, require CUDA kernels, add cache-aware tests#451

Merged
tscholak merged 9 commits intooo/apriel_modeling_bugfrom
fix/require-cuda-kernels-no-fallbacks
Jan 19, 2026
Merged

Fix GDN/KDA bugs, require CUDA kernels, add cache-aware tests#451
tscholak merged 9 commits intooo/apriel_modeling_bugfrom
fix/require-cuda-kernels-no-fallbacks

Conversation

@tscholak
Copy link
Collaborator

Summary

Fix critical bugs in GDN and KDA mixer implementations, harden for production use, and add comprehensive cache-aware equivalence tests.

Bug Fixes

  • GDN chunk mode ignoring cache state: chunk_gated_delta_rule was passing initial_state=None, breaking prefill→decode→prefill generation cycles
  • KDA mode selection during training: fused_recurrent was incorrectly used for short sequences during training

Production Hardening

  • Remove all PyTorch fallback implementations - fail loudly if CUDA kernels (causal_conv1d, mamba_ssm, fla) are missing
  • Consolidate cache.py into modeling_apriel2.py for better tooling compatibility

Test Improvements

  • Extended equivalence tests verifying all 3 inference phases (prefill→decode→prefill2)
  • Enable bf16 tests - all 1718 tests pass in fp32 and bf16
  • Remove obsolete test_causal_conv1d.py (544 lines)
  • Extract shared fixtures, net reduction of ~90 lines

Test Plan

  • All 1718 mixer equivalence tests pass
  • GDN/KDA outputs match reference implementations through cache cycles
  • Import fails cleanly when CUDA kernels missing

🤖 Generated with Claude Code

tscholak and others added 9 commits January 17, 2026 19:41
Remove all PyTorch fallback implementations to ensure fast CUDA kernels
are always used. The module now fails loudly at import/instantiation
if required kernels are missing.

Changes:
- Remove torch_causal_conv1d_fn and torch_causal_conv1d_update fallbacks
- Remove torch_selective_scan_fn and torch_selective_state_update stubs
- Remove torch_chunk_gated_delta_rule function
- Remove _recurrent_gated_delta_rule method from Apriel2GatedDeltaNet
- Remove _forward_local method from GatedRMSNormalization
- Remove TestFastVsSlowPath test class (no longer needed)
- Handle CausalConv1d seq_len==1 edge case via update() instead of fallback
- Add ImportError at module load for missing causal_conv1d/mamba_ssm
- Add ImportError at class init for missing FLA kernels

Required packages:
- causal_conv1d (for CausalConv1d)
- mamba_ssm (for Mamba/SSM operations)
- fla (for GDN, KDA, GatedRMSNormalization)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The chunk_gated_delta_rule call was always passing initial_state=None,
ignoring any existing recurrent state from previous decode cycles.
This broke continued generation scenarios (prefill -> decode -> prefill).

Changed initial_state=None to initial_state=recurrent_state to match
the correct behavior already present in KDA's chunk_kda call.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add test_vs_qwen3next_with_cache and test_vs_fla_with_cache tests that
verify mixer implementations through all inference phases:
- Phase 1: Initial prefill with cache population
- Phase 2: Single-token decode using cached states
- Phase 3: Prefill again (decode→prefill transition)

Tests compare outputs and recurrent states at each phase. Convolution
states are not compared due to different storage formats between
implementations (Apriel2 stores kernel_size-1, references store
kernel_size).

For GDN, Phase 3 documents expected divergence from Qwen3Next due to
its bug where chunk mode ignores initial_state.

For KDA, all phases should match since FLA correctly passes
initial_state in chunk mode.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Merge test_vs_qwen3next and test_vs_qwen3next_with_cache into single
  parameterized test with use_cache fixture
- Merge test_vs_fla and test_vs_fla_with_cache similarly
- Add use_cache (False/True) and decode_steps (4) fixtures
- Use proper Apriel2Cache from cache.py instead of ad-hoc SimpleCache
- Use same total sequence length for both cache and non-cache modes
- Skip cache tests when seq_len < decode_steps + 2 (too small for 3 phases)
- Split sequence as: prefill=2/3, decode=4, prefill2=1/3 of remaining

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Fix KDA mode selection to match FLA: use fused_recurrent only when
  seq_len <= 64 AND not training (single expression instead of override)
- Replace use_cache fixture with explicit phase fixtures (prefill_len,
  decode_steps, prefill2_len) for clearer test parameterization
- Update test_chunked_vs_recurrent to use Apriel2Cache and fixtures
- Rename config_dict to mixer_config for consistency across all tests
- Remove unused qwen3_config fixture (recreated inline where needed)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
CausalConv1d is now tested through KDA equivalence tests which use
CausalConv1d for q_conv, k_conv, v_conv. The isolated tests were also
obsolete since CPU fallback was removed.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Move all cache classes (_AttentionCache, _SSMCache, _DummyCacheLayer,
Apriel2Cache, _LayerListAccessor) into modeling_apriel2.py for better
tooling compatibility - modeling code is expected to be together.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Enable "fast" mode (bf16/sdpa) tests that were previously skipped
- Add test_dtype fixture parameter to all tests that create models
- Convert models to correct dtype with .to(device="cuda", dtype=test_dtype)
- Create input tensors with explicit dtype parameter
- Fix assert_close to cast tensors to same dtype before comparison

All 1718 mixer equivalence tests now pass in both fp32 and bf16 modes.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add gdn_mixer_config and kda_mixer_config fixtures to centralize
  mixer config dict construction (eliminates 6 duplicate dicts)
- Add kda_hidden_size fixture for derived hidden_size calculation
- Add make_apriel2_config() helper for minimal Apriel2TextConfig
  construction (eliminates 4 duplicate config blocks)
- Update all GDN and KDA tests to use new fixtures
- Consolidate duplicate imports within test methods

Net reduction: 47 lines (-125/+78)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes critical bugs in GDN and KDA mixer implementations, removes PyTorch fallback implementations to enforce CUDA kernel requirements, consolidates cache classes, and adds comprehensive 3-phase cache-aware equivalence tests.

Changes:

  • Fixed GDN chunk mode to pass recurrent_state instead of None, enabling proper prefill→decode→prefill generation cycles
  • Fixed KDA mode selection to always use chunk mode during training, preventing incorrect use of fused_recurrent mode for short sequences
  • Removed PyTorch fallback implementations and added clear ImportError messages when CUDA kernels are missing (causal_conv1d, mamba_ssm, fla)
  • Consolidated cache.py into modeling_apriel2.py for better tooling compatibility
  • Extended equivalence tests with 3-phase testing (prefill→decode→prefill2) to verify cache handling
  • Enabled bf16 tests and removed obsolete test file (544 lines)

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
fast_llm_external_models/apriel2/modeling_apriel2.py Added cache classes from cache.py, removed PyTorch fallbacks, fixed GDN/KDA bugs, added CUDA kernel requirement checks
fast_llm_external_models/apriel2/cache.py Deleted - consolidated into modeling_apriel2.py
fast_llm_external_models/tests/test_apriel2/test_mixer_equivalence.py Added new fixtures, extended tests to 3-phase testing, enabled bf16 tests, removed fast/slow path tests
fast_llm_external_models/tests/test_apriel2/test_causal_conv1d.py Deleted - 544 lines of tests no longer needed
fast_llm_external_models/tests/test_apriel2/test_modeling.py Updated import from cache to modeling_apriel2
fast_llm_external_models/tests/test_apriel2/test_model_structure.py Updated import from cache to modeling_apriel2
fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py Updated import from cache to modeling_apriel2
fast_llm_external_models/tests/test_apriel2/test_cache_apriel2_specific.py Updated import from cache to modeling_apriel2
fast_llm_external_models/tests/test_apriel2/conftest.py Updated import from cache to modeling_apriel2

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

self.cache = cache
self.attr = attr

def __getitem__(self, idx):
Copy link

Copilot AI Jan 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method raises RuntimeError - should raise a LookupError (KeyError or IndexError) instead.

Copilot uses AI. Check for mistakes.
recurrent_state = recurrent_state.to(hidden_states.dtype)
output, last_recurrent_state = self._recurrent_gated_delta_rule(
query, key, value, g, beta_gate, recurrent_state
output, last_recurrent_state = fused_recurrent_gated_delta_rule(
Copy link
Contributor

@oleksost oleksost Jan 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, didn't know there was a fused implementation of the recurrent step.

@tscholak tscholak merged commit c3a6b44 into oo/apriel_modeling_bug Jan 19, 2026
6 checks passed
@tscholak tscholak deleted the fix/require-cuda-kernels-no-fallbacks branch January 19, 2026 14:42
@oleksost
Copy link
Contributor

GDN chunk mode ignoring cache state: chunk_gated_delta_rule was passing initial_state=None, breaking prefill→decode→prefill generation cycles
KDA mode selection during training: fused_recurrent was incorrectly used for short sequences during training

Just for completeness: the bug in KDA was not only effecting training but also chunked prefill (whenever seq. length > 64).

As for the GDN, we were using buggy _recurrent_gated_delta_rule instead of the fused_recurrent_gated_delta_rule from FLA, which was also effecting standard inference.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants