Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@
from keras.src.ops.numpy import true_divide as true_divide
from keras.src.ops.numpy import trunc as trunc
from keras.src.ops.numpy import unravel_index as unravel_index
from keras.src.ops.numpy import vander as vander
from keras.src.ops.numpy import var as var
from keras.src.ops.numpy import vdot as vdot
from keras.src.ops.numpy import vectorize as vectorize
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@
from keras.src.ops.numpy import true_divide as true_divide
from keras.src.ops.numpy import trunc as trunc
from keras.src.ops.numpy import unravel_index as unravel_index
from keras.src.ops.numpy import vander as vander
from keras.src.ops.numpy import var as var
from keras.src.ops.numpy import vdot as vdot
from keras.src.ops.numpy import vectorize as vectorize
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@
from keras.src.ops.numpy import true_divide as true_divide
from keras.src.ops.numpy import trunc as trunc
from keras.src.ops.numpy import unravel_index as unravel_index
from keras.src.ops.numpy import vander as vander
from keras.src.ops.numpy import var as var
from keras.src.ops.numpy import vdot as vdot
from keras.src.ops.numpy import vectorize as vectorize
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@
from keras.src.ops.numpy import true_divide as true_divide
from keras.src.ops.numpy import trunc as trunc
from keras.src.ops.numpy import unravel_index as unravel_index
from keras.src.ops.numpy import vander as vander
from keras.src.ops.numpy import var as var
from keras.src.ops.numpy import vdot as vdot
from keras.src.ops.numpy import vectorize as vectorize
Expand Down
28 changes: 28 additions & 0 deletions keras/src/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1385,6 +1385,34 @@ def trapezoid(y, x=None, dx=1.0, axis=-1):
return jnp.trapezoid(y, x, dx=dx, axis=axis)


def vander(x, N=None, increasing=False):
x = convert_to_tensor(x)

if x.ndim != 1:
raise ValueError(
f"Input must be a one-dimensional array. Received: x.ndim={x.ndim}"
)

if N is not None:
if N < 0:
raise ValueError(
f"Argument 'N' must be nonnegative. Received: N={N}"
)

if not isinstance(N, int):
raise TypeError(
f"Argument 'N' must be integer. Received: dtype={type(N)}"
)

if not isinstance(increasing, bool):
raise TypeError(
"Argument 'increasing' must be bool. "
f"Received: dtype={type(increasing)}"
)

return jnp.vander(x, N=N, increasing=increasing)


def var(x, axis=None, keepdims=False):
x = convert_to_tensor(x)
# `jnp.var` does not handle low precision (e.g., float16) overflow
Expand Down
30 changes: 30 additions & 0 deletions keras/src/backend/numpy/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1359,6 +1359,36 @@ def trapezoid(y, x=None, dx=1.0, axis=-1):
return np.trapezoid(y, x, dx=dx, axis=axis).astype(result_dtype)


def vander(x, N=None, increasing=False):
x = convert_to_tensor(x)

if x.ndim != 1:
raise ValueError(
f"Input must be a one-dimensional array. Received: x.ndim={x.ndim}"
)

if N is not None:
if N < 0:
raise ValueError(
f"Argument 'N' must be nonnegative. Received: N={N}"
)

if not isinstance(N, int):
raise TypeError(
f"Argument 'N' must be integer. Received: dtype={type(N)}"
)

if not isinstance(increasing, bool):
raise TypeError(
"Argument 'increasing' must be bool. "
f"Received: dtype={type(increasing)}"
)

result_dtype = dtypes.result_type(x.dtype)
x = x.astype(config.floatx())
return np.vander(x, N=N, increasing=increasing).astype(result_dtype)


