-
Notifications
You must be signed in to change notification settings - Fork 233
Expand file tree
/
Copy pathvmamba.py
More file actions
2483 lines (2173 loc) · 101 KB
/
vmamba.py
File metadata and controls
2483 lines (2173 loc) · 101 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
##########################################################
# simplified version
# just one file and include everything
# written by MzeroMiko
##########################################################
##########################################################
# usage:
# conda create -n vmamba python=3.10
# pip install torch==2.2 torchvision torchaudio triton pytest chardet yacs termcolor fvcore seaborn packaging ninja einops numpy==1.24.4 timm==0.4.12
# pip install https://github.com/state-spaces/mamba/releases/download/v2.2.4/mamba_ssm-2.2.4+cu12torch2.2cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
# python vmamba.py
##########################################################
##########################################################
# csm_triton.py
##########################################################
import torch
import warnings
WITH_TRITON = True
# WITH_TRITON = False
try:
import triton
import triton.language as tl
except:
WITH_TRITON = False
warnings.warn("Triton not installed, fall back to pytorch implements.")
# to make sure cached_property can be loaded for triton
if WITH_TRITON:
try:
from functools import cached_property
except:
warnings.warn("if you are using py37, add this line to functools.py: "
"cached_property = lambda func: property(lru_cache()(func))")
# torch implementation ========================================
def cross_scan_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
if in_channel_first:
B, C, H, W = x.shape
if scans == 0:
y = x.new_empty((B, 4, C, H * W))
y[:, 0, :, :] = x.flatten(2, 3)
y[:, 1, :, :] = x.transpose(dim0=2, dim1=3).flatten(2, 3)
y[:, 2:4, :, :] = torch.flip(y[:, 0:2, :, :], dims=[-1])
elif scans == 1:
y = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1)
elif scans == 2:
y = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1)
y = torch.cat([y, y.flip(dims=[-1])], dim=1)
elif scans == 3:
y = x.new_empty((B, 4, C, H * W))
y[:, 0, :, :] = x.flatten(2, 3)
y[:, 1, :, :] = torch.rot90(x, 1, dims=(2, 3)).flatten(2, 3)
y[:, 2, :, :] = torch.rot90(x, 2, dims=(2, 3)).flatten(2, 3)
y[:, 3, :, :] = torch.rot90(x, 3, dims=(2, 3)).flatten(2, 3)
else:
B, H, W, C = x.shape
if scans == 0:
y = x.new_empty((B, H * W, 4, C))
y[:, :, 0, :] = x.flatten(1, 2)
y[:, :, 1, :] = x.transpose(dim0=1, dim1=2).flatten(1, 2)
y[:, :, 2:4, :] = torch.flip(y[:, :, 0:2, :], dims=[1])
elif scans == 1:
y = x.view(B, H * W, 1, C).repeat(1, 1, 4, 1)
elif scans == 2:
y = x.view(B, H * W, 1, C).repeat(1, 1, 2, 1)
y = torch.cat([y, y.flip(dims=[1])], dim=2)
elif scans == 3:
y = x.new_empty((B, H * W, 4, C))
y[:, :, 0, :] = x.flatten(1, 2)
y[:, :, 1, :] = torch.rot90(x, 1, dims=(1, 2)).flatten(1, 2)
y[:, :, 2, :] = torch.rot90(x, 2, dims=(1, 2)).flatten(1, 2)
y[:, :, 3, :] = torch.rot90(x, 3, dims=(1, 2)).flatten(1, 2)
if in_channel_first and (not out_channel_first):
y = y.permute(0, 3, 1, 2).contiguous()
elif (not in_channel_first) and out_channel_first:
y = y.permute(0, 2, 3, 1).contiguous()
return y
def cross_merge_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
if out_channel_first:
B, K, D, H, W = y.shape
y = y.view(B, K, D, -1)
if scans == 0:
y = y[:, 0:2] + y[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
y = y[:, 0] + y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
elif scans == 1:
y = y.sum(1)
elif scans == 2:
y = y[:, 0:2] + y[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
y = y.sum(1)
elif scans == 3:
oy = y[:, 0, :, :].contiguous().view(B, D, -1)
oy = oy + torch.rot90(y.view(B, K, D, W, H)[:, 1, :, :, :], -1, dims=(2, 3)).flatten(2, 3)
oy = oy + torch.rot90(y.view(B, K, D, H, W)[:, 2, :, :, :], -2, dims=(2, 3)).flatten(2, 3)
oy = oy + torch.rot90(y.view(B, K, D, W, H)[:, 3, :, :, :], -3, dims=(2, 3)).flatten(2, 3)
y = oy
else:
B, H, W, K, D = y.shape
y = y.view(B, -1, K, D)
if scans == 0:
y = y[:, :, 0:2] + y[:, :, 2:4].flip(dims=[1]).view(B, -1, 2, D)
y = y[:, :, 0] + y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).contiguous().view(B, -1, D)
elif scans == 1:
y = y.sum(2)
elif scans == 2:
y = y[:, :, 0:2] + y[:, :, 2:4].flip(dims=[1]).view(B, -1, 2, D)
y = y.sum(2)
elif scans == 3:
oy = y[:, :, 0, :].contiguous().view(B, -1, D)
oy = oy + torch.rot90(y.view(B, W, H, K, D)[:, :, :, 1, :], -1, dims=(1, 2)).flatten(1, 2)
oy = oy + torch.rot90(y.view(B, H, W, K, D)[:, :, :, 2, :], -2, dims=(1, 2)).flatten(1, 2)
oy = oy + torch.rot90(y.view(B, W, H, K, D)[:, :, :, 3, :], -3, dims=(1, 2)).flatten(1, 2)
y = oy
if in_channel_first and (not out_channel_first):
y = y.permute(0, 2, 1).contiguous()
elif (not in_channel_first) and out_channel_first:
y = y.permute(0, 2, 1).contiguous()
return y
def cross_scan1b1_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
if in_channel_first:
B, _, C, H, W = x.shape
if scans == 0:
y = torch.stack([
x[:, 0].flatten(2, 3),
x[:, 1].transpose(dim0=2, dim1=3).flatten(2, 3),
torch.flip(x[:, 2].flatten(2, 3), dims=[-1]),
torch.flip(x[:, 3].transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]),
], dim=1)
elif scans == 1:
y = x.flatten(2, 3)
elif scans == 2:
y = torch.stack([
x[:, 0].flatten(2, 3),
x[:, 1].flatten(2, 3),
torch.flip(x[:, 2].flatten(2, 3), dims=[-1]),
torch.flip(x[:, 3].flatten(2, 3), dims=[-1]),
], dim=1)
elif scans == 3:
y = torch.stack([
x[:, 0, :, :, :].flatten(2, 3),
torch.rot90(x[:, 1, :, :, :], 1, dims=(2, 3)).flatten(2, 3),
torch.rot90(x[:, 2, :, :, :], 2, dims=(2, 3)).flatten(2, 3),
torch.rot90(x[:, 3, :, :, :], 3, dims=(2, 3)).flatten(2, 3),
], dim=1)
else:
B, H, W, _, C = x.shape
if scans == 0:
y = torch.stack([
x[:, :, :, 0].flatten(1, 2),
x[:, :, :, 1].transpose(dim0=1, dim1=2).flatten(1, 2),
torch.flip(x[:, :, :, 2].flatten(1, 2), dims=[1]),
torch.flip(x[:, :, :, 3].transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]),
], dim=2)
elif scans == 1:
y = x.flatten(1, 2)
elif scans == 2:
y = torch.stack([
x[:, 0].flatten(1, 2),
x[:, 1].flatten(1, 2),
torch.flip(x[:, 2].flatten(1, 2), dims=[-1]),
torch.flip(x[:, 3].flatten(1, 2), dims=[-1]),
], dim=2)
elif scans == 3:
y = torch.stack([
x[:, :, :, 0, :].flatten(1, 2),
torch.rot90(x[:, :, :, 1, :], 1, dims=(1, 2)).flatten(1, 2),
torch.rot90(x[:, :, :, 2, :], 2, dims=(1, 2)).flatten(1, 2),
torch.rot90(x[:, :, :, 3, :], 3, dims=(1, 2)).flatten(1, 2),
], dim=1)
if in_channel_first and (not out_channel_first):
y = y.permute(0, 3, 1, 2).contiguous()
elif (not in_channel_first) and out_channel_first:
y = y.permute(0, 2, 3, 1).contiguous()
return y
def cross_merge1b1_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
if out_channel_first:
B, K, D, H, W = y.shape
y = y.view(B, K, D, -1)
if scans == 0:
y = torch.stack([
y[:, 0],
y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3),
torch.flip(y[:, 2], dims=[-1]),
torch.flip(y[:, 3].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]),
], dim=1)
elif scans == 1:
y = y
elif scans == 2:
y = torch.stack([
y[:, 0],
y[:, 1],
torch.flip(y[:, 2], dims=[-1]),
torch.flip(y[:, 3], dims=[-1]),
], dim=1)
elif scans == 3:
y = torch.stack([
y[:, 0, :, :].contiguous().view(B, D, -1),
torch.rot90(y.view(B, K, D, W, H)[:, 1, :, :, :], -1, dims=(2, 3)).flatten(2, 3),
torch.rot90(y.view(B, K, D, H, W)[:, 2, :, :, :], -2, dims=(2, 3)).flatten(2, 3),
torch.rot90(y.view(B, K, D, W, H)[:, 3, :, :, :], -3, dims=(2, 3)).flatten(2, 3),
], dim=1)
else:
B, H, W, K, D = y.shape
y = y.view(B, -1, K, D)
if scans == 0:
y = torch.stack([
y[:, :, 0],
y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2),
torch.flip(y[:, :, 2], dims=[1]),
torch.flip(y[:, :, 3].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]),
], dim=2)
elif scans == 1:
y = y
elif scans == 2:
y = torch.stack([
y[:, :, 0],
y[:, :, 1],
torch.flip(y[:, :, 2], dims=[1]),
torch.flip(y[:, :, 3], dims=[1]),
], dim=2)
elif scans == 3:
y = torch.stack([
y[:, :, 0, :].contiguous().view(B, -1, D),
torch.rot90(y.view(B, W, H, K, D)[:, :, :, 1, :], -1, dims=(1, 2)).flatten(1, 2),
torch.rot90(y.view(B, H, W, K, D)[:, :, :, 2, :], -2, dims=(1, 2)).flatten(1, 2),
torch.rot90(y.view(B, W, H, K, D)[:, :, :, 3, :], -3, dims=(1, 2)).flatten(1, 2),
], dim=2)
if out_channel_first and (not in_channel_first):
y = y.permute(0, 3, 1, 2).contiguous()
elif (not out_channel_first) and in_channel_first:
y = y.permute(0, 2, 3, 1).contiguous()
return y
class CrossScanF(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
# x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
# y: (B, 4, C, H * W) | (B, H * W, 4, C)
ctx.in_channel_first = in_channel_first
ctx.out_channel_first = out_channel_first
ctx.one_by_one = one_by_one
ctx.scans = scans
if one_by_one:
B, K, C, H, W = x.shape
if not in_channel_first:
B, H, W, K, C = x.shape
else:
B, C, H, W = x.shape
if not in_channel_first:
B, H, W, C = x.shape
ctx.shape = (B, C, H, W)
_fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd
y = _fn(x, in_channel_first, out_channel_first, scans)
return y
@staticmethod
def backward(ctx, ys: torch.Tensor):
# out: (b, k, d, l)
in_channel_first = ctx.in_channel_first
out_channel_first = ctx.out_channel_first
one_by_one = ctx.one_by_one
scans = ctx.scans
B, C, H, W = ctx.shape
ys = ys.view(B, -1, C, H, W) if out_channel_first else ys.view(B, H, W, -1, C)
_fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd
y = _fn(ys, in_channel_first, out_channel_first, scans)
if one_by_one:
y = y.view(B, 4, -1, H, W) if in_channel_first else y.view(B, H, W, 4, -1)
else:
y = y.view(B, -1, H, W) if in_channel_first else y.view(B, H, W, -1)
return y, None, None, None, None
class CrossMergeF(torch.autograd.Function):
@staticmethod
def forward(ctx, ys: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
# x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
# y: (B, 4, C, H * W) | (B, H * W, 4, C)
ctx.in_channel_first = in_channel_first
ctx.out_channel_first = out_channel_first
ctx.one_by_one = one_by_one
ctx.scans = scans
B, K, C, H, W = ys.shape
if not out_channel_first:
B, H, W, K, C = ys.shape
ctx.shape = (B, C, H, W)
_fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd
y = _fn(ys, in_channel_first, out_channel_first, scans)
return y
@staticmethod
def backward(ctx, x: torch.Tensor):
# B, D, L = x.shape
# out: (b, k, d, h, w)
in_channel_first = ctx.in_channel_first
out_channel_first = ctx.out_channel_first
one_by_one = ctx.one_by_one
scans = ctx.scans
B, C, H, W = ctx.shape
if not one_by_one:
if in_channel_first:
x = x.view(B, C, H, W)
else:
x = x.view(B, H, W, C)
else:
if in_channel_first:
x = x.view(B, 4, C, H, W)
else:
x = x.view(B, H, W, 4, C)
_fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd
x = _fn(x, in_channel_first, out_channel_first, scans)
x = x.view(B, 4, C, H, W) if out_channel_first else x.view(B, H, W, 4, C)
return x, None, None, None, None
# triton implements ========================================
@triton.jit
def triton_cross_scan_flex(
x: tl.tensor, # (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
y: tl.tensor, # (B, 4, C, H, W) | (B, H, W, 4, C)
x_layout: tl.constexpr,
y_layout: tl.constexpr,
operation: tl.constexpr,
onebyone: tl.constexpr,
scans: tl.constexpr,
BC: tl.constexpr,
BH: tl.constexpr,
BW: tl.constexpr,
DC: tl.constexpr,
DH: tl.constexpr,
DW: tl.constexpr,
NH: tl.constexpr,
NW: tl.constexpr,
):
# x_layout = 0
# y_layout = 1 # 0 BCHW, 1 BHWC
# operation = 0 # 0 scan, 1 merge
# onebyone = 0 # 0 false, 1 true
# scans = 0 # 0 cross scan, 1 unidirectional, 2 bidirectional
i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_h, i_w = (i_hw // NW), (i_hw % NW)
_mask_h = (i_h * BH + tl.arange(0, BH)) < DH
_mask_w = (i_w * BW + tl.arange(0, BW)) < DW
_mask_hw = _mask_h[:, None] & _mask_w[None, :]
_for_C = min(DC - i_c * BC, BC)
pos_h = (i_h * BH + tl.arange(0, BH)[:, None])
pos_w = (i_w * BW + tl.arange(0, BW)[None, :])
neg_h = (DH - i_h * BH - 1 - tl.arange(0, BH)[:, None])
neg_w = (DW - i_w * BW - 1 - tl.arange(0, BW)[None, :])
if scans == 0:
# none; trans; flip; trans + flip;
HWRoute0 = pos_h * DW + pos_w
HWRoute1 = pos_w * DH + pos_h # trans
HWRoute2 = neg_h * DW + neg_w # flip
HWRoute3 = neg_w * DH + neg_h # trans + flip
elif scans == 1:
# none; none; none; none;
HWRoute0 = pos_h * DW + pos_w
HWRoute1 = HWRoute0
HWRoute2 = HWRoute0
HWRoute3 = HWRoute0
elif scans == 2:
# none; none; flip; flip;
HWRoute0 = pos_h * DW + pos_w
HWRoute1 = HWRoute0
HWRoute2 = neg_h * DW + neg_w # flip
HWRoute3 = HWRoute2
elif scans == 3:
# none; rot90; rot180==flip; rot270;
HWRoute0 = pos_h * DW + pos_w
HWRoute1 = neg_w * DH + pos_h
HWRoute2 = neg_h * DW + neg_w
HWRoute3 = pos_w * DH + neg_h
_tmp1 = DC * DH * DW
y_ptr_base = y + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if y_layout == 0 else i_c * BC)
if y_layout == 0:
p_y1 = y_ptr_base + HWRoute0
p_y2 = y_ptr_base + _tmp1 + HWRoute1
p_y3 = y_ptr_base + 2 * _tmp1 + HWRoute2
p_y4 = y_ptr_base + 3 * _tmp1 + HWRoute3
else:
p_y1 = y_ptr_base + HWRoute0 * 4 * DC
p_y2 = y_ptr_base + DC + HWRoute1 * 4 * DC
p_y3 = y_ptr_base + 2 * DC + HWRoute2 * 4 * DC
p_y4 = y_ptr_base + 3 * DC + HWRoute3 * 4 * DC
if onebyone == 0:
x_ptr_base = x + i_b * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)
if x_layout == 0:
p_x = x_ptr_base + HWRoute0
else:
p_x = x_ptr_base + HWRoute0 * DC
if operation == 0:
for idxc in range(_for_C):
_idx_x = idxc * DH * DW if x_layout == 0 else idxc
_idx_y = idxc * DH * DW if y_layout == 0 else idxc
_x = tl.load(p_x + _idx_x, mask=_mask_hw)
tl.store(p_y1 + _idx_y, _x, mask=_mask_hw)
tl.store(p_y2 + _idx_y, _x, mask=_mask_hw)
tl.store(p_y3 + _idx_y, _x, mask=_mask_hw)
tl.store(p_y4 + _idx_y, _x, mask=_mask_hw)
elif operation == 1:
for idxc in range(_for_C):
_idx_x = idxc * DH * DW if x_layout == 0 else idxc
_idx_y = idxc * DH * DW if y_layout == 0 else idxc
_y1 = tl.load(p_y1 + _idx_y, mask=_mask_hw)
_y2 = tl.load(p_y2 + _idx_y, mask=_mask_hw)
_y3 = tl.load(p_y3 + _idx_y, mask=_mask_hw)
_y4 = tl.load(p_y4 + _idx_y, mask=_mask_hw)
tl.store(p_x + _idx_x, _y1 + _y2 + _y3 + _y4, mask=_mask_hw)
else:
x_ptr_base = x + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)
if x_layout == 0:
p_x1 = x_ptr_base + HWRoute0
p_x2 = p_x1 + _tmp1
p_x3 = p_x2 + _tmp1
p_x4 = p_x3 + _tmp1
else:
p_x1 = x_ptr_base + HWRoute0 * 4 * DC
p_x2 = p_x1 + DC
p_x3 = p_x2 + DC
p_x4 = p_x3 + DC
if operation == 0:
for idxc in range(_for_C):
_idx_x = idxc * DH * DW if x_layout == 0 else idxc
_idx_y = idxc * DH * DW if y_layout == 0 else idxc
tl.store(p_y1 + _idx_y, tl.load(p_x1 + _idx_x, mask=_mask_hw), mask=_mask_hw)
tl.store(p_y2 + _idx_y, tl.load(p_x2 + _idx_x, mask=_mask_hw), mask=_mask_hw)
tl.store(p_y3 + _idx_y, tl.load(p_x3 + _idx_x, mask=_mask_hw), mask=_mask_hw)
tl.store(p_y4 + _idx_y, tl.load(p_x4 + _idx_x, mask=_mask_hw), mask=_mask_hw)
else:
for idxc in range(_for_C):
_idx_x = idxc * DH * DW if x_layout == 0 else idxc
_idx_y = idxc * DH * DW if y_layout == 0 else idxc
tl.store(p_x1 + _idx_x, tl.load(p_y1 + _idx_y), mask=_mask_hw)
tl.store(p_x2 + _idx_x, tl.load(p_y2 + _idx_y), mask=_mask_hw)
tl.store(p_x3 + _idx_x, tl.load(p_y3 + _idx_y), mask=_mask_hw)
tl.store(p_x4 + _idx_x, tl.load(p_y4 + _idx_y), mask=_mask_hw)
class CrossScanTritonF(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
if one_by_one:
if in_channel_first:
B, _, C, H, W = x.shape
else:
B, H, W, _, C = x.shape
else:
if in_channel_first:
B, C, H, W = x.shape
else:
B, H, W, C = x.shape
B, C, H, W = int(B), int(C), int(H), int(W)
BC, BH, BW = 1, 32, 32
NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC)
ctx.in_channel_first = in_channel_first
ctx.out_channel_first = out_channel_first
ctx.one_by_one = one_by_one
ctx.scans = scans
ctx.shape = (B, C, H, W)
ctx.triton_shape = (BC, BH, BW, NC, NH, NW)
y = x.new_empty((B, 4, C, H * W)) if out_channel_first else x.new_empty((B, H * W, 4, C))
triton_cross_scan_flex[(NH * NW, NC, B)](
x.contiguous(), y,
(0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans,
BC, BH, BW, C, H, W, NH, NW
)
return y
@staticmethod
def backward(ctx, y: torch.Tensor):
in_channel_first = ctx.in_channel_first
out_channel_first = ctx.out_channel_first
one_by_one = ctx.one_by_one
scans = ctx.scans
B, C, H, W = ctx.shape
BC, BH, BW, NC, NH, NW = ctx.triton_shape
if one_by_one:
x = y.new_empty((B, 4, C, H, W)) if in_channel_first else y.new_empty((B, H, W, 4, C))
else:
x = y.new_empty((B, C, H, W)) if in_channel_first else y.new_empty((B, H, W, C))
triton_cross_scan_flex[(NH * NW, NC, B)](
x, y.contiguous(),
(0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans,
BC, BH, BW, C, H, W, NH, NW
)
return x, None, None, None, None
class CrossMergeTritonF(torch.autograd.Function):
@staticmethod
def forward(ctx, y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
if out_channel_first:
B, _, C, H, W = y.shape
else:
B, H, W, _, C = y.shape
B, C, H, W = int(B), int(C), int(H), int(W)
BC, BH, BW = 1, 32, 32
NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC)
ctx.in_channel_first = in_channel_first
ctx.out_channel_first = out_channel_first
ctx.one_by_one = one_by_one
ctx.scans = scans
ctx.shape = (B, C, H, W)
ctx.triton_shape = (BC, BH, BW, NC, NH, NW)
if one_by_one:
x = y.new_empty((B, 4, C, H * W)) if in_channel_first else y.new_empty((B, H * W, 4, C))
else:
x = y.new_empty((B, C, H * W)) if in_channel_first else y.new_empty((B, H * W, C))
triton_cross_scan_flex[(NH * NW, NC, B)](
x, y.contiguous(),
(0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans,
BC, BH, BW, C, H, W, NH, NW
)
return x
@staticmethod
def backward(ctx, x: torch.Tensor):
in_channel_first = ctx.in_channel_first
out_channel_first = ctx.out_channel_first
one_by_one = ctx.one_by_one
scans = ctx.scans
B, C, H, W = ctx.shape
BC, BH, BW, NC, NH, NW = ctx.triton_shape
y = x.new_empty((B, 4, C, H, W)) if out_channel_first else x.new_empty((B, H, W, 4, C))
triton_cross_scan_flex[(NH * NW, NC, B)](
x.contiguous(), y,
(0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans,
BC, BH, BW, C, H, W, NH, NW
)
return y, None, None, None, None, None
# @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
def cross_scan_fn(x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False):
# x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
# y: (B, 4, C, L) | (B, L, 4, C)
# scans: 0: cross scan; 1 unidirectional; 2: bidirectional;
CSF = CrossScanTritonF if WITH_TRITON and x.is_cuda and (not force_torch) else CrossScanF
if x.is_cuda:
with torch.cuda.device(x.device):
return CSF.apply(x, in_channel_first, out_channel_first, one_by_one, scans)
else:
return CrossScanF.apply(x, in_channel_first, out_channel_first, one_by_one, scans)
# @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
def cross_merge_fn(y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False):
# y: (B, 4, C, L) | (B, L, 4, C)
# x: (B, C, H * W) | (B, H * W, C) | (B, 4, C, H * W) | (B, H * W, 4, C)
# scans: 0: cross scan; 1 unidirectional; 2: bidirectional;
CMF = CrossMergeTritonF if WITH_TRITON and y.is_cuda and (not force_torch) else CrossMergeF
if y.is_cuda:
with torch.cuda.device(y.device):
return CMF.apply(y, in_channel_first, out_channel_first, one_by_one, scans)
else:
return CrossMergeF.apply(y, in_channel_first, out_channel_first, one_by_one, scans)
# checks =================================================================
# class CHECK:
# def check_csm_triton():
# B, C, H, W = 256, 192, 56, 57
# dtype=torch.float16
# dtype=torch.float32
# x = torch.randn((B, C, H, W), dtype=dtype, device=torch.device("cuda")).requires_grad_(True)
# y = torch.randn((B, 4, C, H, W), dtype=dtype, device=torch.device("cuda")).requires_grad_(True)
# x1 = x.clone().detach().requires_grad_(True)
# y1 = y.clone().detach().requires_grad_(True)
# def cross_scan(x: torch.Tensor):
# B, C, H, W = x.shape
# L = H * W
# xs = torch.stack([
# x.view(B, C, L),
# torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, C, L),
# torch.flip(x.contiguous().view(B, C, L), dims=[-1]),
# torch.flip(torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, C, L), dims=[-1]),
# ], dim=1).view(B, 4, C, L)
# return xs
# def cross_merge(out_y: torch.Tensor):
# B, K, D, H, W = out_y.shape
# L = H * W
# out_y = out_y.view(B, K, D, L)
# inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
# wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
# invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
# y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y
# return y
# def cross_scan_1b1(x: torch.Tensor):
# B, K, C, H, W = x.shape
# L = H * W
# xs = torch.stack([
# x[:, 0].view(B, C, L),
# torch.transpose(x[:, 1], dim0=2, dim1=3).contiguous().view(B, C, L),
# torch.flip(x[:, 2].contiguous().view(B, C, L), dims=[-1]),
# torch.flip(torch.transpose(x[:, 3], dim0=2, dim1=3).contiguous().view(B, C, L), dims=[-1]),
# ], dim=1).view(B, 4, C, L)
# return xs
# def unidi_scan(x):
# B, C, H, W = x.shape
# x = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1)
# return x
# def unidi_merge(ys):
# B, K, C, H, W = ys.shape
# return ys.view(B, 4, -1, H * W).sum(1)
# def bidi_scan(x):
# B, C, H, W = x.shape
# x = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1)
# x = torch.cat([x, x.flip(dims=[-1])], dim=1)
# return x
# def bidi_merge(ys):
# B, K, D, H, W = ys.shape
# ys = ys.view(B, K, D, -1)
# ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
# return ys.contiguous().sum(1)
# if True:
# res0 = triton.testing.do_bench(lambda :cross_scan(x))
# res1 = triton.testing.do_bench(lambda :cross_scan_fn(x, True, True, False))
# # res2 = triton.testing.do_bench(lambda :CrossScanTriton.apply(x))
# res3 = triton.testing.do_bench(lambda :cross_merge(y))
# res4 = triton.testing.do_bench(lambda :cross_merge_fn(y, True, True, False))
# # res5 = triton.testing.do_bench(lambda :CrossMergeTriton.apply(y))
# # print(res0, res1, res2, res3, res4, res5)
# print(res0, res1, res3, res4)
# res0 = triton.testing.do_bench(lambda :cross_scan(x).sum().backward())
# res1 = triton.testing.do_bench(lambda :cross_scan_fn(x, True, True, False).sum().backward())
# # res2 = triton.testing.do_bench(lambda :CrossScanTriton.apply(x).sum().backward())
# res3 = triton.testing.do_bench(lambda :cross_merge(y).sum().backward())
# res4 = triton.testing.do_bench(lambda :cross_merge_fn(y, True, True, False).sum().backward())
# # res5 = triton.testing.do_bench(lambda :CrossMergeTriton.apply(y).sum().backward())
# # print(res0, res1, res2, res3, res4, res5)
# print(res0, res1, res3, res4)
# print("test cross scan")
# for (cs0, cm0, cs1, cm1) in [
# # channel_first -> channel_first
# (cross_scan, cross_merge, cross_scan_fn, cross_merge_fn),
# (unidi_scan, unidi_merge, lambda x: cross_scan_fn(x, scans=1), lambda x: cross_merge_fn(x, scans=1)),
# (bidi_scan, bidi_merge, lambda x: cross_scan_fn(x, scans=2), lambda x: cross_merge_fn(x, scans=2)),
# # flex: BLC->BCL; BCL->BLC; BLC->BLC;
# (cross_scan, cross_merge, lambda x: cross_scan_fn(x.permute(0, 2, 3, 1), in_channel_first=False), lambda x: cross_merge_fn(x, in_channel_first=False).permute(0, 2, 1)),
# (cross_scan, cross_merge, lambda x: cross_scan_fn(x, out_channel_first=False).permute(0, 2, 3, 1), lambda x: cross_merge_fn(x.permute(0, 3, 4, 1, 2), out_channel_first=False)),
# (cross_scan, cross_merge, lambda x: cross_scan_fn(x.permute(0, 2, 3, 1), in_channel_first=False, out_channel_first=False).permute(0, 2, 3, 1), lambda x: cross_merge_fn(x.permute(0, 3, 4, 1, 2), in_channel_first=False, out_channel_first=False).permute(0, 2, 1)),
# # previous
# # (cross_scan, cross_merge, lambda x: CrossScanTriton.apply(x), lambda x: CrossMergeTriton.apply(x)),
# # (unidi_scan, unidi_merge, lambda x: getCSM(1)[0].apply(x), lambda x: getCSM(1)[1].apply(x)),
# # (bidi_scan, bidi_merge, lambda x: getCSM(2)[0].apply(x), lambda x: getCSM(2)[1].apply(x)),
# ]:
# x.grad, x1.grad, y.grad, y1.grad = None, None, None, None
# o0 = cs0(x)
# o1 = cs1(x1)
# o0.backward(y.view(B, 4, C, H * W))
# o1.backward(y.view(B, 4, C, H * W))
# print((o0 - o1).abs().max())
# print((x.grad - x1.grad).abs().max())
# o0 = cm0(y)
# o1 = cm1(y1)
# o0.backward(x.view(B, C, H * W))
# o1.backward(x.view(B, C, H * W))
# print((o0 - o1).abs().max())
# print((y.grad - y1.grad).abs().max())
# x.grad, x1.grad, y.grad, y1.grad = None, None, None, None
# print("===============", flush=True)
# print("test cross scan one by one")
# for (cs0, cs1) in [
# (cross_scan_1b1, lambda x: cross_scan_fn(x, one_by_one=True)),
# # (cross_scan_1b1, lambda x: CrossScanTriton1b1.apply(x)),
# ]:
# o0 = cs0(y)
# o1 = cs1(y1)
# o0.backward(y.view(B, 4, C, H * W))
# o1.backward(y.view(B, 4, C, H * W))
# print((o0 - o1).abs().max())
# print((y.grad - y1.grad).abs().max())
# x.grad, x1.grad, y.grad, y1.grad = None, None, None, None
# print("===============", flush=True)
# def check_csm_scan3():
# if False:
# x = torch.arange(0, 16).view(1, 1, 4, 4).cuda()
# out1 = cross_scan_fn(x, scans=3, force_torch=True).view(1, 4, 1, 4, 4)
# out2 = cross_merge_fn(out1, scans=3, force_torch=True).view(1, 1, 4, 4)
# out4 = cross_merge_fn(out1, one_by_one=True, scans=3, force_torch=True).view(1, 4, 1, 4, 4)
# out3 = cross_scan_fn(out4, one_by_one=True, scans=3, force_torch=True).view(1, 4, 1, 4, 4)
# out5 = cross_scan_fn(x.view(1, 4, 4, 1), in_channel_first=False, out_channel_first=False, scans=3, force_torch=True).view(1, 4, 4, 4, 1)
# out6 = cross_merge_fn(out5, in_channel_first=False, out_channel_first=False, scans=3, force_torch=True).view(1, 4, 4, 1)
# out8 = cross_merge_fn(out5, in_channel_first=False, out_channel_first=False, one_by_one=True, scans=3, force_torch=True).view(1, 4, 4, 4, 1)
# out7 = cross_scan_fn(out8, in_channel_first=False, out_channel_first=False, one_by_one=True, scans=3, force_torch=True).view(1, 4, 4, 4, 1)
# print(out1.view(4, -1))
# print(out2.view(-1))
# print(out3.view(4, -1))
# print(out4.view(4, -1))
# print(out5.view(-1, 4).t())
# print(out6.view(-1))
# print(out7.view(-1, 4).t())
# print(out8.view(-1, 4).t())
# B, C, H, W = 27, 253, 57, 58
# x = torch.randn((B, C, H, W)).cuda()
# for scans in [0, 1, 2, 3]:
# o1 = cross_scan_fn(x, scans=scans, force_torch=True).view(B, 4, C, H, W)
# print((cross_scan_fn(x, scans=scans) == cross_scan_fn(x, scans=scans, force_torch=True)).all())
# print((cross_merge_fn(o1, scans=scans) == cross_merge_fn(o1, scans=scans, force_torch=True)).all())
# kwargs = dict(in_channel_first=False, out_channel_first=False)
# x2 = x.permute(0, 2, 3, 1).contiguous()
# o2 = o1.permute(0, 3, 4, 1, 2).contiguous()
# print((cross_scan_fn(x, scans=scans, **kwargs) == cross_scan_fn(x, scans=scans, force_torch=True, **kwargs)).all())
# print((cross_merge_fn(o2, scans=scans, **kwargs) == cross_merge_fn(o2, scans=scans, force_torch=True, **kwargs)).all())
# breakpoint()
# if __name__ == "__main__":
# CHECK.check_csm_scan3()
# CHECK.check_csm_triton()
##########################################################
# csms6s.py
##########################################################
import time
import torch
import warnings
WITH_SELECTIVESCAN_MAMBA = True
try:
import selective_scan_cuda
except ImportError:
WITH_SELECTIVESCAN_MAMBA = False
def selective_scan_torch(
u: torch.Tensor, # (B, K * C, L)
delta: torch.Tensor, # (B, K * C, L)
A: torch.Tensor, # (K * C, N)
B: torch.Tensor, # (B, K, N, L)
C: torch.Tensor, # (B, K, N, L)
D: torch.Tensor = None, # (K * C)
delta_bias: torch.Tensor = None, # (K * C)
delta_softplus=True,
oflex=True,
*args,
**kwargs
):
dtype_in = u.dtype
Batch, K, N, L = B.shape
KCdim = u.shape[1]
Cdim = int(KCdim / K)
assert u.shape == (Batch, KCdim, L)
assert delta.shape == (Batch, KCdim, L)
assert A.shape == (KCdim, N)
assert C.shape == B.shape
if delta_bias is not None:
delta = delta + delta_bias[..., None]
if delta_softplus:
delta = torch.nn.functional.softplus(delta)
u, delta, A, B, C = u.float(), delta.float(), A.float(), B.float(), C.float()
B = B.view(Batch, K, 1, N, L).repeat(1, 1, Cdim, 1, 1).view(Batch, KCdim, N, L)
C = C.view(Batch, K, 1, N, L).repeat(1, 1, Cdim, 1, 1).view(Batch, KCdim, N, L)
deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
if True:
x = A.new_zeros((Batch, KCdim, N))
ys = []
for i in range(L):
x = deltaA[:, :, i, :] * x + deltaB_u[:, :, i, :]
y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
ys.append(y)
y = torch.stack(ys, dim=2) # (B, C, L)
out = y if D is None else y + u * D.unsqueeze(-1)
return out if oflex else out.to(dtype=dtype_in)
class SelectiveScanCuda(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd
def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, oflex=True, backend=None):
ctx.delta_softplus = delta_softplus
# backend = "oflex" if WITH_SELECTIVESCAN_OFLEX and (backend is None) else backend
# backend = "core" if WITH_SELECTIVESCAN_CORE and (backend is None) else backend
backend = "mamba" if WITH_SELECTIVESCAN_MAMBA and (backend is None) else backend
ctx.backend = backend
if backend == "oflex":
out, x, *rest = selective_scan_cuda_oflex.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1, oflex)
elif backend == "mamba":
out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, None, delta_bias, delta_softplus)
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
return out
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, dout, *args):
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
backend = ctx.backend
if dout.stride(-1) != 1:
dout = dout.contiguous()
if backend == "oflex":
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_oflex.bwd(
u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1
)
elif backend == "mamba":
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
u, delta, A, B, C, D, None, delta_bias, dout, x, None, None, ctx.delta_softplus,
False
)
return du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None
def selective_scan_fn(
u: torch.Tensor, # (B, K * C, L)
delta: torch.Tensor, # (B, K * C, L)
A: torch.Tensor, # (K * C, N)
B: torch.Tensor, # (B, K, N, L)
C: torch.Tensor, # (B, K, N, L)
D: torch.Tensor = None, # (K * C)
delta_bias: torch.Tensor = None, # (K * C)
delta_softplus=True,
oflex=True,
backend=None,
):
fn = selective_scan_torch if backend == "torch" or (not WITH_SELECTIVESCAN_MAMBA) else SelectiveScanCuda.apply
return fn(u, delta, A, B, C, D, delta_bias, delta_softplus, oflex, backend)
# fvcore flops =======================================
def print_jit_input_names(inputs):
print("input params: ", end=" ", flush=True)
try:
for i in range(10):
print(inputs[i].debugName(), end=" ", flush=True)
except Exception as e:
pass
print("", flush=True)
def flops_selective_scan_fn(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_complex=False):
"""
u: r(B D L)
delta: r(B D L)
A: r(D N)
B: r(B N L)
C: r(B N L)
D: r(D)
z: r(B D L)
delta_bias: r(D), fp32
ignores:
[.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu]
"""
assert not with_complex
# https://github.com/state-spaces/mamba/issues/110
flops = 9 * B * L * D * N
if with_D:
flops += B * D * L
if with_Z:
flops += B * D * L
return flops
# this is only for selective_scan_ref...
def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False):
"""
u: r(B D L)
delta: r(B D L)
A: r(D N)
B: r(B N L)
C: r(B N L)
D: r(D)
z: r(B D L)
delta_bias: r(D), fp32
ignores:
[.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu]
"""
import numpy as np
# fvcore.nn.jit_handles
def get_flops_einsum(input_shapes, equation):
np_arrs = [np.zeros(s) for s in input_shapes]
optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1]
for line in optim.split("\n"):
if "optimized flop" in line.lower():
# divided by 2 because we count MAC (multiply-add counted as one flop)
flop = float(np.floor(float(line.split(":")[-1]) / 2))
return flop
assert not with_complex
flops = 0 # below code flops = 0
flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln")
if with_Group:
flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln")
else:
flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln")
in_for_flops = B * D * N
if with_Group:
in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd")
else:
in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd")
flops += L * in_for_flops
if with_D:
flops += B * D * L
if with_Z:
flops += B * D * L
return flops
def selective_scan_flop_jit(inputs, outputs, backend="prefixsum", verbose=True):
if verbose:
print_jit_input_names(inputs)
flops_fn = flops_selective_scan_ref if backend == "naive" else flops_selective_scan_fn
B, D, L = inputs[0].type().sizes()
N = inputs[2].type().sizes()[1]
flops = flops_fn(B=B, L=L, D=D, N=N, with_D=True, with_Z=False)
return flops
# if __name__ == "__main__":
# def params(B, K, C, N, L, device = torch.device("cuda"), itype = torch.float):
# As = (-0.5 * torch.rand(K * C, N, device=device, dtype=torch.float32)).requires_grad_()
# Bs = torch.randn((B, K, N, L), device=device, dtype=itype).requires_grad_()
# Cs = torch.randn((B, K, N, L), device=device, dtype=itype).requires_grad_()
# Ds = torch.randn((K * C), device=device, dtype=torch.float32).requires_grad_()
# u = torch.randn((B, K * C, L), device=device, dtype=itype).requires_grad_()
# delta = (0.5 * torch.rand((B, K * C, L), device=device, dtype=itype)).requires_grad_()
# delta_bias = (0.5 * torch.rand((K * C), device=device, dtype=torch.float32)).requires_grad_()
# return u, delta, As, Bs, Cs, Ds, delta_bias
# def bench(func, xs, Warmup=30, NTimes=20):
# import time
# torch.cuda.synchronize()
# for r in range(Warmup):
# for x in xs:
# func(x)