Skip to content

Added inter document masking for manual and flash attention.#434

Open
BlueCrescent wants to merge 17 commits intomainfrom
inter_document_masking_for_attention
Open

Added inter document masking for manual and flash attention.#434
BlueCrescent wants to merge 17 commits intomainfrom
inter_document_masking_for_attention

Conversation

@BlueCrescent
Copy link
Member

What does this PR do?

Adds inter document masking for manual and flash attention.

General Changes

  • Added prepare_inter_document_masking() to CausalSelfAttention which computes 3D attention masks for manual attention and cu_seqlens for DAO flash attention. The input are the sub sequence lengths for each sequence. Thus, padded sequences are also supported.
  • When provided to the attention's forward() call, inter document masking is applied.
  • Added thorough tests for CausalSelfAttention.
  • Integrated into GPT2Model.
  • TODO: Test GPT2Model, support PP, create corresponding dataloader

Breaking Changes

  • None, if no inter document sequence lengths are provided, the behavior should remain unchanged.

Checklist before submitting final PR

  • My PR is minimal and addresses one issue in isolation
  • I have merged the latest version of the target branch into this feature branch
  • I have reviewed my own code w.r.t. correct implementation, missing type hints, proper documentation, etc.
  • I have run a sample config for model training
  • I have checked that all tests run through (python tests/tests.py)
  • I have updated the internal changelog (CHANGELOG_DEV.md)

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds inter-document masking support to the GPT-2 attention stack so sequences containing multiple concatenated documents can prevent cross-document attention for both manual attention and DAO flash attention (via varlen).

Changes:

  • Added CausalSelfAttention.prepare_inter_document_masking() and threaded optional masking info through attention execution paths.
  • Implemented DAO flash varlen execution path to support document-wise masking/splitting without padding leakage.
  • Added extensive unit tests for inter-document masking behaviors and edge cases.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 8 comments.

File Description
tests/models/test_causal_self_attention.py Adds comprehensive tests for inter-document masking across manual and DAO flash attention implementations.
src/modalities/models/gpt2/gpt2_model.py Implements inter-document masking preparation, DAO flash varlen execution, and integrates masking into GPT2LLM/GPT2Block/CausalSelfAttention.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@BlueCrescent BlueCrescent marked this pull request as ready for review February 25, 2026 10:42
@BlueCrescent BlueCrescent requested a review from Copilot February 25, 2026 10:42
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 6 out of 6 changed files in this pull request and generated 11 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 20 to 51
def test_gpt2_collate_sub_seq_lengths_without_eos():
collator = GPT2LLMCollateFn(
sample_key="input_ids",
target_key="labels",
sub_seq_lengths_key="sub_seq_lengths",
eos_token_id=99,
)
batch = [
{"input_ids": torch.tensor([10, 11, 12, 13, 14])},
{"input_ids": torch.tensor([20, 21, 22, 23, 24])},
]

result = collator(batch)

assert result.samples["sub_seq_lengths"] == [[5], [5]]


def test_gpt2_collate_sub_seq_lengths_with_eos():
collator = GPT2LLMCollateFn(
sample_key="input_ids",
target_key="labels",
sub_seq_lengths_key="sub_seq_lengths",
eos_token_id=99,
)
batch = [
{"input_ids": torch.tensor([1, 99, 2, 3, 99])},
{"input_ids": torch.tensor([7, 8, 9, 99, 10])},
]

result = collator(batch)

assert result.samples["sub_seq_lengths"] == [[2, 3], [4, 1]]
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please verify the test expectations are correct. For the input [10, 11, 12, 13, 14], after the shift operation (sample_tensor[:, :-1]) the sequence becomes [10, 11, 12, 13] with length 4. When there's no EOS token, the collator returns [len(seq)] which should be [4], but the test expects [[5], [5]]. Similarly, for test_gpt2_collate_sub_seq_lengths_with_eos, trace through the logic to ensure the expected values match the implementation. If the tests are passing, please add comments explaining the counter-intuitive behavior.

Copilot uses AI. Check for mistakes.
Comment on lines +1284 to +1295
elif attn_mask.dtype == torch.bool:
if attn_mask.dim() == 3:
combined_mask = temp_mask.unsqueeze(0) & attn_mask
else:
combined_mask = temp_mask & attn_mask
fully_masked = ~combined_mask.any(dim=-1)
attn_bias.masked_fill_(combined_mask.logical_not(), float("-inf"))
else:
if attn_mask.dim() == 3:
temp_mask = temp_mask.unsqueeze(0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias += attn_mask
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fully_masked variable is computed to identify rows that have no valid attention positions after combining causal and inter-document masks. However, this is only used when attn_mask.dtype is torch.bool within the is_causal branch. If attn_mask is a float mask, fully_masked will remain None, which means the special handling for fully masked rows won't apply. This could lead to NaN values in attention weights after softmax on fully masked rows when using float masks. Consider computing fully_masked for float masks as well, or document this limitation.

Copilot uses AI. Check for mistakes.
if len(eos_positions) == 0:
assert (
self.padding_token_id is None or seq[0] != self.padding_token_id
), "Sequence starts with padding token"
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assertion message "Sequence starts with padding token" is not very informative. It doesn't explain why this is a problem or what the user should do to fix it. Consider improving the error message to explain that sequences cannot start with padding tokens because it would result in invalid sub-sequence length computation, and suggest how to fix the data (e.g., "Invalid sequence: cannot start with padding token. Please ensure padding is only at the end of sequences after EOS tokens.").

Suggested change
), "Sequence starts with padding token"
), (
"Invalid sequence: cannot start with padding token. This prevents valid "
"sub-sequence length computation when no EOS token is present. Please ensure "
"padding is only applied at the end of sequences, typically after EOS tokens."
)

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this a problem? Because of the assumption in _has_cutoff_final_sequence?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this assertion is to detect the case that a whole batch sequence consists of padding tokens. I changed it to actually check the whole sequence as well if the first sequence is a padding token.

if len(eos_positions) == 0:
assert (
self.padding_token_id is None or seq[0] != self.padding_token_id
), "Sequence starts with padding token"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

sub_seq_lengths.append([len(seq)])
else:
subseq_lengths = self._compute_subsequence_length(seq, eos_positions)
sub_seq_lengths.append(subseq_lengths)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The naming is confusing, since sub_seq_lengths and subseq_lengths are basically the same name. Maybe better use e.g. batch_sub_seq_lengths and sample_sub_seq_lengths

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Improved naming in collator

if len(eos_positions) == 0:
assert (
self.padding_token_id is None or seq[0] != self.padding_token_id
), "Sequence starts with padding token"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this a problem? Because of the assumption in _has_cutoff_final_sequence?

batch = [{"input_ids": torch.tensor([0, 1, 2, 3])}]

with pytest.raises(AssertionError, match="Sequence starts with padding token"):
collator(batch)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to this test, but I wanted to note it somewhere:
2 tests are failing now:

tests/conversion/gpt2/test_conversion_model.py.py::test_convert_model_checkpoint_produces_same_logits_as_original[gpt2_config_test.yaml-False]

fails with:
/workspaces/modalities/tests/conversion/gpt2/test_conversion_model.py::test_convert_model_checkpoint_produces_same_logits_as_original[gpt2_config_test.yaml-False] failed: gpt2_config_path = PosixPath('/tmp/pytest-of-richard-rutmann/pytest-38/gpt2_model2/gpt2_config_test.yaml')

def test_convert_model_checkpoint_produces_same_logits_as_original(gpt2_config_path: Path):
    modalities_config = load_app_config_dict(gpt2_config_path)
    hf_model, modalities_model = convert_model_checkpoint(modalities_config)
    vocab_size = modalities_config["model_raw" if "model_raw" in modalities_config else "model"]["config"]["vocab_size"]
  check_converted_model(hf_model, modalities_model, num_testruns=1, vocab_size=vocab_size)

tests/conversion/gpt2/test_conversion_model.py:32:


src/modalities/conversion/gpt2/conversion_model.py:84: in check_converted_model
llama_logits = hf_model(input_ids=input_ids).logits.to("cpu")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1783: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1794: in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/home/richard-rutmann/.local/lib/python3.11/site-packages/transformers/utils/generic.py:918: in wrapper
output = func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
src/modalities/conversion/gpt2/modeling_gpt2.py:455: in forward
outputs: BaseModelOutputWithPast = self.model(
/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1783: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^


self = GPT2Model(
(embed_tokens): Embedding(50304, 256)
(layers): ModuleList(
(0-2): 3 x GPT2DecoderLayer(
(sel...rue)
)
)
(norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(rotary_emb): LlamaRotaryEmbedding()
)
args = ()
kwargs = {'attention_mask': None, 'cache_position': None, 'input_ids': tensor([[14510, 17260, 49404, 12709, 8274, 42238, 20462...4, 33543, 41703, 1661,
24343, 43573, 30272, 6699, 26169, 50070, 27464, 23769]]), 'inputs_embeds': None, ...}

def _call_impl(self, *args, **kwargs):
    forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
    # If we don't have any hooks, we want to skip the rest of the logic in
    # this function, and just call forward.
    if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
            or _global_backward_pre_hooks or _global_backward_hooks
            or _global_forward_hooks or _global_forward_pre_hooks):
      return forward_call(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

E TypeError: check_model_inputs..wrapped_fn() got an unexpected keyword argument 'input_ids'

/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1794: TypeError

And
tests/conversion/gpt2/test_convert_gpt2.py::test_converting_gpt2_does_not_change_outputs[gpt2_config_test.yaml-False]

fails with:
/workspaces/modalities/tests/conversion/gpt2/test_convert_gpt2.py::test_converting_gpt2_does_not_change_outputs[gpt2_config_test.yaml-False] failed: converted_model = GPT2ForCausalLM(
(model): GPT2Model(
(embed_tokens): Embedding(50304, 256)
(layers): ModuleList(
(0-2)...ue)
(rotary_emb): LlamaRotaryEmbedding()
)
(lm_head): Linear(in_features=256, out_features=50304, bias=False)
)
original_model = GPT2LLM(
(transformer): ModuleDict(
(wte): Embedding(50304, 256)
(wpe): Identity()
(drop): Dropout(p=0.0...256,), eps=1e-05, elementwise_affine=True)
(lm_head): Linear(in_features=256, out_features=50304, bias=False)
)
)
vocab_size = 50304

def test_converting_gpt2_does_not_change_outputs(
    converted_model: PreTrainedModel, original_model: GPT2LLM, vocab_size: int
):
  check_converted_model(
        hf_model=converted_model, modalities_model=original_model, num_testruns=1, vocab_size=vocab_size
    )

tests/conversion/gpt2/test_convert_gpt2.py:22:


src/modalities/conversion/gpt2/conversion_model.py:84: in check_converted_model
llama_logits = hf_model(input_ids=input_ids).logits.to("cpu")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1783: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1794: in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/home/richard-rutmann/.local/lib/python3.11/site-packages/transformers/utils/generic.py:918: in wrapper
output = func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
/home/richard-rutmann/.cache/huggingface/modules/transformers_modules/output/modeling_gpt2.py:455: in forward
outputs: BaseModelOutputWithPast = self.model(
/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1783: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^


self = GPT2Model(
(embed_tokens): Embedding(50304, 256)
(layers): ModuleList(
(0-2): 3 x GPT2DecoderLayer(
(sel...rue)
)
)
(norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(rotary_emb): LlamaRotaryEmbedding()
)
args = ()
kwargs = {'attention_mask': None, 'cache_position': None, 'input_ids': tensor([[ 6008, 4000, 348, 4355, 37595, 29409, 18783...4, 36419, 28141, 34024,
37034, 16807, 47184, 40878, 41777, 40289, 14762, 50230]]), 'inputs_embeds': None, ...}

def _call_impl(self, *args, **kwargs):
    forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
    # If we don't have any hooks, we want to skip the rest of the logic in
    # this function, and just call forward.
    if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
            or _global_backward_pre_hooks or _global_backward_hooks
            or _global_forward_hooks or _global_forward_pre_hooks):
      return forward_call(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

E TypeError: check_model_inputs..wrapped_fn() got an unexpected keyword argument 'input_ids'

/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1794: TypeError

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, that is the bug caused by transformers version 4.57.2 or 4.57.3. It does not happen with 4.57.6.

Copy link
Collaborator

@rrutmann rrutmann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good and well tested. Please have a look at the two failing tests and if they are related to your changes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants