Description
Nx.Defn.value_and_grad crashes when computing gradients through a model containing max_pool with :same padding on BinaryBackend. The backward pass for max_pool calls select_and_scatter, which computes a negative padding value, causing :lists.duplicate/2 to fail with a FunctionClauseError.
The same model works correctly with EXLA as the JIT compiler.
Minimal reproduction
Mix.install([{:axon,` "~> 0.7"}, {:nx, "~> 0.10"}])
model =
Axon.input("input", shape: {nil, 16, 16, 1})
|> Axon.conv(4, kernel_size: 7, strides: [2, 2], padding: :same, name: "conv")
|> Axon.max_pool(kernel_size: {3, 3}, strides: [2, 2], padding: :same, name: "pool")
|> Axon.flatten()
|> Axon.dense(4, name: "fc")
{init_fn, predict_fn} = Axon.build(model, mode: :inference)
template = %{"input" => Nx.template({2, 16, 16, 1}, :f32)}
model_state = init_fn.(template, Axon.ModelState.empty())
# Forward pass works fine
output = predict_fn.(model_state, %{"input" => Nx.broadcast(0.1, {2, 16, 16, 1})})
IO.inspect(Nx.shape(output)) # => {2, 4}
# Gradient crashes
loss_fn = fn params ->
state = %{model_state | data: params}
output = predict_fn.(state, %{"input" => Nx.broadcast(0.1, {2, 16, 16, 1})})
Nx.mean(output)
end
{_loss, _grads} = Nx.Defn.value_and_grad(model_state.data, loss_fn)
Error
** (FunctionClauseError) no function clause matching in :lists.duplicate/2
The following arguments were given to :lists.duplicate/2:
# 1
-36
# 2
<<0, 0, 0, 0>>
(stdlib 6.2.2.2) lists.erl:510: :lists.duplicate/2
(nx 0.10.0) lib/nx/binary_backend.ex:1669: Nx.BinaryBackend.select_and_scatter/8
(nx 0.10.0) lib/nx/defn/evaluator.ex:461: Nx.Defn.Evaluator.eval_apply/4
Expected behavior
value_and_grad should compute gradients through the max_pool layer without crashing.
Root cause
BinaryBackend.select_and_scatter/8 computes negative padding values when the spatial dimensions after convolution are small relative to the pool kernel size. The -36 suggests the padding calculation (output_size - 1) * stride + kernel_size - input_size goes negative.
Workaround
Use EXLA as the JIT compiler:
Nx.Defn.jit(fn params -> Nx.Defn.value_and_grad(params, loss_fn) end, compiler:
EXLA).(model_state.data)
Look at
lib/nx/binary_backend.ex:1669 in select_and_scatter/8.
Environment
- Nx: 0.10.0
- Elixir: 1.18.4
- OTP: 27
- OS: Linux (WSL2)
Description
Nx.Defn.value_and_grad crashes when computing gradients through a model containing max_pool with :same padding on BinaryBackend. The backward pass for max_pool calls select_and_scatter, which computes a negative padding value, causing :lists.duplicate/2 to fail with a FunctionClauseError.
The same model works correctly with EXLA as the JIT compiler.
Minimal reproduction
Error
Expected behavior
value_and_grad should compute gradients through the max_pool layer without crashing.
Root cause
BinaryBackend.select_and_scatter/8 computes negative padding values when the spatial dimensions after convolution are small relative to the pool kernel size. The -36 suggests the padding calculation (output_size - 1) * stride + kernel_size - input_size goes negative.
Workaround
Use EXLA as the JIT compiler:
Nx.Defn.jit(fn params -> Nx.Defn.value_and_grad(params, loss_fn) end, compiler:
EXLA).(model_state.data)
Look at
lib/nx/binary_backend.ex:1669 in select_and_scatter/8.
Environment