[SP] add SP deny list instead of allow#7887
[SP] add SP deny list instead of allow#7887kashif wants to merge 20 commits intodeepspeedai:masterfrom
Conversation
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
tohtana
left a comment
There was a problem hiding this comment.
Hi @kashif,
Thank you for opening this PR! I think supporting HF hub kernels is is a significant update.
Regarding the approach, we check if core_attn_implementation is in ALL_ATTENTION_FUNCTIONS but HF hub kernels like kernels-community/flash-attn2 is not in the list. So HF hub kernels won’t still be available with this fix.
We probably need to do the proper registration steps:
- Reject known-bad impls explicitly: eager, flex_attention, and probably paged|eager.
- If
core_attn_implementationis an HF hub kernel string, call the HF registration path first. (Usinglazy_import_flash_attention(…)) - Then read
core_attn_function = ALL_ATTENTION_FUNCTIONS[core_attn_implementation]. - Build
uattnfrom that original function. - Replace that key with
uattn_wrapper.
Does it make sense to you?
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
|
thanks @tohtana I have tried to fix all the issues raised, if you can kindly check again? |
We actually don't know if flex_attention is bad, we just haven't tried it out. Do you have resources to try it out, Kashif? Same for the others on the list. That's why we started with approve list, rather than deny. The only reason eager is denied is that it requires 4D attention_mask which is a bad idea for long sequence. BTW, SDPA is silently broken with packed samples - when there is no attn mask, it ignores pos ids and attends to the whole sequence instead. Expect bad results. Not sure how to flag that to users - probably need to inspect pos ids and see if they reset at least once and disallow sdpa then. |
|
Hi @kashif, I also think Stas's comment makes sense. Can you try implementing such a validation? |
|
sure @tohtana i can check |
|
to make things more exact - it's packed samples + pos ids + 4D |
|
oh, Kashif, I'm being told |
|
I ran some experiments comparing flash_attention_2, sdpa, and flex_attention with SP=4 on Qwen3-4B (GQA: 32 Q Without SP (1 GPU baseline): flash_attention_2 and sdpa produce identical losses — confirming the backends are With SP=4 (4 GPUs): sdpa and flex_attention match each other, but both diverge significantly from @stas00 any ideas on what flash_attention_2 might be doing differently after the all-to-all that |
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
|
ok @stas00 I now enerate position_ids if missing from batch, build causal BlockMask for flex_attention and do a one-time packed sample validation for packed samples + sdpa/eager Now the outputs are matching: |
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
|
Thank you for running those quality comparison experiments, Kashif I'm a bit unclear about your last "success" comment - what was missing to make FA2 match? are you saying the mismatch was from missing position_ids? but we said that already that SDPA (and now most likely FlexAttenion) have a trouble with no-attn-mask / yes-pos-id and will ignore packed samples. SDPA on the other hand does the right thing here. And it's great to hear Flex Attention works as well with Ulysses, so we could add it to the allow list. |
| if has_packed_samples and self.core_attn_implementation in ("sdpa", "eager"): | ||
| raise ValueError( |
There was a problem hiding this comment.
heh, I thought we were discussing that it's HF Transformers that has to do that, not Ulysses SP. It affects all users regardless of whether they use Ulysses or not. Unless HF Transformers disallows not providing attn-mask with sdpa/eager, which I don't think is the case.
There was a problem hiding this comment.
agree, removed from DeepSpeed side
So, FA2 was the one producing correct results, while SDPA/flex were wrong. Here's what was happening: When FA2 "accidentally" handles this correctly — SDPA with The fix: generate With this fix, all three backends match within numerical precision: For |
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
|
great explanations, Kashif - thank you!
Thank you, Kashif |
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
|
@stas00, regarding point 2, we added
|
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
|
Thank you very much, Kashif. Do you think all this amazing tooling you added should live here and not in HF Transformers? |
|
checking |
|
So some SP-specific things tied to the all-to-all make sense to be here...
Agree that in Transformers:
On the TRL side:
|
|
Thank you for the detailed summary, Kashif. I agree with everything, except:
I think it should assert. Warnings don't work and allowing invalid training can be so so costly to the user who missed the warning in the sea of warnings. I wonder how many people will discover their model has been mistrained and they had no clue that was the case, other than getting bad outcomes. |
|
Please let us know when things are ready for the final review, Kashif. |
|
thanks @stas00 yes we are asserting the |
|
yes ready for review, thanks! |
I meant inside transformers. Currently transformers may provide a disservice to users if they use packed samples w/o attention w/ sdpa/eager - or is it the case that transformers enforces 4D attention mask? |
stas00
left a comment
There was a problem hiding this comment.
Kashif,
overall looks great - added a few suggestions
- do you think we should discuss the different supported attn types in the tutorial as well?
if so let's add a brief section there?
- also we can now probably test fa4 and add it to the list - fa4 support has been merged in transformers a few days ago.
| f"{core_attn_implementation} attn_implementation isn't currently supported by Ulysses sequence" | ||
| f" parallelism. Set core_attn_implementation arg to one of {supported_attn_implementation}.") | ||
| f" parallelism because it requires a 4D attention_mask (O(n²) memory)." | ||
| f" Use 'flash_attention_2', 'flash_attention_3', 'flex_attention', 'sdpa'," |
There was a problem hiding this comment.
Should we future proof this for fa and say any official flash attention version?
- Use lazy imports for BlockMask/create_block_mask instead of storing on instance attributes, fixing multiprocessing pickle errors in tests - Future-proof error message for unsupported attn implementations - Add TestUlyssesSPHFFlexAttention test class with non_daemonic_procs and a model with head_dim >= 16 (flex_attention compiled kernel requirement) Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
…rward() - Move `attention_mask = self._flex_block_mask_cached` inside the `isinstance(attention_mask, BlockMask)` guard to prevent stale cache from leaking when attention_mask is not a BlockMask - Add warning_once in forward() when position_ids are missing, so users who bypass UlyssesSPDataLoaderAdapter are alerted to potential incorrect causal masking Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
|
the test is failing... i am investigating |
- Replace Felladrin/Llama-160M-Chat-v1 with a locally-created LlamaConfig (head_dim=16, 2 layers, 2 heads) to match DeepSpeed's convention of using tiny models in tests, avoiding external model downloads - Remove _compile=True from create_block_mask — it caused gradient explosion in the backward pass through torch.compile - Set random seed for reproducible model initialization - Use torch_assert_close for loss (flex_attention + torch.compile introduces tiny numerical differences vs exact match) - Parametrize over zero_stage [2, 3] matching existing test convention Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
|
ok tests passing now |
Compiling create_block_mask (via _compile=True or torch.compile) inside the model forward causes gradient explosion when the resulting BlockMask is used with flex_attention's own torch.compile. The nested compilation contexts conflict in the backward pass. Since the BlockMask is already cached and only rebuilt when dimensions change, the creation cost is negligible without compilation. Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
position_ids are required for Ulysses SP — without them each rank generates local [0..chunk_len-1] positions which break causal masking after the all_gather. A warning is useless since training silently produces wrong results. Make it a hard assert with actionable guidance. Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
this way one can register kernels based flash-attn as well with SP