@@ -375,6 +375,7 @@ def _aten_slice(self, dim=0, start=None, end=None, step=1):
375375 return self [tuple (dims )]
376376
377377
378+ @op (torch .ops .aten .positive )
378379@op (torch .ops .aten .detach )
379380def _aten_detach (self ):
380381 return self
@@ -2951,12 +2952,14 @@ def _aten_log2(x):
29512952
29522953# aten.logical_and
29532954@op (torch .ops .aten .logical_and )
2955+ @op (torch .ops .aten .__and__ )
29542956def _aten_logical_and (self , other ):
29552957 return jnp .logical_and (self , other )
29562958
29572959
29582960# aten.logical_or
29592961@op (torch .ops .aten .logical_or )
2962+ @op (torch .ops .aten .__or__ )
29602963def _aten_logical_or (self , other ):
29612964 return jnp .logical_or (self , other )
29622965
@@ -2998,6 +3001,7 @@ def _aten_logcumsumexp(self, dim=None):
29983001# aten.max_pool3d_backward
29993002# aten.logical_xor
30003003@op (torch .ops .aten .logical_xor )
3004+ @op (torch .ops .aten .__xor__ )
30013005def _aten_logical_xor (self , other ):
30023006 return jnp .logical_xor (self , other )
30033007
@@ -4946,7 +4950,7 @@ def _aten__linalg_solve_ex(a, b):
49464950 res = jnp .linalg .solve (a , b )
49474951 if batched :
49484952 res = res .squeeze (- 1 )
4949- info_shape = a .shape [0 ] if len ( a . shape ) >= 3 else [ ]
4953+ info_shape = a .shape [: - 2 ]
49504954 info = jnp .zeros (info_shape , dtype = mappings .t2j_dtype (torch .int32 ))
49514955 return res , info
49524956
@@ -5497,6 +5501,50 @@ def _aten_pad(self, pad, mode='constant', value=None):
54975501 )
54985502
54995503
5504+ @op (torch .ops .aten .is_nonzero )
5505+ def _aten_is_nonzero (a ):
5506+ a = jnp .squeeze (a )
5507+ if a .shape == (0 ,):
5508+ raise RuntimeError ('bool value of Tensor with no values is ambiguous' )
5509+ if a .ndim != 0 :
5510+ raise RuntimeError (
5511+ 'bool value of Tensor with more than one value is ambiguous' )
5512+ return a .item () != 0
5513+
5514+
5515+ @op (torch .ops .aten .logit )
5516+ def _aten_logit (self : jax .Array , eps : float | None = None ) -> jax .Array :
5517+ """
5518+ Computes the logit function of the input tensor.
5519+
5520+ logit(p) = log(p / (1 - p))
5521+
5522+ Args:
5523+ self: Input tensor.
5524+ eps: A small value to clip the input tensor to avoid log(0) or division by zero.
5525+ If None, no clipping is performed.
5526+
5527+ Returns:
5528+ A tensor with the logit of each element of the input.
5529+ """
5530+ if eps is not None :
5531+ self = jnp .clip (self , eps , 1.0 - eps )
5532+ res = jnp .log (self / (1.0 - self ))
5533+ res = res .astype (mappings .t2j_dtype (torch .get_default_dtype ()))
5534+ return res
5535+
5536+
5537+ @op (torch .ops .aten .floor_divide )
5538+ def _aten_floor_divide (x , y ):
5539+ res = jnp .floor_divide (x , y )
5540+ return res
5541+
5542+
5543+ @op (torch .ops .aten ._assert_tensor_metadata )
5544+ def _aten__assert_tensor_metadata (* args , ** kwargs ):
5545+ pass
5546+
5547+
55005548mutation_ops_to_functional = {
55015549 torch .ops .aten .add_ :
55025550 op_base .InplaceOp (torch .ops .aten .add ),
@@ -5565,6 +5613,10 @@ def _aten_pad(self, pad, mode='constant', value=None):
55655613 op_base .InplaceOp (torch .ops .aten .scatter ),
55665614 torch .ops .aten .bitwise_or_ :
55675615 op_base .InplaceOp (torch .ops .aten .bitwise_or ),
5616+ torch .ops .aten .floor_divide_ :
5617+ op_base .InplaceOp (torch .ops .aten .floor_divide ),
5618+ torch .ops .aten .remainder_ :
5619+ op_base .InplaceOp (torch .ops .aten .remainder ),
55685620}
55695621
55705622# Note: tuple comparisons work intuitively, e.g. `_jax_version >= (0, 4, 32)`.
0 commit comments