Skip to content

Commit 19e3045

Browse files
Refactor adaptive pooling with shared utils and base classes
1 parent 1603dd9 commit 19e3045

15 files changed

+1011
-781
lines changed

keras/src/backend/common/backend_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import functools
2+
import math
23
import operator
34
import re
45
import warnings
@@ -539,3 +540,10 @@ def slice_along_axis(x, start=0, stop=None, step=1, axis=0):
539540
-1 - axis
540541
)
541542
return x[tuple(slices)]
543+
544+
545+
def compute_adaptive_pooling_window_sizes(input_dim, output_dim):
546+
"""Compute small and big window sizes for adaptive pooling."""
547+
small = math.ceil(input_dim / output_dim)
548+
big = small + 1
549+
return small, big

keras/src/backend/jax/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@
2525
from keras.src.backend.jax.core import shape
2626
from keras.src.backend.jax.core import stop_gradient
2727
from keras.src.backend.jax.core import vectorized_map
28-
from keras.src.backend.jax.nn import adaptive_avg_pool
29-
from keras.src.backend.jax.nn import adaptive_max_pool
3028
from keras.src.backend.jax.rnn import cudnn_ok
3129
from keras.src.backend.jax.rnn import gru
3230
from keras.src.backend.jax.rnn import lstm

0 commit comments

Comments
 (0)