@@ -150,16 +150,6 @@ pub enum Activation {
150150 Tanh ,
151151}
152152
153- impl std:: fmt:: Display for Activation {
154- fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
155- match self {
156- Activation :: None => write ! ( f, "NONE" ) ,
157- Activation :: SquaredRelu => write ! ( f, "SQUARED_RELU" ) ,
158- Activation :: Tanh => write ! ( f, "TANH" ) ,
159- }
160- }
161- }
162-
163153impl Macros {
164154 /// Define a `u32` macro `NF4_BLOCK_SIZE`.
165155 pub fn nf4 ( mut self , block_size : u32 ) -> Self {
@@ -196,6 +186,16 @@ impl Macros {
196186 }
197187 }
198188
189+ pub fn activate ( mut self , name : impl Into < String > , value : Activation ) -> Self {
190+ let name = name. into ( ) ;
191+ match value {
192+ Activation :: None => self . insert ( name, "" . into ( ) ) ,
193+ Activation :: SquaredRelu => self . insert ( name, "squared_relu" . into ( ) ) ,
194+ Activation :: Tanh => self . insert ( name, "tanh" . into ( ) ) ,
195+ } ;
196+ self
197+ }
198+
199199 /// Define the macro specifies input/output tensor data type.
200200 pub fn tensor < T : Float > (
201201 mut self ,
@@ -654,7 +654,7 @@ impl TensorOp {
654654 . u32 ( "BLOCK_SIZE" , BLOCK_SIZE )
655655 . tensor ( & input, Some ( "IN" ) )
656656 . tensor ( & output, Some ( "OUT" ) )
657- . custom ( active , Some ( "ACT" ) ) ,
657+ . activate ( "ACT" , active ) ,
658658 ) ;
659659 #[ cfg( feature = "subgroup-ops" ) ]
660660 let pipeline = context. checkout_pipeline (
@@ -667,7 +667,7 @@ impl TensorOp {
667667 . u32 ( "BLOCK_SIZE" , BLOCK_SIZE )
668668 . tensor ( & input, Some ( "IN" ) )
669669 . tensor ( & output, Some ( "OUT" ) )
670- . custom ( active , Some ( "ACT" ) ) ,
670+ . activate ( "ACT" , active ) ,
671671 ) ;
672672 let bindings = vec ! [ context. device. create_bind_group( & BindGroupDescriptor {
673673 label: None ,
@@ -743,7 +743,7 @@ impl TensorOp {
743743 . int8 ( Self :: INT8_BLOCK_SIZE )
744744 . tensor ( & input, Some ( "IN" ) )
745745 . tensor ( & output, Some ( "OUT" ) )
746- . custom ( active , Some ( "ACT" ) ) ,
746+ . activate ( "ACT" , active ) ,
747747 ) ;
748748 #[ cfg( feature = "subgroup-ops" ) ]
749749 let pipeline = context. checkout_pipeline (
@@ -757,7 +757,7 @@ impl TensorOp {
757757 . int8 ( Self :: INT8_BLOCK_SIZE )
758758 . tensor ( & input, Some ( "IN" ) )
759759 . tensor ( & output, Some ( "OUT" ) )
760- . custom ( active , Some ( "ACT" ) ) ,
760+ . activate ( "ACT" , active ) ,
761761 ) ;
762762 let bindings = vec ! [ context. device. create_bind_group( & BindGroupDescriptor {
763763 label: None ,
@@ -837,7 +837,7 @@ impl TensorOp {
837837 . nf4 ( Self :: NF4_BLOCK_SIZE )
838838 . tensor ( & input, Some ( "IN" ) )
839839 . tensor ( & output, Some ( "OUT" ) )
840- . custom ( active , Some ( "ACT" ) ) ,
840+ . activate ( "ACT" , active ) ,
841841 ) ;
842842 #[ cfg( feature = "subgroup-ops" ) ]
843843 let pipeline = context. checkout_pipeline (
@@ -851,7 +851,7 @@ impl TensorOp {
851851 . nf4 ( Self :: NF4_BLOCK_SIZE )
852852 . tensor ( & input, Some ( "IN" ) )
853853 . tensor ( & output, Some ( "OUT" ) )
854- . custom ( active , Some ( "ACT" ) ) ,
854+ . activate ( "ACT" , active ) ,
855855 ) ;
856856 let bindings = vec ! [ context. device. create_bind_group( & BindGroupDescriptor {
857857 label: None ,
@@ -932,7 +932,7 @@ impl TensorOp {
932932 . u32 ( "BLOCK_SIZE" , BLOCK_SIZE )
933933 . tensor ( & input, Some ( "IN" ) )
934934 . tensor ( & output, Some ( "OUT" ) )
935- . custom ( active , Some ( "ACT" ) ) ,
935+ . activate ( "ACT" , active ) ,
936936 ) ;
937937 let bindings = vec ! [ context. device. create_bind_group( & BindGroupDescriptor {
938938 label: None ,
@@ -1013,7 +1013,7 @@ impl TensorOp {
10131013 . int8 ( Self :: INT8_BLOCK_SIZE )
10141014 . tensor ( & input, Some ( "IN" ) )
10151015 . tensor ( & output, Some ( "OUT" ) )
1016- . custom ( active , Some ( "ACT" ) ) ,
1016+ . activate ( "ACT" , active ) ,
10171017 ) ;
10181018 let bindings = vec ! [ context. device. create_bind_group( & BindGroupDescriptor {
10191019 label: None ,
@@ -1098,7 +1098,7 @@ impl TensorOp {
10981098 . nf4 ( Self :: NF4_BLOCK_SIZE )
10991099 . tensor ( & input, Some ( "IN" ) )
11001100 . tensor ( & output, Some ( "OUT" ) )
1101- . custom ( active , Some ( "ACT" ) ) ,
1101+ . activate ( "ACT" , active ) ,
11021102 ) ;
11031103 let bindings = vec ! [ context. device. create_bind_group( & BindGroupDescriptor {
11041104 label: None ,
0 commit comments