Skip to content

Commit 8048666

Browse files
committed
Refactor TensorOp and TensorGpuView.
1 parent 287bae3 commit 8048666

File tree

9 files changed

+413
-305
lines changed

9 files changed

+413
-305
lines changed

examples/puzzle15.rs renamed to examples/puzzle15/main.rs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,28 @@ use half::f16;
88
#[cfg(not(debug_assertions))]
99
use itertools::Itertools;
1010
use memmap2::Mmap;
11+
use ops::TensorOpExt;
1112
use safetensors::SafeTensors;
1213
use tokio::{
1314
fs::File,
1415
io::{AsyncReadExt, BufReader},
1516
};
1617
use web_rwkv::{
1718
context::{Context, ContextBuilder, InstanceExt},
19+
num::Float,
1820
runtime::{
1921
infer::{InferInput, InferInputBatch, InferOption},
2022
loader::Loader,
2123
model::{ContextAutoLimits, ModelBuilder, ModelInfo},
2224
v6, TokioRuntime,
2325
},
26+
tensor::ops::TensorOp,
2427
tokenizer::Tokenizer,
2528
wgpu,
2629
};
2730

31+
mod ops;
32+
2833
const PROMPT: &str = r"<input>
2934
<board>
3035
15 0 2 12
@@ -85,6 +90,25 @@ async fn load_tokenizer() -> Result<Tokenizer> {
8590
Ok(Tokenizer::new(&contents)?)
8691
}
8792

93+
fn make_hooks<F: Float>(info: &ModelInfo) -> Result<v6::HookMap<F>> {
94+
let mut hooks = v6::HookMap::new();
95+
96+
for layer in 0..info.num_layer {
97+
hooks.insert(
98+
v6::Hook::PreAttTimeDecayActivate(layer),
99+
Box::new(move |frame: v6::Frame<F>| {
100+
let ops = vec![TensorOp::mul_exp(
101+
&frame.buffer.time_decay,
102+
&frame.buffer.att_k,
103+
)?];
104+
Ok(TensorOp::List(ops))
105+
}),
106+
);
107+
}
108+
109+
Ok(hooks)
110+
}
111+
88112
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, ValueEnum)]
89113
enum EmbedDevice {
90114
#[default]
@@ -138,12 +162,13 @@ async fn main() -> Result<()> {
138162

139163
let embed_device = cli.embed_device.unwrap_or(EmbedDevice::Cpu).into();
140164

165+
let hooks = make_hooks(&info)?;
141166
let model = ModelBuilder::new(&context, model)
142167
.embed_device(embed_device)
143168
.rescale(0)
144169
.build_v6()
145170
.await?;
146-
let bundle = v6::Bundle::<f16>::new(model, 1);
171+
let bundle = v6::Bundle::<f16>::new_with_hooks(model, 1, hooks);
147172
let runtime = TokioRuntime::new(bundle).await;
148173

149174
let tokens = tokenizer.encode(PROMPT.as_bytes())?;

examples/puzzle15/mul_exp.wgsl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
struct View {
2+
shape: vec4<u32>,
3+
stride: vec4<u32>,
4+
offset: vec4<u32>,
5+
};
6+
7+
@group(0) @binding(0) var<uniform> source: View;
8+
@group(0) @binding(1) var<uniform> destination: View;
9+
10+
#ifdef IN_FP16
11+
@group(0) @binding(2) var<storage, read> input: array<vec2<u32>>; // (B, T, C)
12+
#else
13+
@group(0) @binding(2) var<storage, read> input: array<vec4<f32>>; // (B, T, C)
14+
#endif
15+
#ifdef OUT_FP16
16+
@group(0) @binding(3) var<storage, read_write> output: array<vec2<u32>>; // (B, T, C)
17+
#else
18+
@group(0) @binding(3) var<storage, read_write> output: array<vec4<f32>>; // (B, T, C)
19+
#endif
20+
21+
fn pack4x16float(x: vec4<f32>) -> vec2<u32> {
22+
return vec2<u32>(pack2x16float(x.xy), pack2x16float(x.zw));
23+
}
24+
25+
fn unpack4x16float(x: vec2<u32>) -> vec4<f32> {
26+
return vec4<f32>(unpack2x16float(x.x), unpack2x16float(x.y));
27+
}
28+
29+
fn compute_index(view: View, batch: u32, token: u32, index: u32) -> u32 {
30+
let stride = view.stride.x >> 2u;
31+
let offset = vec3<u32>(view.offset.zy, view.offset.x >> 2u);
32+
return dot(vec3<u32>(batch, token, index) + offset, vec3<u32>(view.stride.y * stride, stride, 1u));
33+
}
34+
35+
@compute @workgroup_size(BLOCK_SIZE, 1, 1)
36+
fn mul_exp(@builtin(global_invocation_id) invocation_id: vec3<u32>) {
37+
let stride = destination.shape.x / 4u;
38+
let index = invocation_id.x;
39+
let token = invocation_id.y;
40+
let batch = invocation_id.z;
41+
42+
if index < stride {
43+
#ifdef IN_FP16
44+
let x = unpack4x16float(input[compute_index(source, batch, select(token, 0u, source.shape.y == 1u), index)]);
45+
#else
46+
let x = input[compute_index(source, batch, select(token, 0u, source.shape.y == 1u), index)];
47+
#endif
48+
let bti = compute_index(destination, batch, token, index);
49+
#ifdef OUT_FP16
50+
output[bti] = pack4x16float(exp(min(x, vec4<f32>(0.0))) * unpack4x16float(output[bti]));
51+
#else
52+
output[bti] = exp(min(x, vec4<f32>(0.0))) * output[bti];
53+
#endif
54+
}
55+
}

examples/puzzle15/ops.rs

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
use web_rwkv::{
2+
context::Macros,
3+
num::Float,
4+
tensor::{ops::TensorOp, TensorError, TensorGpuView, TensorShape},
5+
wgpu::{BindGroupDescriptor, BindGroupEntry},
6+
};
7+
8+
pub trait TensorOpExt: Sized {
9+
/// Multiply `input` to exponential of `output`.
10+
/// - `input` shape: `[C, 1, B]` or `[C, T, B]`.
11+
/// - `output` shape: `[C, T, B]`.
12+
fn mul_exp<'a, 'b, F0: Float, F1: Float>(
13+
input: impl Into<TensorGpuView<'a, F0>>,
14+
output: impl Into<TensorGpuView<'b, F1>>,
15+
) -> Result<Self, TensorError>;
16+
}
17+
18+
impl TensorOpExt for TensorOp {
19+
fn mul_exp<'a, 'b, F0: Float, F1: Float>(
20+
input: impl Into<TensorGpuView<'a, F0>>,
21+
output: impl Into<TensorGpuView<'b, F1>>,
22+
) -> Result<Self, TensorError> {
23+
const BLOCK_SIZE: u32 = 128;
24+
25+
let input: TensorGpuView<_> = input.into();
26+
let output: TensorGpuView<_> = output.into();
27+
28+
let shape = {
29+
let [index, token, batch, _] = output.shape().into();
30+
input
31+
.check_shape([index, 1, batch, 1])
32+
.or(input.check_shape([index, token, batch, 1]))?;
33+
output.check_shape([index, token, batch, 1])?;
34+
output.shape()
35+
};
36+
37+
let context = output.context();
38+
let pipeline = context.checkout_pipeline(
39+
"mul_exp",
40+
include_str!("mul_exp.wgsl"),
41+
"mul_exp",
42+
None,
43+
Macros::new()
44+
.u32("BLOCK_SIZE", BLOCK_SIZE)
45+
.tensor(&input, Some("IN"))
46+
.tensor(&output, Some("OUT")),
47+
);
48+
let bindings = vec![context.device.create_bind_group(&BindGroupDescriptor {
49+
label: None,
50+
layout: &pipeline.layout,
51+
entries: &[
52+
BindGroupEntry {
53+
binding: 0,
54+
resource: input.meta_binding(),
55+
},
56+
BindGroupEntry {
57+
binding: 1,
58+
resource: output.meta_binding(),
59+
},
60+
BindGroupEntry {
61+
binding: 2,
62+
resource: input.binding(),
63+
},
64+
BindGroupEntry {
65+
binding: 3,
66+
resource: output.binding(),
67+
},
68+
],
69+
})];
70+
71+
Ok(Self::Atom {
72+
pipeline,
73+
bindings,
74+
dispatch: [
75+
u32::div_ceil(shape[0] as u32 / 4, BLOCK_SIZE),
76+
shape[1] as u32,
77+
shape[2] as u32,
78+
],
79+
})
80+
}
81+
}

0 commit comments

Comments
 (0)