Skip to content

BinaryBackend.select_and_scatter/8 crashes with negative padding during max_pool gradient #1675

@blasphemetheus

Description

@blasphemetheus

Description

Nx.Defn.value_and_grad crashes when computing gradients through a model containing max_pool with :same padding on BinaryBackend. The backward pass for max_pool calls select_and_scatter, which computes a negative padding value, causing :lists.duplicate/2 to fail with a FunctionClauseError.

The same model works correctly with EXLA as the JIT compiler.

Minimal reproduction

Mix.install([{:axon,` "~> 0.7"}, {:nx, "~> 0.10"}])

  model =
    Axon.input("input", shape: {nil, 16, 16, 1})
    |> Axon.conv(4, kernel_size: 7, strides: [2, 2], padding: :same, name: "conv")
    |> Axon.max_pool(kernel_size: {3, 3}, strides: [2, 2], padding: :same, name: "pool")
    |> Axon.flatten()
    |> Axon.dense(4, name: "fc")

  {init_fn, predict_fn} = Axon.build(model, mode: :inference)
  template = %{"input" => Nx.template({2, 16, 16, 1}, :f32)}
  model_state = init_fn.(template, Axon.ModelState.empty())

  # Forward pass works fine
  output = predict_fn.(model_state, %{"input" => Nx.broadcast(0.1, {2, 16, 16, 1})})
  IO.inspect(Nx.shape(output))  # => {2, 4}

  # Gradient crashes
  loss_fn = fn params ->
    state = %{model_state | data: params}
    output = predict_fn.(state, %{"input" => Nx.broadcast(0.1, {2, 16, 16, 1})})
    Nx.mean(output)
  end

  {_loss, _grads} = Nx.Defn.value_and_grad(model_state.data, loss_fn)

Error

  ** (FunctionClauseError) no function clause matching in :lists.duplicate/2

      The following arguments were given to :lists.duplicate/2:

          # 1
          -36

          # 2
          <<0, 0, 0, 0>>

      (stdlib 6.2.2.2) lists.erl:510: :lists.duplicate/2
      (nx 0.10.0) lib/nx/binary_backend.ex:1669: Nx.BinaryBackend.select_and_scatter/8
      (nx 0.10.0) lib/nx/defn/evaluator.ex:461: Nx.Defn.Evaluator.eval_apply/4

Expected behavior

value_and_grad should compute gradients through the max_pool layer without crashing.

Root cause

BinaryBackend.select_and_scatter/8 computes negative padding values when the spatial dimensions after convolution are small relative to the pool kernel size. The -36 suggests the padding calculation (output_size - 1) * stride + kernel_size - input_size goes negative.

Workaround

Use EXLA as the JIT compiler:
Nx.Defn.jit(fn params -> Nx.Defn.value_and_grad(params, loss_fn) end, compiler:
EXLA).(model_state.data)

Look at

lib/nx/binary_backend.ex:1669 in select_and_scatter/8.

Environment

  • Nx: 0.10.0
  • Elixir: 1.18.4
  • OTP: 27
  • OS: Linux (WSL2)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions