faced a token indexing bug, where it was finding first token any of the output tokens instead of last token (actual label). fixed using flip() to find last occurrence.
memory oom issues with batch=4 (103gb) vs batch=1 (64gb). attention computation needed ~30gb per forward pass.
one important rule in finetuning any model, the distribution of your labels have to be equal, especially for LoRA training.