-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
About "norm_dim_in" in self.to_out #10
Comments
@buyeah1109 hey Anthon! this is a good question which i don't really know the answer to in the end, i went by this sentence and presumed embedding (model) dimension for all weight matrices however, there are papers that do allude that your way of doing it may work (look up the post-activation layernorm in feedforward in normformer paper, as well as the norm on the aggregated values coming out of attention, but before combining the heads, named sub-ln) |
Thanks for your answer! That helps a lot. May I understand you are persuming the normalizing dimension is consistent across the entire model? Thus the normalizations in the code are all based on the "dim" dimension. There are some good news. I observed a significant speed-up from nGPT in pretraining on the Openwebtext dataset. To achieve the same loss, nGPT requires much fewer tokens or training iterations than standard GPT2 baselines. |
@buyeah1109 wow! you may be the first (although a lot of researchers keep their cards close) to report a positive result! (at least for the speedup) how significant was it and was it holding all hyperparameters constant? |
yes that's what i'm assuming based on "all vectors forming the embedding dimension" phrase, but looking at some other papers (like the cosine sim network paper from 2017), it could be done the other way around when projecting out in the end, you can only normalize one dimension of the weight matrix, and i can see both perspectives |
@buyeah1109 maybe when i get some time i'll wire up the training script for openwebtext in this repo then... thanks for sharing your replication efforts |
@buyeah1109 i could add some option to change up the normalizing dimension for those output projections in the |
Pretty good, it's about 3x faster than standard GPT2-124M. For hyperparameters, I just copy the default setting of standard GPT2 and tried different learning rates. I believe there will be a lot of space to improve by hyperparameter tuning. |
@buyeah1109 so not without a ton of tuning.. that's great news.. 🙏 looks like i'll put some more work into this next month. research isn't done here yet |
@buyeah1109 you didn't happen to train with mixed precision did you? also, your baseline was using rotary embeddings or no? |
I trained with mixed percision. My GPU is Ampere so I think BF16 is supported and used in the mixed training. I used to train with V100s but ended up with NaN. For the baseline, i testify baseline both with and w/o the rotary embedding. RoPE slightly improved baseline' training losses. |
@buyeah1109 wonderful! thank you! |
Hey, would you care to share your baseline and nGPT code ? I tried to reproduce the results on the 124M scale but only got results comparable for nGPT and GPT. |
@alxndrTL hey Alexandre, thanks for chiming in i'm also seeing only "comparable" results, but then again, i've never been that great of an experimentalist, so will reserve judgement for a bit longer |
Sure! I used the code from nanoGPT to train the GPT2-124M baseline on openwebtext dataset. I use the default training configuration in nanoGPT for GPT2 except for the batchsize since I don't have 8x H100s lol. For nGPT, i keep using the nanoGPT training script and directly import model from this wonderful project and align the depth, width and dimension with GPT2. I didn't use QK-norm for nGPT. |
I may also share some data points. GPT2-124M achieves training loss around 3.04 after training with 16B tokens. nGPT achieves 2.94 after training with 16B tokens. nGPT achieves 3.0 after training with 8B tokens. |
I think your @alxndrTL "comparable" results also matters. You may compare with my setting and find the difference, which could enlight us which part of the nGPT implementation is the most significant one for the training speed-up. It would help a lot! |
Thanks for the details @buyeah1109 . Yes I will try again, what you saw gives me hope. |
In my experiments (different modalities including mel-spectorgrams and other embeddings, no text tokens) i noticed that setting |
@inspirit yes that makes sense, as the magnitude in the continuous data will be lost without the network given some room to encode it as phase i also got curious and spent the whole day yesterday trying out a normalized MLP for a small RL task, but while it learns and is stable, not really that much better than SOTA. while we are on this topic, i think i'll just also bring up the xval paper, where it shows that a transformer can make use of magnitude on a token to generalize better for numerical tasks. play devil's advocate for this approach |
@lucidrains I wonder what you would recommend as an input projection for continuous data? especially if data dimensionality is 2-3x larger than transformer dim. I now have 2 ideas, first I have already verified and it works by setting norm_dim_in=True for input layer, second idea would be using just normal Linear layer without weight norm and just l2norm output of the layer. |
I guess another way is to use small MLP as input/output projection, the question is what weight norm options to use to project from data_dim to model_dim |
@inspirit in the RL setting, what worked was just a linear followed by an activation, but if somehow heavily normalized networks like these take off, i'm sure people will be combing the lit for better ways to encode magnitude into phase. in other words, i don't know |
Thanks for the great work. I notice that in the Attention and FFN, the output matrix (i.e., self.to_out) is normalized differently along the first dimension instead of the last dimension (normalizing along the last dimension is the default in your code, and this behavior is controlled by the flag "norm_dim_in").
I am wondering why the normalization is different for the output matrix. I was thinking that the author's goal of weight normalization is turning the computation of dot product to cosine similarity. But if we normalize along the first dimension of output matrix, then we are not calculating the cosine similarity between the intermediate state in FFN and weight vectors in output matrix. Correct me if I am wrong, really appreciate that.
The text was updated successfully, but these errors were encountered: