Description
After the attention layers were splitted into all nodes I missed the implications what it introduced.
Long story short: to calculate the attention for a single head from the Q output, I need to have the whole head from the K output. For x
Q head I need to have whole floor(x / (nHeads / nKvHeads))
K head to calculate the result.
For example Llama 3 8B:
💡 dim: 128
💡 nHeads: 32
💡 nKvHeads: 8
Q head 0 => floor( 0 / ( 32 / 8) ) => K head 0
Q head 1 => floor( 1 / ( 32 / 8) ) => K head 0
Q head 2 => floor( 2 / ( 32 / 8) ) => K head 0
...
Q head 8 => floor( 8 / ( 32 / 8) ) => K head 2
Q head 9 => floor( 9 / ( 32 / 8) ) => K head 2
...
Q head 31 => floor( 31 / ( 32 / 8) ) => K head 7
By this currently is not possible to split nodes to more than nKvHeads
nodes.
^ The same problem is with the V layer.
How this could be fixed?
1. Synchronize missing outputs
For nSlices > nKvHeads
setups there could be introduced a new synchronization step. This step would synchornize missing Q/V outputs across nodes. Ofc the synchronization is the slowest part of Distributed Llama.
2. Redundancy
The redundancy could be introduces for K/V layers. These layers should be splited with the aligment to headSize
. By this there is no synchronization, and redundant amount of calculations seems to be small (headSize - kvDim0
).
For example Llama 3 8B:
headSize = dim / nHeads = 128
kvDim = (dim * kvHeads) / nHeads = 1024
nSlices = 16
kvDim0 = kvDim / nSlices = 64
redundancy = 128 - 64 = 64 outputs of K & V
nSlices = 32
kvDim0 = kvDim / nSlices = 32
redundancy = 128 - 32 = 96 outputs of K & V