Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce MsT technologies into unsloth to extend sequence length #1082

Open
wants to merge 13 commits into
base: nightly
Choose a base branch
from

Conversation

wdlctc
Copy link

@wdlctc wdlctc commented Sep 30, 2024

Description
This pull request introduces optimizations to the LLaMA model implementation, specifically targeting the language modeling head and forward pass. The main changes include:

Implement a custom _LM_head using torch.autograd.Function for more efficient forward and backward passes.
Introduce a LMheadWarpper class to manage the custom LM head.
Add minis_processing function to handle mini-batch processing of hidden states and labels.
Modify the CausalLM_fast_forward function to use the new mini-batch processing and custom LM head.

Changes

  • Added _LM_head class for custom autograd function
  • Introduced LMheadWarpper class
  • Implemented minis_processing function for mini-batch handling
  • Updated CausalLM_fast_forward to use new optimizations
  • Removed some redundant code and comments

Benefits

  • Improved memory efficiency through mini-batch processing
  • Improved maximum sequence support by 4x

Testing
Please ensure to test this implementation thoroughly, especially:

Performance comparison with the original implementation
Correctness of loss calculation and gradient computation
Memory usage across various input sizes

@shimmyshimmer
Copy link
Collaborator

Thank you @wdlctc ! We will review it and hopefully be able to push it in after our multimodal release! :)

@wdlctc
Copy link
Author

wdlctc commented Oct 4, 2024

Thank you @shimmyshimmer, for your review I addon detailed training info for reference:

  1. standard training with slightly better loss performance: unsloth(1.192900) vs unsloth+MST(1.165600)
  2. 2x long sequence length on LLAMA2: unsloth OOM at 25k, unsloth work at 12k, unsloth+MST work at 25k

For more implementation, you can refer our blog: https://wdlctc.github.io/mst.html or our paper https://www.arxiv.org/abs/2407.15892

If you need other fine-tuned settings, I can try it another time.

@wdlctc
Copy link
Author

wdlctc commented Oct 5, 2024

Rewrite it with unsloth fast_cross_entropy. We are surprised to find that integrated MST with unsloth not only improve memory behavior, but also introduce speedup.

The key difference: checkpointing hidden_state of LM-head (input) instead of checkpointing logits(output)

  • Save memory by splitting logits into mini-sequence
  • Speed up as CPU offload logits is time consuming ((b, s, 32000) for llama) for llama, but CPU offload hidden_state is much faster ((b, s, 4096) for llama)

@danielhanchen
Copy link
Contributor

@wdlctc Thanks a lot again!! I'll test it and verify all losses match! Appreciate it!

@danielhanchen danielhanchen changed the base branch from main to nightly October 11, 2024 08:44
@wdlctc
Copy link
Author

wdlctc commented Oct 14, 2024

10/14/2024: Resolve the conflicts with nightly branch

@danielhanchen
Copy link
Contributor

Sorry on the delay - was planning to add this together with Vision support :) It might take a few more days!

@danielhanchen danielhanchen deleted the branch unslothai:nightly October 18, 2024 03:43
@danielhanchen danielhanchen reopened this Oct 20, 2024
@danielhanchen
Copy link
Contributor

Oh lol I noticed I accidentally deleted this PR after I deleted the nightly branch - whoops so sorry!

@danielhanchen
Copy link
Contributor

Interesting so I looked through the paper and code, essentially you're proposing to essentially do gradient accumulation inside of each sequence length?
Ie the first is normally chunking the CE Loss kernel amongst large columnar blocks, but you're suggesting the 2nd - chunking the rows itself.
image

And the trick is since it's row chunked, we also do not materialize the full logits but instead re-compute them on the fly?

@wdlctc
Copy link
Author

wdlctc commented Oct 20, 2024

Yes! key insight is full logits is too big especially when vocabulary size is large on LLAMA3(128k) and Gemma2(256), so re-compute them on the fly can effectively reduce memory(only compute one chunk at a time and discard previous chunk) and time(for offloading).

We do suggest do that row chunked, but you can also do both, row and col, as for LM-head and MLP the row and col(batch and seq) are independent. And it is effective as long context training would use local_batch_size=1.

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants