Added inter document masking for manual and flash attention.#434
Added inter document masking for manual and flash attention.#434BlueCrescent wants to merge 17 commits intomainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
Adds inter-document masking support to the GPT-2 attention stack so sequences containing multiple concatenated documents can prevent cross-document attention for both manual attention and DAO flash attention (via varlen).
Changes:
- Added
CausalSelfAttention.prepare_inter_document_masking()and threaded optional masking info through attention execution paths. - Implemented DAO flash varlen execution path to support document-wise masking/splitting without padding leakage.
- Added extensive unit tests for inter-document masking behaviors and edge cases.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 8 comments.
| File | Description |
|---|---|
tests/models/test_causal_self_attention.py |
Adds comprehensive tests for inter-document masking across manual and DAO flash attention implementations. |
src/modalities/models/gpt2/gpt2_model.py |
Implements inter-document masking preparation, DAO flash varlen execution, and integrates masking into GPT2LLM/GPT2Block/CausalSelfAttention. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
…ce lengths required for inter document attention masking.
… manual attention. - Also applied some review comments.
…ed eod_token_id to eos_token_id
…doc string Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 6 out of 6 changed files in this pull request and generated 11 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
tests/models/test_gpt2_collator.py
Outdated
| def test_gpt2_collate_sub_seq_lengths_without_eos(): | ||
| collator = GPT2LLMCollateFn( | ||
| sample_key="input_ids", | ||
| target_key="labels", | ||
| sub_seq_lengths_key="sub_seq_lengths", | ||
| eos_token_id=99, | ||
| ) | ||
| batch = [ | ||
| {"input_ids": torch.tensor([10, 11, 12, 13, 14])}, | ||
| {"input_ids": torch.tensor([20, 21, 22, 23, 24])}, | ||
| ] | ||
|
|
||
| result = collator(batch) | ||
|
|
||
| assert result.samples["sub_seq_lengths"] == [[5], [5]] | ||
|
|
||
|
|
||
| def test_gpt2_collate_sub_seq_lengths_with_eos(): | ||
| collator = GPT2LLMCollateFn( | ||
| sample_key="input_ids", | ||
| target_key="labels", | ||
| sub_seq_lengths_key="sub_seq_lengths", | ||
| eos_token_id=99, | ||
| ) | ||
| batch = [ | ||
| {"input_ids": torch.tensor([1, 99, 2, 3, 99])}, | ||
| {"input_ids": torch.tensor([7, 8, 9, 99, 10])}, | ||
| ] | ||
|
|
||
| result = collator(batch) | ||
|
|
||
| assert result.samples["sub_seq_lengths"] == [[2, 3], [4, 1]] |
There was a problem hiding this comment.
Please verify the test expectations are correct. For the input [10, 11, 12, 13, 14], after the shift operation (sample_tensor[:, :-1]) the sequence becomes [10, 11, 12, 13] with length 4. When there's no EOS token, the collator returns [len(seq)] which should be [4], but the test expects [[5], [5]]. Similarly, for test_gpt2_collate_sub_seq_lengths_with_eos, trace through the logic to ensure the expected values match the implementation. If the tests are passing, please add comments explaining the counter-intuitive behavior.
| elif attn_mask.dtype == torch.bool: | ||
| if attn_mask.dim() == 3: | ||
| combined_mask = temp_mask.unsqueeze(0) & attn_mask | ||
| else: | ||
| combined_mask = temp_mask & attn_mask | ||
| fully_masked = ~combined_mask.any(dim=-1) | ||
| attn_bias.masked_fill_(combined_mask.logical_not(), float("-inf")) | ||
| else: | ||
| if attn_mask.dim() == 3: | ||
| temp_mask = temp_mask.unsqueeze(0) | ||
| attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) | ||
| attn_bias += attn_mask |
There was a problem hiding this comment.
The fully_masked variable is computed to identify rows that have no valid attention positions after combining causal and inter-document masks. However, this is only used when attn_mask.dtype is torch.bool within the is_causal branch. If attn_mask is a float mask, fully_masked will remain None, which means the special handling for fully masked rows won't apply. This could lead to NaN values in attention weights after softmax on fully masked rows when using float masks. Consider computing fully_masked for float masks as well, or document this limitation.
| if len(eos_positions) == 0: | ||
| assert ( | ||
| self.padding_token_id is None or seq[0] != self.padding_token_id | ||
| ), "Sequence starts with padding token" |
There was a problem hiding this comment.
The assertion message "Sequence starts with padding token" is not very informative. It doesn't explain why this is a problem or what the user should do to fix it. Consider improving the error message to explain that sequences cannot start with padding tokens because it would result in invalid sub-sequence length computation, and suggest how to fix the data (e.g., "Invalid sequence: cannot start with padding token. Please ensure padding is only at the end of sequences after EOS tokens.").
| ), "Sequence starts with padding token" | |
| ), ( | |
| "Invalid sequence: cannot start with padding token. This prevents valid " | |
| "sub-sequence length computation when no EOS token is present. Please ensure " | |
| "padding is only applied at the end of sequences, typically after EOS tokens." | |
| ) |
There was a problem hiding this comment.
Why is this a problem? Because of the assumption in _has_cutoff_final_sequence?
There was a problem hiding this comment.
Yes, this assertion is to detect the case that a whole batch sequence consists of padding tokens. I changed it to actually check the whole sequence as well if the first sequence is a padding token.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
| if len(eos_positions) == 0: | ||
| assert ( | ||
| self.padding_token_id is None or seq[0] != self.padding_token_id | ||
| ), "Sequence starts with padding token" |
| sub_seq_lengths.append([len(seq)]) | ||
| else: | ||
| subseq_lengths = self._compute_subsequence_length(seq, eos_positions) | ||
| sub_seq_lengths.append(subseq_lengths) |
There was a problem hiding this comment.
The naming is confusing, since sub_seq_lengths and subseq_lengths are basically the same name. Maybe better use e.g. batch_sub_seq_lengths and sample_sub_seq_lengths
There was a problem hiding this comment.
Improved naming in collator
| if len(eos_positions) == 0: | ||
| assert ( | ||
| self.padding_token_id is None or seq[0] != self.padding_token_id | ||
| ), "Sequence starts with padding token" |
There was a problem hiding this comment.
Why is this a problem? Because of the assumption in _has_cutoff_final_sequence?
| batch = [{"input_ids": torch.tensor([0, 1, 2, 3])}] | ||
|
|
||
| with pytest.raises(AssertionError, match="Sequence starts with padding token"): | ||
| collator(batch) |
There was a problem hiding this comment.
Unrelated to this test, but I wanted to note it somewhere:
2 tests are failing now:
tests/conversion/gpt2/test_conversion_model.py.py::test_convert_model_checkpoint_produces_same_logits_as_original[gpt2_config_test.yaml-False]
fails with:
/workspaces/modalities/tests/conversion/gpt2/test_conversion_model.py::test_convert_model_checkpoint_produces_same_logits_as_original[gpt2_config_test.yaml-False] failed: gpt2_config_path = PosixPath('/tmp/pytest-of-richard-rutmann/pytest-38/gpt2_model2/gpt2_config_test.yaml')
def test_convert_model_checkpoint_produces_same_logits_as_original(gpt2_config_path: Path):
modalities_config = load_app_config_dict(gpt2_config_path)
hf_model, modalities_model = convert_model_checkpoint(modalities_config)
vocab_size = modalities_config["model_raw" if "model_raw" in modalities_config else "model"]["config"]["vocab_size"]
check_converted_model(hf_model, modalities_model, num_testruns=1, vocab_size=vocab_size)
tests/conversion/gpt2/test_conversion_model.py:32:
src/modalities/conversion/gpt2/conversion_model.py:84: in check_converted_model
llama_logits = hf_model(input_ids=input_ids).logits.to("cpu")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1783: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1794: in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/home/richard-rutmann/.local/lib/python3.11/site-packages/transformers/utils/generic.py:918: in wrapper
output = func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
src/modalities/conversion/gpt2/modeling_gpt2.py:455: in forward
outputs: BaseModelOutputWithPast = self.model(
/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1783: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
self = GPT2Model(
(embed_tokens): Embedding(50304, 256)
(layers): ModuleList(
(0-2): 3 x GPT2DecoderLayer(
(sel...rue)
)
)
(norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(rotary_emb): LlamaRotaryEmbedding()
)
args = ()
kwargs = {'attention_mask': None, 'cache_position': None, 'input_ids': tensor([[14510, 17260, 49404, 12709, 8274, 42238, 20462...4, 33543, 41703, 1661,
24343, 43573, 30272, 6699, 26169, 50070, 27464, 23769]]), 'inputs_embeds': None, ...}
def _call_impl(self, *args, **kwargs):
forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
# If we don't have any hooks, we want to skip the rest of the logic in
# this function, and just call forward.
if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
or _global_backward_pre_hooks or _global_backward_hooks
or _global_forward_hooks or _global_forward_pre_hooks):
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E TypeError: check_model_inputs..wrapped_fn() got an unexpected keyword argument 'input_ids'
/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1794: TypeError
And
tests/conversion/gpt2/test_convert_gpt2.py::test_converting_gpt2_does_not_change_outputs[gpt2_config_test.yaml-False]
fails with:
/workspaces/modalities/tests/conversion/gpt2/test_convert_gpt2.py::test_converting_gpt2_does_not_change_outputs[gpt2_config_test.yaml-False] failed: converted_model = GPT2ForCausalLM(
(model): GPT2Model(
(embed_tokens): Embedding(50304, 256)
(layers): ModuleList(
(0-2)...ue)
(rotary_emb): LlamaRotaryEmbedding()
)
(lm_head): Linear(in_features=256, out_features=50304, bias=False)
)
original_model = GPT2LLM(
(transformer): ModuleDict(
(wte): Embedding(50304, 256)
(wpe): Identity()
(drop): Dropout(p=0.0...256,), eps=1e-05, elementwise_affine=True)
(lm_head): Linear(in_features=256, out_features=50304, bias=False)
)
)
vocab_size = 50304
def test_converting_gpt2_does_not_change_outputs(
converted_model: PreTrainedModel, original_model: GPT2LLM, vocab_size: int
):
check_converted_model(
hf_model=converted_model, modalities_model=original_model, num_testruns=1, vocab_size=vocab_size
)
tests/conversion/gpt2/test_convert_gpt2.py:22:
src/modalities/conversion/gpt2/conversion_model.py:84: in check_converted_model
llama_logits = hf_model(input_ids=input_ids).logits.to("cpu")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1783: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1794: in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/home/richard-rutmann/.local/lib/python3.11/site-packages/transformers/utils/generic.py:918: in wrapper
output = func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
/home/richard-rutmann/.cache/huggingface/modules/transformers_modules/output/modeling_gpt2.py:455: in forward
outputs: BaseModelOutputWithPast = self.model(
/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1783: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
self = GPT2Model(
(embed_tokens): Embedding(50304, 256)
(layers): ModuleList(
(0-2): 3 x GPT2DecoderLayer(
(sel...rue)
)
)
(norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(rotary_emb): LlamaRotaryEmbedding()
)
args = ()
kwargs = {'attention_mask': None, 'cache_position': None, 'input_ids': tensor([[ 6008, 4000, 348, 4355, 37595, 29409, 18783...4, 36419, 28141, 34024,
37034, 16807, 47184, 40878, 41777, 40289, 14762, 50230]]), 'inputs_embeds': None, ...}
def _call_impl(self, *args, **kwargs):
forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
# If we don't have any hooks, we want to skip the rest of the logic in
# this function, and just call forward.
if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
or _global_backward_pre_hooks or _global_backward_hooks
or _global_forward_hooks or _global_forward_pre_hooks):
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E TypeError: check_model_inputs..wrapped_fn() got an unexpected keyword argument 'input_ids'
/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1794: TypeError
There was a problem hiding this comment.
I think, that is the bug caused by transformers version 4.57.2 or 4.57.3. It does not happen with 4.57.6.
rrutmann
left a comment
There was a problem hiding this comment.
Looks good and well tested. Please have a look at the two failing tests and if they are related to your changes
Co-authored-by: Richard Rutmann <97447451+rrutmann@users.noreply.github.com>
…ing with pytorch sdpa
…h consisting completely of padding tokens
Co-authored-by: Richard Rutmann <97447451+rrutmann@users.noreply.github.com>
…rted pp + inter doc masking
What does this PR do?
Adds inter document masking for manual and flash attention.
General Changes
prepare_inter_document_masking()to CausalSelfAttention which computes 3D attention masks for manual attention and cu_seqlens for DAO flash attention. The input are the sub sequence lengths for each sequence. Thus, padded sequences are also supported.forward()call, inter document masking is applied.Breaking Changes
Checklist before submitting final PR
python tests/tests.py)CHANGELOG_DEV.md)