Description
⚠️ Please check that this feature request hasn't been suggested before.
- I searched previous Ideas in Discussions didn't find any similar feature requests.
- I searched previous Issues didn't find any similar feature requests.
🔖 Feature description
Currently LoRA optimizations target only specific modules, but no optimized kernels exist for lm_head or embed_tokens. As a result full-model fine-tuning with lora suffers from reduced throughput and increased memory usage
Axolotl’s current LoRA setup already provides optimized kernels for :
lora_mlp_kernel: true
lora_o_kernel: true
lora_qkv_kernel: true
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- gate_proj
- down_proj
- up_proj
However, there are no analogous optimized kernels for:
lm_head
embed_tokens
When you fine-tune entire model, missing these kernels leads to significant drop in throughput and increased GPU memory usage.
✔️ Solution
Extend existing LoRA optimizations (eg via Triton or custom CUDA kernels) to cover lm_head and embed_tokens. Benchmarks should demonstrate restored training throughput and lower peak memory footprint. Once validated, integrate new kernels into Axolotl’s fine-tuning pipeline.
❓ Alternatives
No response
📝 Additional Context
No response
Acknowledgements
- My issue title is concise, descriptive, and in title casing.
- I have searched the existing issues to make sure this feature has not been requested yet.
- I have provided enough information for the maintainers to understand and evaluate this request.