def var(x, axis=None, keepdims=False):
axis = standardize_axis_for_numpy(axis)
x = convert_to_tensor(x)
Expand Down
2 changes: 2 additions & 0 deletions keras/src/backend/openvino/excluded_concrete_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ NumpyDtypeTest::test_trace
NumpyDtypeTest::test_trapezoid
NumpyDtypeTest::test_trunc
NumpyDtypeTest::test_unravel
NumpyDtypeTest::test_vander
NumpyDtypeTest::test_var
NumpyDtypeTest::test_vdot
NumpyDtypeTest::test_view
Expand Down Expand Up @@ -91,6 +92,7 @@ NumpyOneInputOpsCorrectnessTest::test_trace
NumpyOneInputOpsCorrectnessTest::test_trapezoid
NumpyOneInputOpsCorrectnessTest::test_trunc
NumpyOneInputOpsCorrectnessTest::test_unravel_index
NumpyOneInputOpsCorrectnessTest::test_vander
NumpyOneInputOpsCorrectnessTest::test_vectorize
NumpyOneInputOpsCorrectnessTest::test_vstack
NumpyOneInputOpsCorrectnessTest::test_view
Expand Down
4 changes: 4 additions & 0 deletions keras/src/backend/openvino/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2457,6 +2457,10 @@ def trapezoid(y, x=None, dx=1.0, axis=-1):
)


def vander(x, N=None, increasing=False):
raise NotImplementedError("`vander` is not supported with openvino backend")


def var(x, axis=None, keepdims=False):
x = get_ov_output(x)
x_type = x.get_element_type()
Expand Down
44 changes: 44 additions & 0 deletions keras/src/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3080,6 +3080,50 @@ def _move_axis_to_last(tensor, axis):
return result


def vander(x, N=None, increasing=False):
x = convert_to_tensor(x)

if x.shape.rank != 1:
raise ValueError(
"Input must be a one-dimensional array. "
f"Received: x.ndim={x.shape.rank}"
)

if N is not None:
if N < 0:
raise ValueError(
f"Argument 'N' must be nonnegative. Received: N={N}"
)

if not isinstance(N, int):
raise TypeError(
f"Argument 'N' must be integer. Received: dtype={type(N)}"
)

if not isinstance(increasing, bool):
raise TypeError(
"Argument 'increasing' must be bool. "
f"Received: dtype={type(increasing)}"
)

result_dtype = dtypes.result_type(x.dtype)

if N is None:
N = tf.shape(x)[0]

if increasing:
powers = tf.range(N)
else:
powers = tf.range(N - 1, -1, -1)

x_exp = tf.expand_dims(x, axis=-1)

x_exp = tf.cast(x_exp, tf.float32)
powers = tf.cast(powers, tf.float32)
vander = tf.math.pow(x_exp, powers)
return tf.cast(vander, result_dtype)


def var(x, axis=None, keepdims=False):
x = convert_to_tensor(x)
compute_dtype = dtypes.result_type(x.dtype, "float32")
Expand Down
29 changes: 29 additions & 0 deletions keras/src/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1819,6 +1819,35 @@ def trapezoid(y, x=None, dx=1.0, axis=-1):
return torch.trapz(y, dx=dx, dim=axis)


def vander(x, N=None, increasing=False):
x = convert_to_tensor(x)

if x.ndim != 1:
raise ValueError(
f"Input must be a one-dimensional array. Received: x.ndim={x.ndim}"
)

if N is not None:
if N < 0:
raise ValueError(
f"Argument 'N' must be nonnegative. Received: N={N}"
)

if not isinstance(N, int):
raise TypeError(
f"Argument 'N' must be integer. Received: dtype={type(N)}"
)

if not isinstance(increasing, bool):
raise TypeError(
"Argument 'increasing' must be bool. "
f"Received: dtype={type(increasing)}"
)

result_dtype = dtypes.result_type(x.dtype)
return cast(torch.vander(x, N=N, increasing=increasing), result_dtype)


def var(x, axis=None, keepdims=False):
x = convert_to_tensor(x)
compute_dtype = dtypes.result_type(x.dtype, "float32")
Expand Down
47 changes: 47 additions & 0 deletions keras/src/ops/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7279,6 +7279,53 @@ def mean(x, axis=None, keepdims=False):
return backend.numpy.mean(x, axis=axis, keepdims=keepdims)


