-
Notifications
You must be signed in to change notification settings - Fork 82
/
standalone_hyenadna.py
1073 lines (888 loc) · 41.6 KB
/
standalone_hyenadna.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
863
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# -*- coding: utf-8 -*-
"""HyenaDNA training & inference example (Public)
This code is adapted from the original colab tutorial on HyenaDNA. Check that out for an easier entry point into the code.
We provide the code here as an example for those who want something outside collab, with Huggingface integration.
Original file is located at
https://colab.research.google.com/drive/1wyVEQd4R3HYLTUOXEEQmp_I8aNC_aLhL
"""
#@title Imports
# for HyenaDNA specifically
import torch
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from einops import rearrange
from typing import Optional
from functools import partial
from torch import Tensor
from torchvision.ops import StochasticDepth
from collections import namedtuple
import numpy as np
import os
import json
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Union
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
"""# HyenaDNA
"""
#@title Hyena layer
def fftconv(u, k, D):
"""
We apply a convolution through the fourier domain (from the Convolution Theorem)
"""
seqlen = u.shape[-1]
fft_size = 2 * seqlen
k_f = torch.fft.rfft(k, n=fft_size) / fft_size
u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)
if len(u.shape) > 3: k_f = k_f.unsqueeze(1)
y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen]
out = y u * D.unsqueeze(-1)
return out.to(dtype=u.dtype)
@torch.jit.script
def mul_sum(q, y):
return (q * y).sum(dim=1)
class OptimModule(nn.Module):
""" Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters """
def register(self, name, tensor, lr=None, wd=0.0):
"""Register a tensor with a configurable learning rate and 0 weight decay"""
if lr == 0.0:
self.register_buffer(name, tensor)
else:
self.register_parameter(name, nn.Parameter(tensor))
optim = {}
if lr is not None: optim["lr"] = lr
if wd is not None: optim["weight_decay"] = wd
setattr(getattr(self, name), "_optim", optim)
class Sin(nn.Module):
"""The Sin activation function for the Hyena Filter function."""
def __init__(self, dim, w=10, train_freq=True):
super().__init__()
self.freq = nn.Parameter(w * torch.ones(1, dim)) if train_freq else w * torch.ones(1, dim)
def forward(self, x):
return torch.sin(self.freq * x)
class PositionalEmbedding(OptimModule):
def __init__(self, emb_dim: int, seq_len: int, lr_pos_emb: float=1e-5, **kwargs):
"""Complex exponential positional embeddings for Hyena filters."""
super().__init__()
self.seq_len = seq_len
# The time embedding fed to the filteres is normalized so that t_f = 1
t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1
if emb_dim > 1:
bands = (emb_dim - 1) // 2
# To compute the right embeddings we use the "proper" linspace
t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None]
w = 2 * math.pi * t_rescaled / seq_len # 1, L, 1
f = torch.linspace(1e-4, bands - 1, bands)[None, None]
z = torch.exp(-1j * f * w)
z = torch.cat([t, z.real, z.imag], dim=-1)
self.register("z", z, lr=lr_pos_emb)
self.register("t", t, lr=0.0)
def forward(self, L):
return self.z[:, :L], self.t[:, :L]
class ExponentialModulation(OptimModule):
"""The window function applied to the output of the (MLP) filter function."""
def __init__(
self,
d_model,
fast_decay_pct=0.3,
slow_decay_pct=1.5,
target=1e-2,
modulation_lr=0.0,
modulate: bool=True,
shift: float = 0.05,
**kwargs
):
super().__init__()
self.modulate = modulate
self.shift = shift
max_decay = math.log(target) / fast_decay_pct
min_decay = math.log(target) / slow_decay_pct
deltas = torch.linspace(min_decay, max_decay, d_model)[None, None]
self.register("deltas", deltas, lr=modulation_lr)
def forward(self, t, x):
if self.modulate:
decay = torch.exp(-t * self.deltas.abs())
x = x * (decay self.shift)
return x
class HyenaFilter(OptimModule):
def __init__(
self,
d_model,
emb_dim=3, # dim of input to MLP, augments with positional encoding
order=16, # width of the implicit MLP
fused_fft_conv=False,
seq_len=1024,
lr=1e-3,
lr_pos_emb=1e-5,
dropout=0.0,
w=1, # frequency of periodic activations
wd=0, # weight decay of kernel parameters
bias=True,
num_inner_mlps=2,
normalized=False,
**kwargs
):
"""
Implicit long filter with modulation.
Args:
d_model: number of channels in the input
emb_dim: dimension of the positional encoding (`emb_dim` - 1) // 2 is the number of bands
order: width of the FFN
num_inner_mlps: number of inner linear layers inside filter MLP
Note:
filter_dropout is not implemented
"""
super().__init__()
self.d_model = d_model
self.use_bias = bias
self.fused_fft_conv = fused_fft_conv
self.bias = nn.Parameter(torch.randn(self.d_model))
self.dropout = nn.Dropout(dropout)
act = Sin(dim=order, w=w)
self.emb_dim = emb_dim
assert emb_dim % 2 != 0 and emb_dim >= 3, "emb_dim must be odd and greater or equal to 3 (time, sine and cosine)"
self.seq_len = seq_len
self.pos_emb = PositionalEmbedding(emb_dim, seq_len, lr_pos_emb)
self.implicit_filter = nn.Sequential(
nn.Linear(emb_dim, order),
act,
)
for i in range(num_inner_mlps):
self.implicit_filter.append(nn.Linear(order, order))
self.implicit_filter.append(act)
self.implicit_filter.append(nn.Linear(order, d_model, bias=False))
self.modulation = ExponentialModulation(d_model, **kwargs)
self.normalized = normalized
for c in self.implicit_filter.children():
for name, v in c.state_dict().items():
optim = {"weight_decay": wd, "lr": lr}
setattr(getattr(c, name), "_optim", optim)
def filter(self, L, *args, **kwargs):
z, t = self.pos_emb(L)
h = self.implicit_filter(z)
h = self.modulation(t, h)
return h
def forward(self, x, L, k=None, bias=None, *args, **kwargs):
if k is None: k = self.filter(L)
# Ensure compatibility with filters that return a tuple
k = k[0] if type(k) is tuple else k
y = fftconv(x, k, bias)
return y
class HyenaOperator(nn.Module):
def __init__(
self,
d_model,
l_max,
order=2,
filter_order=64,
dropout=0.0,
filter_dropout=0.0,
**filter_args,
):
r"""
Hyena operator described in the paper https://arxiv.org/pdf/2302.10866.pdf
Args:
d_model (int): Dimension of the input and output embeddings (width of the layer)
l_max: (int): Maximum input sequence length. Defaults to None
order: (int): Depth of the Hyena recurrence. Defaults to 2
dropout: (float): Dropout probability. Defaults to 0.0
filter_dropout: (float): Dropout probability for the filter. Defaults to 0.0
"""
super().__init__()
self.d_model = d_model
self.l_max = l_max
self.order = order
inner_width = d_model * (order 1)
self.dropout = nn.Dropout(dropout)
self.in_proj = nn.Linear(d_model, inner_width)
self.out_proj = nn.Linear(d_model, d_model)
self.short_filter = nn.Conv1d(
inner_width,
inner_width,
3,
padding=2,
groups=inner_width
)
self.filter_fn = HyenaFilter(
d_model * (order - 1),
order=filter_order,
seq_len=l_max,
channels=1,
dropout=filter_dropout,
**filter_args
)
def forward(self, u, *args, **kwargs):
l = u.size(-2)
l_filter = min(l, self.l_max)
u = self.in_proj(u)
u = rearrange(u, 'b l d -> b d l')
uc = self.short_filter(u)[...,:l_filter]
*x, v = uc.split(self.d_model, dim=1)
k = self.filter_fn.filter(l_filter)[0]
k = rearrange(k, 'l (o d) -> o d l', o=self.order - 1)
bias = rearrange(self.filter_fn.bias, '(o d) -> o d', o=self.order - 1)
for o, x_i in enumerate(reversed(x[1:])):
v = self.dropout(v * x_i)
v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o])
y = rearrange(v * x[0], 'b d l -> b l d')
y = self.out_proj(y)
return y
#@title Self-Attention (alternative)
"""
If you'd like to try the HyenaDNA model using attention instead, you can. ie,
use a regular decoder only Transformer.
"""
class SelfAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
super().__init__()
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
def forward(self, qkv, causal=None, key_padding_mask=None):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
causal: if passed, will override self.causal
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
False means to mask out. (B, S)
"""
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
causal = self.causal if causal is None else causal
q, k, v = qkv.unbind(dim=2)
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
if key_padding_mask is not None:
padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype,
device=scores.device)
padding_mask.masked_fill_(key_padding_mask, 0.0)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores rearrange(padding_mask, 'b s -> b 1 1 s')
if causal:
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores causal_mask.to(dtype=scores.dtype)
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
attention_drop = F.dropout(attention, self.dropout_p if self.training else 0.0)
output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
return output
class MHA(nn.Module):
"""Multi-head self-attention and cross-attention
"""
def __init__(self, embed_dim, num_heads, bias=True, dropout=0.0,
softmax_scale=None, causal=False, layer_idx=None, dwconv=False,return_residual=False,device=None, dtype=None) -> None:
"""
return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
"""
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.embed_dim = embed_dim
self.causal = causal
self.layer_idx = layer_idx
self.dwconv = dwconv
self.return_residual = return_residual
self.num_heads = num_heads
assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
self.head_dim = self.embed_dim // num_heads
linear_cls = nn.Linear
linear_resid_cls = LinearResidual
inner_attn_cls = SelfAttention
if not self.return_residual:
self.Wqkv = linear_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
else:
self.Wqkv = linear_resid_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
if self.dwconv:
self.dwconv_qkv = nn.Conv1d(3 * embed_dim, 3 * embed_dim, kernel_size=3, padding=2,
groups=3 * embed_dim)
self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale,
attention_dropout=dropout)
# output projection always have the bias (for now)
self.out_proj = linear_cls(embed_dim, embed_dim, **factory_kwargs)
def forward(self, x, key_padding_mask=None, **kwargs):
"""
Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
is the is the sum of the sequence lengths in the batch.
cu_seqlens: (batch_size 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into x. Only applicable when using
FlashAttention.
max_seqlen: int. Maximum sequence length in the batch.
key_padding_mask: boolean mask, True means to keep, False means to mask out.
(batch, seqlen). Only applicable when not using FlashAttention.
mixer_subset: for cross-attention only. If not None, will take a subset of x
before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer.
inference_params: for generation. Adapted from Megatron-LM (and Apex)
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44786342329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
"""
kwargs = ({'key_padding_mask': key_padding_mask, **kwargs})
if not self.return_residual:
qkv = self.Wqkv(x)
else:
qkv, x = self.Wqkv(x)
if self.dwconv:
qkv = rearrange(self.dwconv_qkv(rearrange(qkv, 'b s d -> b d s'))[..., :-2],
'b d s -> b s d').contiguous()
qkv = rearrange(qkv, '... (three h d) -> ... three h d', three=3, d=self.head_dim)
context = self.inner_attn(qkv, **kwargs)
out = self.out_proj(rearrange(context, '... h d -> ... (h d)'))
return out if not self.return_residual else (out, x)
#@title MLP layer
"""
The MLP layer after the mixer layer (HyenaOperator).
"""
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu,
return_residual=False, device=None, dtype=None):
"""
From https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/modules/mlp.py
"""
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.return_residual = return_residual
self.fc1 = nn.Linear(in_features, hidden_features, **factory_kwargs)
self.activation = activation
self.fc2 = nn.Linear(hidden_features, out_features, **factory_kwargs)
def forward(self, x):
y = self.fc1(x)
y = self.activation(y)
y = self.fc2(y)
return y if not self.return_residual else (y, x)
#@title Block layer (Hyena MLP layers)
"""
A block consists of a Mixer layer (Hyena or attention), and a MLP layer.
"""
class LinearResidual(nn.Linear):
"""Wrap nn.Linear to return the residual as well. For compatibility with FusedDense.
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
return super().forward(input), input
class Block(nn.Module):
def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm,
dropout_cls=nn.Dropout, prenorm=True, resid_dropout1=0., resid_dropout2=0.,
drop_path1=0., drop_path2=0.,
return_residual=False,
residual_in_fp32=False):
"""
From https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/modules/block.py
For prenorm=True, this Block has a slightly different structure compared to a regular
prenorm Transformer block.
The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
[Ref: https://arxiv.org/abs/2002.04745]
Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
the hidden_states (output of the MLP) and the residual.
This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
The residual needs to be provided (except for the very first block).
For prenorm=False, this Block has the same structure as a regular postnorm Transformer
block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
This is for performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
"""
super().__init__()
self.prenorm = prenorm
self.return_residual = return_residual
self.residual_in_fp32 = residual_in_fp32
if self.residual_in_fp32:
assert self.prenorm, 'residual_in_fp32 is only compatible with prenorm=True'
if mixer_cls is None:
mixer_cls = partial(MHA, num_heads=dim // 64)
if mlp_cls is None:
mlp_cls = partial(Mlp, hidden_features=4 * dim)
self.mixer = mixer_cls()
self.dropout1 = dropout_cls(resid_dropout1)
self.drop_path1 = StochasticDepth(drop_path1, mode='row')
self.norm1 = norm_cls(dim)
self.mlp = mlp_cls(dim)
if not isinstance(self.mlp, nn.Identity):
self.dropout2 = dropout_cls(resid_dropout2)
self.drop_path2 = StochasticDepth(drop_path2, mode='row')
self.norm2 = norm_cls(dim)
def forward(self, hidden_states, residual = None,
mixer_subset=None, mixer_kwargs=None):
r"""Pass the input through the encoder layer.
Args:
hidden_states: the sequence to the encoder layer (required).
residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
mixer_subset: for cross-attention only. If not None, will take a subset of x
before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer.
"""
if self.prenorm:
dropped = self.drop_path1(self.dropout1(hidden_states))
residual = (dropped residual) if residual is not None else dropped
hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
if mixer_kwargs is None:
mixer_kwargs = {}
if mixer_subset is not None:
mixer_kwargs['mixer_subset'] = mixer_subset
hidden_states = self.mixer(hidden_states, **mixer_kwargs)
if mixer_subset is not None:
residual = residual[:, mixer_subset]
if not isinstance(self.mlp, nn.Identity):
dropped = self.drop_path2(self.dropout2(hidden_states))
residual = (dropped residual) if residual is not None else dropped
hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
else:
assert residual is None
mixer_out = self.mixer(
hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
)
if self.return_residual: # mixer out is actually a pair here
mixer_out, hidden_states = mixer_out
hidden_states = self.norm1((self.drop_path1(self.dropout1(mixer_out))
hidden_states).to(dtype=self.norm1.weight.dtype))
if not isinstance(self.mlp, nn.Identity):
mlp_out = self.mlp(hidden_states)
if self.return_residual: # mlp out is actually a pair here
mlp_out, hidden_states = mlp_out
hidden_states = self.norm2((self.drop_path2(self.dropout2(mlp_out))
hidden_states).to(dtype=self.norm2.weight.dtype))
return hidden_states
def create_mixer_cls(layer=None,
attn_layer_idx=None, attn_cfg=None, layer_idx=None,
device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
if attn_layer_idx is not None and layer_idx in attn_layer_idx:
causal = True if attn_cfg is None else attn_cfg.pop('causal', True)
mha_cls = MHA
mixer_cls = partial(mha_cls, causal=causal, layer_idx=layer_idx,
**(attn_cfg if attn_cfg is not None else {}),**factory_kwargs)
else:
# mixer_cls = instantiate(registry.layer, layer, partial=True, layer_idx=layer_idx, **factory_kwargs)
mixer_cls = partial(HyenaOperator, **layer)
return mixer_cls
def create_mlp_cls(d_model, d_inner=None, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
inner_dim = d_inner if d_inner is not None else 4 * d_model
mlp_cls = partial(Mlp, hidden_features=inner_dim,
activation=partial(F.gelu, approximate='tanh'), **factory_kwargs)
return mlp_cls
def create_block(d_model, d_inner=None,
layer=None, attn_layer_idx=None,
attn_cfg=None, layer_norm_epsilon=1e-5,
resid_dropout1=0.0, resid_dropout2=0.0, residual_in_fp32=False,
layer_idx=None,
device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
mixer_cls = create_mixer_cls(layer=layer,
attn_layer_idx=attn_layer_idx,
attn_cfg=attn_cfg, layer_idx=layer_idx,
**factory_kwargs)
mlp_cls = create_mlp_cls(d_model, d_inner=d_inner,
**factory_kwargs)
norm_cls = partial(nn.LayerNorm, eps=layer_norm_epsilon, **factory_kwargs)
block = Block(d_model, mixer_cls, mlp_cls, norm_cls=norm_cls,
prenorm=True, resid_dropout1=resid_dropout1, resid_dropout2=resid_dropout2,residual_in_fp32=residual_in_fp32)
block.layer_idx = layer_idx
return block
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True,
glu_act=False):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, std=initializer_range)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=initializer_range)
if rescale_prenorm_residual:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if name in ["out_proj.weight", "fc2.weight"]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer))
# If using GLU activation for now, we scale the std by 2
elif name in ["output_linear.0.weight"]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
if not glu_act:
nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer))
else:
out_features = p.shape[0]
# Multiplying the first half of the matrix by 2 since sigmoid scales it down by 0.5
# on average.
nn.init.normal_(p[:out_features // 2], mean=0.0, std=initializer_range / math.sqrt(2 * n_layer) * 2)
#@title Backbone model (stack of blocks)
"""
A backbone model consists of a stack of blocks. If you use attention, then
positional embeddings are included. When using Hyena, then the pos emb
revert to doing nothing.
"""
class GPT2Embeddings(nn.Module):
def __init__(self, embed_dim, vocab_size, max_position_embeddings, padding_idx=None,
word_embed_proj_dim=None, device=None, dtype=None):
"""
If max_position_embeddings <= 0, there's no position embeddings
If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension
the project up to embed_dim
"""
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
if word_embed_proj_dim is None:
self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx,
**factory_kwargs)
self.project_in = None
else:
self.word_embeddings = nn.Embedding(vocab_size, word_embed_proj_dim,
padding_idx=padding_idx, **factory_kwargs)
self.project_in = nn.Linear(word_embed_proj_dim, embed_dim, bias=False,
**factory_kwargs)
self.max_position_embeddings = max_position_embeddings
if self.max_position_embeddings > 0:
self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim,
**factory_kwargs)
def forward(self, input_ids, position_ids=None):
"""
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
"""
batch_size, seqlen = input_ids.shape
embeddings = self.word_embeddings(input_ids)
if self.project_in is not None:
embeddings = self.project_in(embeddings)
if self.max_position_embeddings > 0:
if position_ids is None:
position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
position_embeddings = self.position_embeddings(position_ids)
embeddings = embeddings position_embeddings
return embeddings
class LMBackbone(nn.Module):
def __init__(self, d_model: int, n_layer: int, d_inner: int, vocab_size: int,
process_group=None, layer=None,
attn_layer_idx=None, attn_cfg=None, max_position_embeddings=0,
resid_dropout: float = 0.0, embed_dropout: float = 0.1,
layer_norm_epsilon: float = 1e-5, initializer_cfg=None,residual_in_fp32=False,
device=None, dtype=None, **kwargs) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.process_group = process_group
self.residual_in_fp32 = residual_in_fp32
# note max_position_embeddings is 0 for Hyena, and therefore isn't used
self.embeddings = GPT2Embeddings(d_model, vocab_size, max_position_embeddings,
**factory_kwargs)
self.layers = nn.ModuleList([create_block(
d_model, d_inner=d_inner,
layer=layer, attn_layer_idx=attn_layer_idx,
attn_cfg=attn_cfg, layer_norm_epsilon=layer_norm_epsilon,
resid_dropout1=embed_dropout if i == 0 else resid_dropout,
resid_dropout2=resid_dropout, residual_in_fp32=residual_in_fp32,layer_idx=i,
**factory_kwargs,
) for i in range(n_layer)])
self.drop_f = nn.Dropout(resid_dropout)
self.ln_f = nn.LayerNorm(d_model, eps=layer_norm_epsilon, **factory_kwargs)
self.apply(partial(_init_weights, n_layer=n_layer,
**(initializer_cfg if initializer_cfg is not None else {})))
def forward(self, input_ids, position_ids=None):
hidden_states = self.embeddings(input_ids, position_ids=position_ids,)
residual = None
for layer in self.layers:
hidden_states, residual = layer(hidden_states, residual)
dropped = self.drop_f(hidden_states)
residual = (dropped residual) if residual is not None else dropped
hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))
return hidden_states
#@title Decoder head layer
"""
A simple decoder head (using MLP) to predict a sequence level classification.
You have the option to average across all the tokens in a sequence or using the
"last" token to classify. At least, those 2 worked best for us, but we provide
other "modes" as well.
We only need this for classification. Otherwise we'll use the hidden
states of the backbone as embeddings.
"""
class SequenceDecoder(nn.Module):
def __init__(
self, d_model, d_output=None, l_output=None, use_lengths=False, mode="last"
):
super().__init__()
self.output_transform = nn.Identity() if d_output is None else nn.Linear(d_model, d_output)
if l_output is None:
self.l_output = None
self.squeeze = False
elif l_output == 0:
# Equivalent to getting an output of length 1 and then squeezing
self.l_output = 1
self.squeeze = True
else:
assert l_output > 0
self.l_output = l_output
self.squeeze = False
self.use_lengths = use_lengths
self.mode = mode
if mode == 'ragged':
assert not use_lengths
def forward(self, x, state=None, lengths=None, l_output=None):
"""
x: (n_batch, l_seq, d_model)
Returns: (n_batch, l_output, d_output)
"""
if self.l_output is None:
if l_output is not None:
assert isinstance(l_output, int) # Override by pass in
else:
# Grab entire output
l_output = x.size(-2)
squeeze = False
else:
l_output = self.l_output
squeeze = self.squeeze
if self.mode == "last":
restrict = lambda x: x[..., -l_output:, :]
elif self.mode == "first":
restrict = lambda x: x[..., :l_output, :]
elif self.mode == "pool":
restrict = lambda x: (
torch.cumsum(x, dim=-2)
/ torch.arange(
1, 1 x.size(-2), device=x.device, dtype=x.dtype
).unsqueeze(-1)
)[..., -l_output:, :]
def restrict(x):
L = x.size(-2)
s = x.sum(dim=-2, keepdim=True)
if l_output > 1:
c = torch.cumsum(x[..., -(l_output - 1) :, :].flip(-2), dim=-2)
c = F.pad(c, (0, 0, 1, 0))
s = s - c # (B, l_output, D)
s = s.flip(-2)
denom = torch.arange(
L - l_output 1, L 1, dtype=x.dtype, device=x.device
)
s = s / denom
return s
elif self.mode == "sum":
restrict = lambda x: torch.cumsum(x, dim=-2)[..., -l_output:, :]
# TODO use same restrict function as pool case
elif self.mode == 'ragged':
assert lengths is not None, "lengths must be provided for ragged mode"
# remove any additional padding (beyond max length of any sequence in the batch)
restrict = lambda x: x[..., : max(lengths), :]
else:
raise NotImplementedError(
"Mode must be ['last' | 'first' | 'pool' | 'sum']"
)
# Restrict to actual length of sequence
if self.use_lengths:
assert lengths is not None
x = torch.stack(
[
restrict(out[..., :length, :])
for out, length in zip(torch.unbind(x, dim=0), lengths)
],
dim=0,
)
else:
x = restrict(x)
if squeeze:
assert x.size(-2) == 1
x = x.squeeze(-2)
x = self.output_transform(x)
return x
def step(self, x, state=None):
# Ignore all length logic
return self.output_transform(x)
#@title Model (backbone head)
"""
Putting it all together, the model consists of a backbone model
and a decoder head (you can turn off head for embeddings only too).
Here we use a simple head to do multi-classification, but
can also swap the head to do next token prediction too. We defer to the main
HyenaDNA for that code, since pretraining with next token prediction isn't quite
feasible on colab.
"""
class HyenaDNAModel(nn.Module):
def __init__(self, d_model: int, n_layer: int, d_inner: int, vocab_size: int,
layer=None, attn_layer_idx=None, attn_cfg=None, max_position_embeddings=0,
resid_dropout: float = 0.0, embed_dropout: float = 0.1,
layer_norm_epsilon: float = 1e-5, initializer_cfg=None,residual_in_fp32=False,
pad_vocab_size_multiple: int = 1, use_head=False, n_classes: int = 2,
device=None, dtype=None, **kwargs) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
if vocab_size % pad_vocab_size_multiple != 0:
vocab_size = pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
self.use_head = use_head
# check if layer (config) has d_model (HF code differs from main Safari code)
if 'd_model' not in layer:
layer['d_model'] = d_model
self.backbone = LMBackbone(
d_model=d_model, n_layer=n_layer, d_inner=d_inner, vocab_size=vocab_size,
layer=layer, attn_layer_idx=attn_layer_idx, attn_cfg=attn_cfg,
max_position_embeddings=max_position_embeddings,
resid_dropout=resid_dropout, embed_dropout=embed_dropout,
layer_norm_epsilon=layer_norm_epsilon,
initializer_cfg=initializer_cfg, residual_in_fp32=residual_in_fp32,
**factory_kwargs, **kwargs
)
# we only need a head if doing classification, otherwise we'll use the
# hidden states as embeddings
if self.use_head:
self.head = SequenceDecoder(d_model=d_model, d_output=n_classes, l_output=0, mode='pool')
# Initialize weights and apply final processing
self.apply(partial(_init_weights, n_layer=n_layer,
**(initializer_cfg if initializer_cfg is not None else {})))
# if self.use_head:
# self.tie_weights()
# def tie_weights(self):
# self.head.weight = self.backbone.embeddings.word_embeddings.weight
def forward(self, input_ids, position_ids=None, state=None): # state for the repo interface
hidden_states = self.backbone(input_ids, position_ids=position_ids)
if self.use_head:
return self.head(hidden_states)
else:
return hidden_states
"""# Data pipeline
"""
#@title Tokenizer
"""
Just a simple character level tokenizer.
From: https://github.com/dariush-bahrami/character-tokenizer/blob/master/charactertokenizer/core.py
CharacterTokenzier for Hugging Face Transformers.
This is heavily inspired from CanineTokenizer in transformers package.
"""
class CharacterTokenizer(PreTrainedTokenizer):
def __init__(self, characters: Sequence[str], model_max_length: int, padding_side: str='left', **kwargs):
"""Character tokenizer for Hugging Face transformers.
Args:
characters (Sequence[str]): List of desired characters. Any character which
is not included in this list will be replaced by a special token called
[UNK] with id=6. Following are list of all of the special tokens with
their corresponding ids:
"[CLS]": 0
"[SEP]": 1
"[BOS]": 2
"[MASK]": 3
"[PAD]": 4
"[RESERVED]": 5
"[UNK]": 6
an id (starting at 7) will be assigned to each character.
model_max_length (int): Model maximum sequence length.
"""
self.characters = characters
self.model_max_length = model_max_length
bos_token = AddedToken("[BOS]", lstrip=False, rstrip=False)
eos_token = AddedToken("[SEP]", lstrip=False, rstrip=False)
sep_token = AddedToken("[SEP]", lstrip=False, rstrip=False)
cls_token = AddedToken("[CLS]", lstrip=False, rstrip=False)
pad_token = AddedToken("[PAD]", lstrip=False, rstrip=False)
unk_token = AddedToken("[UNK]", lstrip=False, rstrip=False)
mask_token = AddedToken("[MASK]", lstrip=True, rstrip=False)
super().__init__(
bos_token=bos_token,
eos_token=sep_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
unk_token=unk_token,
add_prefix_space=False,
model_max_length=model_max_length,
padding_side=padding_side,
**kwargs,
)
self._vocab_str_to_int = {
"[CLS]": 0,
"[SEP]": 1,
"[BOS]": 2,
"[MASK]": 3,
"[PAD]": 4,
"[RESERVED]": 5,
"[UNK]": 6,
**{ch: i 7 for i, ch in enumerate(characters)},
}
self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}
@property
def vocab_size(self) -> int:
return len(self._vocab_str_to_int)
def _tokenize(self, text: str) -> List[str]:
return list(text)