Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions fast_llm_external_models/apriel2/modeling_apriel2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,21 @@

import torch
import torch.nn.functional as F
from causal_conv1d import causal_conv1d_fn as _causal_conv1d_fn
from causal_conv1d import causal_conv1d_update as _causal_conv1d_update
from einops import rearrange, repeat
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
from torch import nn
from transformers import GenerationMixin, PreTrainedModel
from transformers.cache_utils import Cache
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.models.llama.modeling_llama import eager_attention_forward
from transformers.models.mistral.modeling_mistral import MistralMLP, MistralRMSNorm, apply_rotary_pos_emb
from transformers.processing_utils import Unpack
from transformers.cache_utils import Cache
from transformers.utils import logging
from transformers.utils.import_utils import (
is_causal_conv1d_available,
Expand Down Expand Up @@ -485,20 +489,13 @@ class PreprocessingOutput(TypedDict, total=False):
attention_mask: Optional[torch.Tensor]




# Require fast path CUDA kernels - no silent fallback to unoptimized code paths
if not is_fast_path_available:
raise ImportError(
"CausalConv1d and Mamba require CUDA kernels from causal_conv1d and mamba_ssm. "
"Install with: pip install causal-conv1d mamba-ssm"
)

from causal_conv1d import causal_conv1d_fn as _causal_conv1d_fn
from causal_conv1d import causal_conv1d_update as _causal_conv1d_update
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
from mamba_ssm.ops.triton.selective_state_update import selective_state_update


class CausalConv1d(nn.Conv1d):
"""
Expand Down Expand Up @@ -1298,8 +1295,7 @@ def __init__(self, hidden_size: int, eps: float = 1e-5, activation: str = "silu"
super().__init__()
if rms_norm_gated is None:
raise ImportError(
"GatedRMSNormalization requires rms_norm_gated from fla library. "
"Install with: pip install fla-core"
"GatedRMSNormalization requires rms_norm_gated from fla library. " "Install with: pip install fla-core"
)
self.weight = nn.Parameter(torch.ones(hidden_size))
self.eps = eps
Expand Down Expand Up @@ -1386,8 +1382,7 @@ def __init__(
# Require FLA kernels - no silent fallback to unoptimized code paths
if chunk_gated_delta_rule is None or fused_recurrent_gated_delta_rule is None:
raise ImportError(
"GatedDeltaNet requires the fla library for optimized kernels. "
"Install with: pip install fla-core"
"GatedDeltaNet requires the fla library for optimized kernels. " "Install with: pip install fla-core"
)

def _fix_query_key_value_ordering(self, mixed_qkvz: torch.Tensor, mixed_ba: torch.Tensor):
Expand Down Expand Up @@ -1606,7 +1601,9 @@ def __init__(
self.conv_kernel_size = conv_config.get("kernel_size", 4)
norm_config = config_dict.get("normalization", {})
self.norm_eps = norm_config.get("epsilon", 1e-5)
self.norm_activation = norm_config.get("activation", "sigmoid")
self.norm_activation = norm_config.get(
"activation", "silu"
) # default to silu to be consistent with Fast-LLM's default. Note, Kimi uses sigmoid.

# Derived dimensions
self.projection_size = self.head_dim * self.num_heads
Expand Down
Loading