class Vander(Operation):
def __init__(self, N=None, increasing=False, *, name=None):
super().__init__(name=name)
self.N = N
self.increasing = increasing

def call(self, x):
return backend.numpy.vander(x, self.N, self.increasing)

def compute_output_spec(self, x):
if self.N is None:
N = x.shape[0]
else:
N = self.N

out_shape = list(x.shape)
out_shape.append(N)
return KerasTensor(tuple(out_shape), dtype=x.dtype)


@keras_export(["keras.ops.vander", "keras.ops.numpy.vander"])
def vander(x, N=None, increasing=False):
"""Generate a Vandermonde matrix.

Args:
x: 1D input tensor.
N: Number of columns. If None, `N` = `len(x)`.
increasing: Order of powers. If True, powers increase left to right.

Returns:
Output tensor, vandermonde matrix of shape `(len(x), N)`.

Example:
>>> import numpy as np
>>> import keras
>>> x = np.array([1, 2, 3, 5])
>>> keras.ops.vander(x)
array([[ 1, 1, 1, 1],
[ 8, 4, 2, 1],
[ 27, 9, 3, 1],
[125, 25, 5, 1]])
"""
if any_symbolic_tensors((x,)):
return Vander(N=N, increasing=increasing).symbolic_call(x)
return backend.numpy.vander(x, N=N, increasing=increasing)


class Var(Operation):
def __init__(self, axis=None, keepdims=False, *, name=None):
super().__init__(name=name)
Expand Down
45 changes: 45 additions & 0 deletions keras/src/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,6 +1140,10 @@ def test_trapezoid(self):
x = KerasTensor((None, 3, 3))
self.assertEqual(knp.trapezoid(x, axis=1).shape, (None, 3))

def test_vander(self):
x = KerasTensor((None,))
self.assertEqual(knp.vander(x).shape, (None, None))

def test_var(self):
x = KerasTensor((None, 3))
self.assertEqual(knp.var(x).shape, ())
Expand Down Expand Up @@ -1899,6 +1903,10 @@ def test_trapezoid(self):
x = KerasTensor((2, 3))
self.assertEqual(knp.trapezoid(x).shape, (2,))

def test_vander(self):
x = KerasTensor((2,))
self.assertEqual(knp.vander(x).shape, (2, 2))

def test_var(self):
x = KerasTensor((2, 3))
self.assertEqual(knp.var(x).shape, ())
Expand Down Expand Up @@ -3717,6 +3725,25 @@ def test_trapezoid(self):
np.trapezoid(y, x=x, axis=1),
)

def test_vander(self):
x = np.random.random((3,))
N = 6
increasing = True

self.assertAllClose(knp.vander(x), np.vander(x))
self.assertAllClose(knp.vander(x, N=N), np.vander(x, N=N))
self.assertAllClose(
knp.vander(x, N=N, increasing=increasing),
np.vander(x, N=N, increasing=increasing),
)

self.assertAllClose(knp.Vander().call(x), np.vander(x))
self.assertAllClose(knp.Vander(N=N).call(x), np.vander(x, N=N))
self.assertAllClose(
knp.Vander(N=N, increasing=increasing).call(x),
np.vander(x, N=N, increasing=increasing),
)

def test_var(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(knp.var(x), np.var(x))
Expand Down Expand Up @@ -9162,6 +9189,24 @@ def test_trapezoid(self, dtype):
expected_dtype,
)

@parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
def test_vander(self, dtype):
import jax.numpy as jnp

x = knp.ones((2,), dtype=dtype)
x_jax = jnp.ones((2,), dtype=dtype)

if dtype == "bool":
self.skipTest("vander does not support bool")

expected_dtype = standardize_dtype(jnp.vander(x_jax).dtype)

self.assertEqual(standardize_dtype(knp.vander(x).dtype), expected_dtype)
self.assertEqual(
standardize_dtype(knp.Vander().symbolic_call(x).dtype),
expected_dtype,
)

@parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
def test_var(self, dtype):
import jax.numpy as jnp
Expand Down