Skip to content

What's the exact meaning of 8.2x enhancement in memory efficiency, prompting latency, and generation latency? Can you provide the evaluation code? #29

@xinhaoH

Description

@xinhaoH

Thanks for your great work.

Q1:
I found that we should execute
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
and then do
past_key_value.update (kv_pruned / key|value_states).
Since the pruned score is calculated for each attention head.
This is totally different from the original implementation of GQA.
The question is that the original GQA reduces the k/v cache (bsz, num_key_value_groups=[8], q_len, head_dim/pruned_dim), but your work eliminates this advantage (bsz, num_heads=[32], q_len, head_dim/pruned_dim).

Q2:
I also noticed that in the prefill stage, although we prune the token number to max_capacity_prompt (2k), we still use full attention to compute attention weight.
For example, we input a 6k prompt to generate a response, and in the prefill stage, we choose the 2k most important tokens key/value_states_compress.
However, we still use 6k (seq_len dim) query_states@key_states.T instead of 2k key_states_compress@value_states_compress.T to compute attention weight.
Why don't we use the pruned 2k (seq_len dim) key_states_compress@value_states_compress.T to compute attention weight?

Thanks a lot!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions