Skip to content

Support nSlices > nKvHeads #70

Open
@b4rtaz

Description

@b4rtaz

After the attention layers were splitted into all nodes I missed the implications what it introduced.

image

Long story short: to calculate the attention for a single head from the Q output, I need to have the whole head from the K output. For x Q head I need to have whole floor(x / (nHeads / nKvHeads)) K head to calculate the result.

For example Llama 3 8B:

💡 dim: 128
💡 nHeads: 32
💡 nKvHeads: 8

Q head 0  => floor(  0 / ( 32 / 8) ) => K head 0
Q head 1  => floor(  1 / ( 32 / 8) ) => K head 0
Q head 2  => floor(  2 / ( 32 / 8) ) => K head 0
...
Q head 8  => floor(  8 / ( 32 / 8) ) => K head 2
Q head 9  => floor(  9 / ( 32 / 8) ) => K head 2
...
Q head 31 => floor( 31 / ( 32 / 8) ) => K head 7

By this currently is not possible to split nodes to more than nKvHeads nodes.

^ The same problem is with the V layer.


How this could be fixed?

1. Synchronize missing outputs

For nSlices > nKvHeads setups there could be introduced a new synchronization step. This step would synchornize missing Q/V outputs across nodes. Ofc the synchronization is the slowest part of Distributed Llama.

2. Redundancy

The redundancy could be introduces for K/V layers. These layers should be splited with the aligment to headSize. By this there is no synchronization, and redundant amount of calculations seems to be small (headSize - kvDim0).

For example Llama 3 8B:

headSize = dim / nHeads = 128
kvDim = (dim * kvHeads) / nHeads = 1024

nSlices = 16
kvDim0 = kvDim / nSlices = 64
redundancy = 128 - 64 = 64 outputs of K & V

nSlices = 32
kvDim0 = kvDim / nSlices = 32
redundancy = 128 - 32 = 96 outputs of K & V

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions