Fix GDN/KDA bugs, require CUDA kernels, add cache-aware tests#451
Conversation
Remove all PyTorch fallback implementations to ensure fast CUDA kernels are always used. The module now fails loudly at import/instantiation if required kernels are missing. Changes: - Remove torch_causal_conv1d_fn and torch_causal_conv1d_update fallbacks - Remove torch_selective_scan_fn and torch_selective_state_update stubs - Remove torch_chunk_gated_delta_rule function - Remove _recurrent_gated_delta_rule method from Apriel2GatedDeltaNet - Remove _forward_local method from GatedRMSNormalization - Remove TestFastVsSlowPath test class (no longer needed) - Handle CausalConv1d seq_len==1 edge case via update() instead of fallback - Add ImportError at module load for missing causal_conv1d/mamba_ssm - Add ImportError at class init for missing FLA kernels Required packages: - causal_conv1d (for CausalConv1d) - mamba_ssm (for Mamba/SSM operations) - fla (for GDN, KDA, GatedRMSNormalization) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The chunk_gated_delta_rule call was always passing initial_state=None, ignoring any existing recurrent state from previous decode cycles. This broke continued generation scenarios (prefill -> decode -> prefill). Changed initial_state=None to initial_state=recurrent_state to match the correct behavior already present in KDA's chunk_kda call. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add test_vs_qwen3next_with_cache and test_vs_fla_with_cache tests that verify mixer implementations through all inference phases: - Phase 1: Initial prefill with cache population - Phase 2: Single-token decode using cached states - Phase 3: Prefill again (decode→prefill transition) Tests compare outputs and recurrent states at each phase. Convolution states are not compared due to different storage formats between implementations (Apriel2 stores kernel_size-1, references store kernel_size). For GDN, Phase 3 documents expected divergence from Qwen3Next due to its bug where chunk mode ignores initial_state. For KDA, all phases should match since FLA correctly passes initial_state in chunk mode. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Merge test_vs_qwen3next and test_vs_qwen3next_with_cache into single parameterized test with use_cache fixture - Merge test_vs_fla and test_vs_fla_with_cache similarly - Add use_cache (False/True) and decode_steps (4) fixtures - Use proper Apriel2Cache from cache.py instead of ad-hoc SimpleCache - Use same total sequence length for both cache and non-cache modes - Skip cache tests when seq_len < decode_steps + 2 (too small for 3 phases) - Split sequence as: prefill=2/3, decode=4, prefill2=1/3 of remaining Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Fix KDA mode selection to match FLA: use fused_recurrent only when seq_len <= 64 AND not training (single expression instead of override) - Replace use_cache fixture with explicit phase fixtures (prefill_len, decode_steps, prefill2_len) for clearer test parameterization - Update test_chunked_vs_recurrent to use Apriel2Cache and fixtures - Rename config_dict to mixer_config for consistency across all tests - Remove unused qwen3_config fixture (recreated inline where needed) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
CausalConv1d is now tested through KDA equivalence tests which use CausalConv1d for q_conv, k_conv, v_conv. The isolated tests were also obsolete since CPU fallback was removed. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Move all cache classes (_AttentionCache, _SSMCache, _DummyCacheLayer, Apriel2Cache, _LayerListAccessor) into modeling_apriel2.py for better tooling compatibility - modeling code is expected to be together. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Enable "fast" mode (bf16/sdpa) tests that were previously skipped - Add test_dtype fixture parameter to all tests that create models - Convert models to correct dtype with .to(device="cuda", dtype=test_dtype) - Create input tensors with explicit dtype parameter - Fix assert_close to cast tensors to same dtype before comparison All 1718 mixer equivalence tests now pass in both fp32 and bf16 modes. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add gdn_mixer_config and kda_mixer_config fixtures to centralize mixer config dict construction (eliminates 6 duplicate dicts) - Add kda_hidden_size fixture for derived hidden_size calculation - Add make_apriel2_config() helper for minimal Apriel2TextConfig construction (eliminates 4 duplicate config blocks) - Update all GDN and KDA tests to use new fixtures - Consolidate duplicate imports within test methods Net reduction: 47 lines (-125/+78) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
This PR fixes critical bugs in GDN and KDA mixer implementations, removes PyTorch fallback implementations to enforce CUDA kernel requirements, consolidates cache classes, and adds comprehensive 3-phase cache-aware equivalence tests.
Changes:
- Fixed GDN chunk mode to pass
recurrent_stateinstead ofNone, enabling proper prefill→decode→prefill generation cycles - Fixed KDA mode selection to always use chunk mode during training, preventing incorrect use of fused_recurrent mode for short sequences
- Removed PyTorch fallback implementations and added clear ImportError messages when CUDA kernels are missing (causal_conv1d, mamba_ssm, fla)
- Consolidated
cache.pyintomodeling_apriel2.pyfor better tooling compatibility - Extended equivalence tests with 3-phase testing (prefill→decode→prefill2) to verify cache handling
- Enabled bf16 tests and removed obsolete test file (544 lines)
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| fast_llm_external_models/apriel2/modeling_apriel2.py | Added cache classes from cache.py, removed PyTorch fallbacks, fixed GDN/KDA bugs, added CUDA kernel requirement checks |
| fast_llm_external_models/apriel2/cache.py | Deleted - consolidated into modeling_apriel2.py |
| fast_llm_external_models/tests/test_apriel2/test_mixer_equivalence.py | Added new fixtures, extended tests to 3-phase testing, enabled bf16 tests, removed fast/slow path tests |
| fast_llm_external_models/tests/test_apriel2/test_causal_conv1d.py | Deleted - 544 lines of tests no longer needed |
| fast_llm_external_models/tests/test_apriel2/test_modeling.py | Updated import from cache to modeling_apriel2 |
| fast_llm_external_models/tests/test_apriel2/test_model_structure.py | Updated import from cache to modeling_apriel2 |
| fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py | Updated import from cache to modeling_apriel2 |
| fast_llm_external_models/tests/test_apriel2/test_cache_apriel2_specific.py | Updated import from cache to modeling_apriel2 |
| fast_llm_external_models/tests/test_apriel2/conftest.py | Updated import from cache to modeling_apriel2 |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| self.cache = cache | ||
| self.attr = attr | ||
|
|
||
| def __getitem__(self, idx): |
There was a problem hiding this comment.
This method raises RuntimeError - should raise a LookupError (KeyError or IndexError) instead.
| recurrent_state = recurrent_state.to(hidden_states.dtype) | ||
| output, last_recurrent_state = self._recurrent_gated_delta_rule( | ||
| query, key, value, g, beta_gate, recurrent_state | ||
| output, last_recurrent_state = fused_recurrent_gated_delta_rule( |
There was a problem hiding this comment.
Nice, didn't know there was a fused implementation of the recurrent step.
Just for completeness: the bug in KDA was not only effecting training but also chunked prefill (whenever seq. length > 64). As for the GDN, we were using buggy |
Summary
Fix critical bugs in GDN and KDA mixer implementations, harden for production use, and add comprehensive cache-aware equivalence tests.
Bug Fixes
chunk_gated_delta_rulewas passinginitial_state=None, breaking prefill→decode→prefill generation cyclesfused_recurrentwas incorrectly used for short sequences during trainingProduction Hardening
cache.pyintomodeling_apriel2.pyfor better tooling compatibilityTest Improvements
test_causal_conv1d.py(544 lines)Test Plan
🤖 Generated with Claude Code