Skip to content

Commit ef021f1

Browse files
committed
Simplify activation function selection.
1 parent 110bf77 commit ef021f1

File tree

10 files changed

+37
-97
lines changed

10 files changed

+37
-97
lines changed

src/shaders/matmul_mat_fp16.wgsl

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -124,18 +124,10 @@ fn matmul(in: Input) {
124124
}
125125

126126
if all(u < vec2<u32>(ra.y, rb.y)) {
127-
#ifdef ACT_SQUARED_RELU
128-
local_sum[0] = squared_relu(local_sum[0]);
129-
local_sum[1] = squared_relu(local_sum[1]);
130-
local_sum[2] = squared_relu(local_sum[2]);
131-
local_sum[3] = squared_relu(local_sum[3]);
132-
#endif
133-
#ifdef ACT_TANH
134-
local_sum[0] = tanh(local_sum[0]);
135-
local_sum[1] = tanh(local_sum[1]);
136-
local_sum[2] = tanh(local_sum[2]);
137-
local_sum[3] = tanh(local_sum[3]);
138-
#endif
127+
local_sum[0] = ACT(local_sum[0]);
128+
local_sum[1] = ACT(local_sum[1]);
129+
local_sum[2] = ACT(local_sum[2]);
130+
local_sum[3] = ACT(local_sum[3]);
139131
#ifdef OUT_FP16
140132
output[compute_index(destination, in.uid.z, u.y + 0u, in.uid.x)] = pack4x16float(local_sum[0]);
141133
output[compute_index(destination, in.uid.z, u.y + 1u, in.uid.x)] = pack4x16float(local_sum[1]);

src/shaders/matmul_mat_int8.wgsl

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -138,18 +138,10 @@ fn matmul(in: Input) {
138138
}
139139

140140
if all(u < vec2<u32>(ra.y, rb.y)) {
141-
#ifdef ACT_SQUARED_RELU
142-
local_sum[0] = squared_relu(local_sum[0]);
143-
local_sum[1] = squared_relu(local_sum[1]);
144-
local_sum[2] = squared_relu(local_sum[2]);
145-
local_sum[3] = squared_relu(local_sum[3]);
146-
#endif
147-
#ifdef ACT_TANH
148-
local_sum[0] = tanh(local_sum[0]);
149-
local_sum[1] = tanh(local_sum[1]);
150-
local_sum[2] = tanh(local_sum[2]);
151-
local_sum[3] = tanh(local_sum[3]);
152-
#endif
141+
local_sum[0] = ACT(local_sum[0]);
142+
local_sum[1] = ACT(local_sum[1]);
143+
local_sum[2] = ACT(local_sum[2]);
144+
local_sum[3] = ACT(local_sum[3]);
153145
#ifdef OUT_FP16
154146
output[compute_index(destination, in.uid.z, u.y + 0u, in.uid.x)] = pack4x16float(local_sum[0]);
155147
output[compute_index(destination, in.uid.z, u.y + 1u, in.uid.x)] = pack4x16float(local_sum[1]);

src/shaders/matmul_mat_nf4.wgsl

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -205,18 +205,10 @@ fn matmul(in: Input) {
205205
}
206206

207207
if all(u < vec2<u32>(ra.y, rb.y)) {
208-
#ifdef ACT_SQUARED_RELU
209-
local_sum[0] = squared_relu(local_sum[0]);
210-
local_sum[1] = squared_relu(local_sum[1]);
211-
local_sum[2] = squared_relu(local_sum[2]);
212-
local_sum[3] = squared_relu(local_sum[3]);
213-
#endif
214-
#ifdef ACT_TANH
215-
local_sum[0] = tanh(local_sum[0]);
216-
local_sum[1] = tanh(local_sum[1]);
217-
local_sum[2] = tanh(local_sum[2]);
218-
local_sum[3] = tanh(local_sum[3]);
219-
#endif
208+
local_sum[0] = ACT(local_sum[0]);
209+
local_sum[1] = ACT(local_sum[1]);
210+
local_sum[2] = ACT(local_sum[2]);
211+
local_sum[3] = ACT(local_sum[3]);
220212
#ifdef OUT_FP16
221213
output[compute_index(destination, in.uid.z, u.y + 0u, in.uid.x, 4u)] = pack4x16float(local_sum[0]);
222214
output[compute_index(destination, in.uid.z, u.y + 1u, in.uid.x, 4u)] = pack4x16float(local_sum[1]);

src/shaders/matmul_vec_fp16.wgsl

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,7 @@ fn matmul(@builtin(global_invocation_id) invocation_id: vec3<u32>) {
9494

9595
if index == 0u {
9696
let btc = compute_index(destination, batch, token, channel);
97-
var out = sketch[0];
98-
#ifdef ACT_SQUARED_RELU
99-
out = squared_relu(out);
100-
#endif
101-
#ifdef ACT_TANH
102-
out = tanh(out);
103-
#endif
97+
let out = ACT(sketch[0]);
10498
#ifdef OUT_FP16
10599
output[btc] = pack4x16float(out);
106100
#else

src/shaders/matmul_vec_int8.wgsl

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,7 @@ fn matmul(@builtin(global_invocation_id) invocation_id: vec3<u32>) {
102102

103103
if index == 0u {
104104
let btc = compute_index(destination, batch, token, channel);
105-
var out = sketch[0];
106-
#ifdef ACT_SQUARED_RELU
107-
out = squared_relu(out);
108-
#endif
109-
#ifdef ACT_TANH
110-
out = tanh(out);
111-
#endif
105+
let out = ACT(sketch[0]);
112106
#ifdef OUT_FP16
113107
output[btc] = pack4x16float(out);
114108
#else

src/shaders/matmul_vec_nf4.wgsl

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -154,13 +154,7 @@ fn matmul(@builtin(global_invocation_id) invocation_id: vec3<u32>) {
154154

155155
if index == 0u {
156156
let btc = compute_index(destination, batch, token, channel, 2u);
157-
var out = sketch[0];
158-
#ifdef ACT_SQUARED_RELU
159-
out = squared_relu(out);
160-
#endif
161-
#ifdef ACT_TANH
162-
out = tanh(out);
163-
#endif
157+
let out = ACT(sketch[0]);
164158
#ifdef OUT_FP16
165159
output[btc] = pack4x16float(out);
166160
#else

src/shaders/subgroup/matmul_vec_fp16.wgsl

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,7 @@ fn matmul(
120120

121121
if index == 0u {
122122
let btc = compute_index(destination, batch, token, channel);
123-
var out = sketch[0];
124-
#ifdef ACT_SQUARED_RELU
125-
out = squared_relu(out);
126-
#endif
127-
#ifdef ACT_TANH
128-
out = tanh(out);
129-
#endif
123+
let out = ACT(sketch[0]);
130124
#ifdef OUT_FP16
131125
output[btc] = pack4x16float(out);
132126
#else

src/shaders/subgroup/matmul_vec_int8.wgsl

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,7 @@ fn matmul(
127127

128128
if index == 0u {
129129
let btc = compute_index(destination, batch, token, channel);
130-
var out = sketch[0];
131-
#ifdef ACT_SQUARED_RELU
132-
out = squared_relu(out);
133-
#endif
134-
#ifdef ACT_TANH
135-
out = tanh(out);
136-
#endif
130+
let out = ACT(sketch[0]);
137131
#ifdef OUT_FP16
138132
output[btc] = pack4x16float(out);
139133
#else

src/shaders/subgroup/matmul_vec_nf4.wgsl

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -172,13 +172,7 @@ fn matmul(
172172

173173
if index == 0u {
174174
let btc = compute_index(destination, batch, token, channel, 2u);
175-
var out = sketch[0];
176-
#ifdef ACT_SQUARED_RELU
177-
out = squared_relu(out);
178-
#endif
179-
#ifdef ACT_TANH
180-
out = tanh(out);
181-
#endif
175+
let out = ACT(sketch[0]);
182176
#ifdef OUT_FP16
183177
output[btc] = pack4x16float(out);
184178
#else

src/tensor/ops.rs

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -150,16 +150,6 @@ pub enum Activation {
150150
Tanh,
151151
}
152152

153-
impl std::fmt::Display for Activation {
154-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
155-
match self {
156-
Activation::None => write!(f, "NONE"),
157-
Activation::SquaredRelu => write!(f, "SQUARED_RELU"),
158-
Activation::Tanh => write!(f, "TANH"),
159-
}
160-
}
161-
}
162-
163153
impl Macros {
164154
/// Define a `u32` macro `NF4_BLOCK_SIZE`.
165155
pub fn nf4(mut self, block_size: u32) -> Self {
@@ -196,6 +186,16 @@ impl Macros {
196186
}
197187
}
198188

189+
pub fn activate(mut self, name: impl Into<String>, value: Activation) -> Self {
190+
let name = name.into();
191+
match value {
192+
Activation::None => self.insert(name, "".into()),
193+
Activation::SquaredRelu => self.insert(name, "squared_relu".into()),
194+
Activation::Tanh => self.insert(name, "tanh".into()),
195+
};
196+
self
197+
}
198+
199199
/// Define the macro specifies input/output tensor data type.
200200
pub fn tensor<T: Float>(
201201
mut self,
@@ -654,7 +654,7 @@ impl TensorOp {
654654
.u32("BLOCK_SIZE", BLOCK_SIZE)
655655
.tensor(&input, Some("IN"))
656656
.tensor(&output, Some("OUT"))
657-
.custom(active, Some("ACT")),
657+
.activate("ACT", active),
658658
);
659659
#[cfg(feature = "subgroup-ops")]
660660
let pipeline = context.checkout_pipeline(
@@ -667,7 +667,7 @@ impl TensorOp {
667667
.u32("BLOCK_SIZE", BLOCK_SIZE)
668668
.tensor(&input, Some("IN"))
669669
.tensor(&output, Some("OUT"))
670-
.custom(active, Some("ACT")),
670+
.activate("ACT", active),
671671
);
672672
let bindings = vec![context.device.create_bind_group(&BindGroupDescriptor {
673673
label: None,
@@ -743,7 +743,7 @@ impl TensorOp {
743743
.int8(Self::INT8_BLOCK_SIZE)
744744
.tensor(&input, Some("IN"))
745745
.tensor(&output, Some("OUT"))
746-
.custom(active, Some("ACT")),
746+
.activate("ACT", active),
747747
);
748748
#[cfg(feature = "subgroup-ops")]
749749
let pipeline = context.checkout_pipeline(
@@ -757,7 +757,7 @@ impl TensorOp {
757757
.int8(Self::INT8_BLOCK_SIZE)
758758
.tensor(&input, Some("IN"))
759759
.tensor(&output, Some("OUT"))
760-
.custom(active, Some("ACT")),
760+
.activate("ACT", active),
761761
);
762762
let bindings = vec![context.device.create_bind_group(&BindGroupDescriptor {
763763
label: None,
@@ -837,7 +837,7 @@ impl TensorOp {
837837
.nf4(Self::NF4_BLOCK_SIZE)
838838
.tensor(&input, Some("IN"))
839839
.tensor(&output, Some("OUT"))
840-
.custom(active, Some("ACT")),
840+
.activate("ACT", active),
841841
);
842842
#[cfg(feature = "subgroup-ops")]
843843
let pipeline = context.checkout_pipeline(
@@ -851,7 +851,7 @@ impl TensorOp {
851851
.nf4(Self::NF4_BLOCK_SIZE)
852852
.tensor(&input, Some("IN"))
853853
.tensor(&output, Some("OUT"))
854-
.custom(active, Some("ACT")),
854+
.activate("ACT", active),
855855
);
856856
let bindings = vec![context.device.create_bind_group(&BindGroupDescriptor {
857857
label: None,
@@ -932,7 +932,7 @@ impl TensorOp {
932932
.u32("BLOCK_SIZE", BLOCK_SIZE)
933933
.tensor(&input, Some("IN"))
934934
.tensor(&output, Some("OUT"))
935-
.custom(active, Some("ACT")),
935+
.activate("ACT", active),
936936
);
937937
let bindings = vec![context.device.create_bind_group(&BindGroupDescriptor {
938938
label: None,
@@ -1013,7 +1013,7 @@ impl TensorOp {
10131013
.int8(Self::INT8_BLOCK_SIZE)
10141014
.tensor(&input, Some("IN"))
10151015
.tensor(&output, Some("OUT"))
1016-
.custom(active, Some("ACT")),
1016+
.activate("ACT", active),
10171017
);
10181018
let bindings = vec![context.device.create_bind_group(&BindGroupDescriptor {
10191019
label: None,
@@ -1098,7 +1098,7 @@ impl TensorOp {
10981098
.nf4(Self::NF4_BLOCK_SIZE)
10991099
.tensor(&input, Some("IN"))
11001100
.tensor(&output, Some("OUT"))
1101-
.custom(active, Some("ACT")),
1101+
.activate("ACT", active),
11021102
);
11031103
let bindings = vec![context.device.create_bind_group(&BindGroupDescriptor {
11041104
label: None,

0 commit comments

Comments
 (0)