Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 114 additions & 126 deletions fast_llm/functional/entropy_loss.py

Large diffs are not rendered by default.

13 changes: 7 additions & 6 deletions fast_llm/functional/triton/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ def triton_cross_entropy_forward_backward(
# TODO: Improve assumptions.
assert logits.is_contiguous()
assert target.is_contiguous()
n_rows, n_cols = logits.shape
n_rows = logits.shape[:-1].numel()
n_cols = logits.size(-1)
block_size = triton.next_power_of_2(n_cols)
assert block_size <= TritonConfig.MAX_BLOCK_SIZE_BYTES
num_warps = 4 if block_size < 2048 else (8 if block_size < 8192 else 16)
Expand All @@ -155,8 +156,8 @@ def triton_cross_entropy_forward_backward(
losses,
None if grad_output is None else grad_output / n_rows,
n_cols,
logits.stride(0),
None if grad_output is None else grad_logits.stride(0),
logits.stride(-2),
None if grad_output is None else grad_logits.stride(-2),
logits_scale_factor,
block_size=block_size,
num_warps=num_warps,
Expand All @@ -172,9 +173,9 @@ def triton_cross_entropy_forward_backward(
losses,
None if grad_output is None else grad_output / n_rows,
n_cols,
logits.stride(0),
target.stride(0),
None if grad_output is None else grad_logits.stride(0),
logits.stride(-2),
target.stride(-2),
None if grad_output is None else grad_logits.stride(-2),
logits_scale_factor,
block_size=block_size,
num_warps=num_warps,
Expand Down
1 change: 1 addition & 0 deletions fast_llm/layers/language_model/loss/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def dpo_loss(
beta: float = 1.0,
logits_scale_factor: float = 1.0,
) -> torch.Tensor:
logits = logits.float()

if logits_scale_factor != 1.0:
# TODO: Make more efficient.
Expand Down
71 changes: 67 additions & 4 deletions fast_llm/layers/language_model/loss/entropy_loss.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import typing

import torch
from torch._C._distributed_c10d import ProcessGroup

from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType, TargetFormat, TritonConfig
from fast_llm.functional.entropy_loss import entropy_loss_forward_backward
from fast_llm.functional.entropy_loss import fused_entropy_loss_forward_backward, torch_entropy_loss_forward_backward
from fast_llm.functional.triton.cross_entropy import triton_cross_entropy_forward_backward
from fast_llm.layers.language_model.loss.config import (
LanguageModelDistillationLossConfig,
LanguageModelLabelEntropyLossConfig,
)
from fast_llm.layers.language_model.loss.loss import LanguageModelLoss
from fast_llm.utils import Assert


def _get_imlementation(
def _get_implementation(
default: EntropyLossImplementation = EntropyLossImplementation.auto,
loss_type: EntropyLossType = EntropyLossType.cross_entropy,
vocab_parallel: bool = False,
Expand All @@ -34,7 +37,7 @@ def _get_imlementation(
class LanguageModelLabelEntropyLoss[ConfigType: LanguageModelLabelEntropyLossConfig](LanguageModelLoss[ConfigType]):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._implementation = _get_imlementation(
self._implementation = _get_implementation(
self._config.implementation, self._config.loss_type, self._vocab_parallel
)

Expand Down Expand Up @@ -63,7 +66,7 @@ def __init__(self, *args, **kwargs):
if self._prediction_distance > 0:
raise NotImplementedError()

self._implementation = _get_imlementation(
self._implementation = _get_implementation(
self._config.implementation, self._config.loss_type, self._vocab_parallel
)

Expand All @@ -84,3 +87,63 @@ def forward_backward(
target_format=TargetFormat.logits,
entropy_loss_type=self._config.loss_type,
)


_ENTROPY_LOSS_IMPLEMENTATIONS = {
EntropyLossImplementation.torch: torch_entropy_loss_forward_backward,
EntropyLossImplementation.fused: fused_entropy_loss_forward_backward,
EntropyLossImplementation.triton: triton_cross_entropy_forward_backward,
}


def entropy_loss_forward_backward(
logits: torch.Tensor, # (*batch, vocab)
target: torch.Tensor, # (*batch,) or (*batch, vocab)
loss_mask: torch.Tensor | None, # (*batch,)
grad_output: float | None,
group: ProcessGroup | None = None,
implementation: EntropyLossImplementation = EntropyLossImplementation.fused,
logits_scale_factor: float = 1.0,
temperature: float = 1.0,
target_format: TargetFormat = TargetFormat.labels,
entropy_loss_type: EntropyLossType = EntropyLossType.cross_entropy,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
Select the appropriate implementation of cross-entropy.
The triton implementation from the triton submodule is the fastest and recommended one.
It doesn't have a tensor-parallel implementation, but can be computed in a sequence-tensor-parallel way,
which is faster and has a relatively small memory overhead.
"""
if target_format == TargetFormat.labels:
Assert.eq(target.shape, logits.shape[:-1])
Assert.eq(target.dtype, torch.int64)
assert loss_mask is None
else:
Assert.eq(target.shape, logits.shape)
assert target.dtype.is_floating_point, target.dtype
if loss_mask is not None:
Assert.eq(loss_mask.shape, logits.shape[:-1])
if group:
Assert.eq(implementation, EntropyLossImplementation.fused)
return fused_entropy_loss_forward_backward(
logits,
target,
loss_mask,
grad_output,
logits_scale_factor,
target_format,
entropy_loss_type,
group,
temperature,
)
else:
return _ENTROPY_LOSS_IMPLEMENTATIONS[implementation](
logits,
target,
loss_mask,
grad_output,
logits_scale_factor,
target_format,
entropy_loss_type,
temperature=temperature,
)
83 changes: 55 additions & 28 deletions fast_llm/layers/language_model/loss/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,59 +2,86 @@

import torch

from fast_llm.functional.entropy_loss import fused_predicted_logits_from_labels, fused_softmax_base
from fast_llm.layers.language_model.loss.config import LanguageModelGRPOLossConfig, LanguageModelLossKwargs
from fast_llm.layers.language_model.loss.dpo import get_target_log_probabilities
from fast_llm.layers.language_model.loss.loss import LanguageModelLoss, loss_forward_backward
from fast_llm.layers.language_model.loss.loss import LanguageModelLoss


class LanguageModelGRPOLoss[ConfigType: LanguageModelGRPOLossConfig](LanguageModelLoss[ConfigType]):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Support vocab_parallel
if self._prediction_distance > 0:
raise NotImplementedError()
if self._vocab_parallel:
raise NotImplementedError()

def forward_backward(
self,
logits: "torch.Tensor",
kwargs: dict[str, typing.Any],
split_index: int = 0,
) -> "tuple[torch.Tensor, torch.Tensor | None]":
return loss_forward_backward(
self._get_grad_output(kwargs),
grpo_loss,
) -> tuple[torch.Tensor, torch.Tensor | None]:
return grpo_loss_forward_backward(
logits,
self._get_loss_mask(kwargs, split_index),
self._get_labels(kwargs, split_index),
self._prepare_target(kwargs[LanguageModelLossKwargs.advantages], kwargs, split_index),
self._prepare_target(kwargs[LanguageModelLossKwargs.old_log_probabilities], kwargs, split_index),
self._config.epsilon_low,
self._config.epsilon_high,
self._logits_scale_factor,
grad_output=self._get_grad_output(kwargs),
group=self._parallel_dim.group if self._vocab_parallel else None,
epsilon_low=self._config.epsilon_low,
epsilon_high=self._config.epsilon_high,
logits_scale_factor=self._logits_scale_factor,
)


@torch.compile
def grpo_loss(
logits: torch.Tensor,
loss_mask: "torch.Tensor | None",
labels: torch.Tensor,
advantages: torch.Tensor,
old_log_probabilities: torch.Tensor,
def grpo_loss_forward_backward(
logits: torch.Tensor, # (*batch, vocab)
target: torch.Tensor, # (*batch,)
advantages: torch.Tensor, # (*batch,)
old_log_probabilities: torch.Tensor, # (*batch,)
grad_output: float | None,
group: torch.distributed.ProcessGroup | None = None,
epsilon_low: float = 0.2,
epsilon_high: float = 0.2,
logits_scale_factor: float = 1.0,
) -> torch.Tensor:
if logits_scale_factor != 1.0:
# TODO: Make more efficient.
logits = logits * logits_scale_factor
probability_ratio = torch.exp(get_target_log_probabilities(logits, labels) - old_log_probabilities)
loss = -torch.min(
) -> tuple[torch.Tensor, torch.Tensor | None]:
grad_output = None if grad_output is None else grad_output / logits.shape[:-1].numel() * logits_scale_factor
loss_mask = target >= 0

logits_norm, exp_logits, sum_exp_logits, _ = fused_softmax_base(logits, logits_scale_factor, group)
predicted_logits, target_masked, target_mask = fused_predicted_logits_from_labels(
logits_norm, target, loss_mask, group
)
probability_ratio = (predicted_logits - sum_exp_logits.log() - old_log_probabilities).exp()

per_sample_loss = -torch.min(
probability_ratio * advantages,
torch.clamp(probability_ratio, 1 - epsilon_low, 1 + epsilon_high) * advantages,
)
if loss_mask is not None:
loss = loss * loss_mask
return loss.mean()
per_sample_loss = per_sample_loss * loss_mask
loss = per_sample_loss.mean()

if grad_output is None:
grad = None
else:
# loss[a>=0] = -a * min(x, 1 + epsilon_high) => grad[a>=0] = -a * (x <= 1 + epsilon_high)
# loss[a<=0] = a * max(x, 1 - epsilon_low) => grad[a<=0] = a * (x >= 1 - epsilon_low)
probability_ratio_grad = (
grad_output
* (
torch.clamp_min(advantages, 0) * (probability_ratio <= 1 + epsilon_high)
+ torch.clamp_max(advantages, 0) * (probability_ratio >= 1 - epsilon_low)
)
* loss_mask
)

# d(probability_ratio)/d(logits) = - probability_ratio * (predicted_probabilities - target_probabilities)
# (Sign absorbed in probability_ratio_grad)
predicted_probabilities = exp_logits / sum_exp_logits.unsqueeze_(-1)
grad = (probability_ratio_grad * probability_ratio).unsqueeze(-1) * predicted_probabilities.scatter_add(
-1,
target_masked.unsqueeze(-1),
-(loss_mask if target_mask is None else target_mask).unsqueeze(-1).to(torch.float32),
)
grad = grad.to(logits.dtype)

return loss, grad
2 changes: 1 addition & 1 deletion fast_llm/layers/language_model/loss/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,6 @@ def loss_forward_backward(
grad = None
else:
loss.backward(torch.full_like(loss, grad_output))
grad = input_.grad.detach().to(input_.dtype)
grad = input_.grad.detach()

return loss, grad
48 changes: 40 additions & 8 deletions fast_llm/layers/language_model/loss/z_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import torch

from fast_llm.functional.entropy_loss import fused_softmax_base
from fast_llm.layers.language_model.loss.config import LanguageModelZLossConfig
from fast_llm.layers.language_model.loss.loss import LanguageModelLoss, loss_forward_backward
from fast_llm.layers.language_model.loss.loss import LanguageModelLoss


class LanguageModelZLoss[ConfigType: LanguageModelZLossConfig](LanguageModelLoss[ConfigType]):
Expand All @@ -19,12 +20,12 @@ def forward_backward(
kwargs: dict[str, typing.Any],
split_index: int = 0,
) -> "tuple[torch.Tensor, torch.Tensor | None]":
return loss_forward_backward(
self._get_grad_output(kwargs),
z_loss,
return z_loss_forward_backward(
logits,
self._get_loss_mask(kwargs, split_index),
self._logits_scale_factor,
grad_output=self._get_grad_output(kwargs),
group=self._parallel_dim.group if self._vocab_parallel else None,
logits_scale_factor=self._logits_scale_factor,
)


Expand All @@ -34,10 +35,41 @@ def z_loss(
loss_mask: "torch.Tensor | None" = None,
logits_scale_factor: float = 1.0,
) -> torch.Tensor:
"""
Z-loss = mean(logsumexp(logits, dim=-1) ** 2)
"""
# TODO: Replace usage in MoE, move to testing.
logits = logits.float()
out = torch.logsumexp(logits if logits_scale_factor == 1.0 else logits * logits_scale_factor, dim=-1) ** 2
if loss_mask is not None:
out = out * loss_mask
return torch.mean(out)


@torch.compile
def z_loss_forward_backward(
logits: torch.Tensor,
loss_mask: torch.Tensor | None,
grad_output: float | None,
group: torch.distributed.ProcessGroup | None = None,
logits_scale_factor: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
Z-loss = mean(logsumexp(logits, dim=-1) ** 2)
Grad = 2 * log_sum_exp_logits * softmax(logits)
"""
grad_output = None if grad_output is None else grad_output / logits.shape[:-1].numel() * logits_scale_factor
logits_norm, exp_logits, sum_exp_logits, logits_max = fused_softmax_base(logits, logits_scale_factor, group)
log_sum_exp_logits = sum_exp_logits.log() + logits_max

per_sample_loss = log_sum_exp_logits**2
if loss_mask is not None:
per_sample_loss = per_sample_loss * loss_mask
loss = per_sample_loss.mean()

if grad_output is None:
grad = None
else:
grad_base = 2 * grad_output * (log_sum_exp_logits / sum_exp_logits)
if loss_mask is not None:
grad_base = grad_base * loss_mask
grad = (grad_base.unsqueeze(-1) * exp_logits).to(logits.dtype)

return loss, grad
Loading