Skip to content

Comments

[New Feasture]: Add a FLOPs collection interface#1302

Open
mahaocong90 wants to merge 2 commits intomodelscope:mainfrom
mahaocong90:dev-caculate-flops-for-wan-training
Open

[New Feasture]: Add a FLOPs collection interface#1302
mahaocong90 wants to merge 2 commits intomodelscope:mainfrom
mahaocong90:dev-caculate-flops-for-wan-training

Conversation

@mahaocong90
Copy link

@mahaocong90 mahaocong90 commented Feb 19, 2026

This PR add a FLOPs collection interface that supports real-time collection of floating-point operations (FLOPs) during the training process for WAN models.

Description
Wrap the nn.modules module and use Python decorators to estimate the forward computational flops based on module inputs, outputs, and model parameters.

Environment version
os:ubuntu24.04
cuda driver:550.163.01 + 12.9
Python:3.12.3
torch:2.8.0
xfuser:0.4.5
transformers:4.55.2
gpu: a800 x 8, one node

Print result
Once we have estimated the FLOPS for the forward process, we can roughly estimate the TFLOPS for a training step by the formula: (fwd + bwd) / time = 3 * fwd / time.

[Step      1] | Rank: 0 | loss: 0.02763 | lr: 3.33333e-06 | step/t: 3.335s | [t5] Tflops: 4.846 | [dit] Tflops: 55.759 | [vae] Tflops: 1.867 | TFLOPS: 56.193
[Step      2] | Rank: 0 | loss: 0.01775 | lr: 3.33333e-06 | step/t: 3.235s | [t5] Tflops: 4.846 | [dit] Tflops: 55.759 | [vae] Tflops: 1.867 | TFLOPS: 57.942
[Step      3] | Rank: 0 | loss: 0.07406 | lr: 1.00000e-05 | step/t: 2.545s | [t5] Tflops: 4.846 | [dit] Tflops: 55.759 | [vae] Tflops: 1.867 | TFLOPS: 73.648
[Step      4] | Rank: 0 | loss: 0.33582 | lr: 1.00000e-05 | step/t: 2.538s | [t5] Tflops: 4.846 | [dit] Tflops: 55.759 | [vae] Tflops: 1.867 | TFLOPS: 73.852

…f floating-point operations (FLOPs) during the training process for WAN models.
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @mahaocong90, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a robust system for real-time profiling of Floating Point Operations (FLOPs) during the training of deep learning models, specifically targeting WAN models. By wrapping neural network modules and employing decorators, it accurately estimates computational costs for various layers, providing valuable insights into model efficiency. The integration into the training loop allows for immediate feedback on TFLOPS per step, aiding in performance monitoring and optimization efforts.

Highlights

  • FLOPs Collection Interface: Implemented a new FLOPs collection interface to estimate floating-point operations in real-time during model training, specifically for WAN models.
  • Real-time Profiling Integration: Integrated FLOPs profiling into the training runner, enabling per-step TFLOPS logging for different model components (t5, dit, vae) and total TFLOPS.
  • Dynamic Module Wrapping: Introduced a flexible profiling mechanism using Python decorators to wrap nn.modules and calculate FLOPs based on inputs, outputs, and model parameters for various layer types.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • diffsynth/diffusion/runner.py
    • Imported FLOPs profiling utilities.
    • Initialized FLOPs profiling for the model at the start of training and unprofiled at the end.
    • Added per-step FLOPs calculation and logging for T5, WAN, and VAE components, along with total TFLOPS.
    • Enhanced the tqdm progress bar to display detailed FLOPs metrics.
  • diffsynth/models/wan_video_text_encoder.py
    • Stored query, key, and value tensor shapes (q_shape, k_shape, v_shape) within the attention module for FLOPs calculation.
  • diffsynth/utils/profiling/init.py
    • Created an __init__.py file to expose FLOPs profiling functions.
  • diffsynth/utils/profiling/flops_profiler.py
    • Added a new module for FLOPs profiling, including a flops_counter decorator and get_module_flops function to estimate FLOPs for various PyTorch layers (e.g., Linear, Conv, Norms, Activations, Attention).
    • Implemented profile_entire_model and unprofile_entire_model functions to dynamically wrap and unwrap module forward passes.
    • Included functions to retrieve and print model-wide FLOPs and duration.
  • diffsynth/utils/xfuser/xdit_context_parallel.py
    • Stored query, key, and value tensor shapes (q_shape, k_shape, v_shape) within the usp_attn_forward method for FLOPs calculation.
Activity
  • No specific activity (comments, reviews, or progress updates) has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a valuable FLOPs collection interface for real-time performance monitoring during model training. The implementation correctly uses decorators and module wrapping to estimate computation. My review includes suggestions to improve the robustness and maintainability of the new profiling code, addressing a potential bug in the module wrapping logic, removing unused code, and enhancing clarity in a few areas.

Comment on lines 228 to 229
def profiled_forward(self, x, *args, **kwargs):
return module._original_forward(x, *args, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The signature of profiled_forward is (self, x, *args, **kwargs), which assumes that every wrapped module's forward method has a first positional argument x. This is not always true and can lead to a TypeError for modules with different forward signatures (e.g., no arguments, or keyword-only arguments). The signature should be (self, *args, **kwargs) to be generic and robust.

Suggested change
def profiled_forward(self, x, *args, **kwargs):
return module._original_forward(x, *args, **kwargs)
def profiled_forward(self, *args, **kwargs):
return module._original_forward(*args, **kwargs)

Comment on lines 79 to 81
def format_time(key: str) -> str:
value = timing.get(key, 0.0)
return f"{value:.3f}s"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function is redefined on every iteration of the training loop, which is inefficient. It's better to define it once outside the loop. To do so, you'll need to pass the timing dictionary as an argument.

For example, you could define it before the loop:

def format_time(timing_dict: dict, key: str) -> str:
    value = timing_dict.get(key, 0.0)
    return f"{value:.3f}s"

And then call it inside the loop as format_time(timing, "step").

Since I cannot suggest changes outside of the current diff hunk, I'm leaving this as a comment for you to refactor.

k = self.k(context).view(b, -1, n, c)
v = self.v(context).view(b, -1, n, c)

# For caculate flops
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a typo in the comment. "caculate" should be "calculate".

Suggested change
# For caculate flops
# For calculate flops

Comment on lines 5 to 7
from collections import defaultdict
import flash_attn
from einops import rearrange
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These imports (defaultdict, flash_attn, einops) do not appear to be used in this file and can be removed to improve code clarity.

Comment on lines 10 to 25
def get_dit_flops(model):
def get_dit_flops(dit_block_model):
total_flops = 0
for sub_model in dit_block_model.modules():
total_flops += getattr(sub_model, '__flops__', 0)
return total_flops

total_flops = 0
total_duration = 0
for sub_module in model.modules():
if sub_module.__class__.__name__ == 'DiTBlock':
total_flops += get_dit_flops(sub_module)
total_duration += getattr(sub_module, '__duration__', 0)

Tflops = total_flops / 1e12
return Tflops
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function has a couple of issues affecting readability and maintainability:

  1. The nested function at line 11 has the same name as the outer function, which can be confusing. It's better to give it a more descriptive name, like _get_dit_flops_recursive.
  2. The total_duration variable is calculated but never used. It should be removed.
Suggested change
def get_dit_flops(model):
def get_dit_flops(dit_block_model):
total_flops = 0
for sub_model in dit_block_model.modules():
total_flops += getattr(sub_model, '__flops__', 0)
return total_flops
total_flops = 0
total_duration = 0
for sub_module in model.modules():
if sub_module.__class__.__name__ == 'DiTBlock':
total_flops += get_dit_flops(sub_module)
total_duration += getattr(sub_module, '__duration__', 0)
Tflops = total_flops / 1e12
return Tflops
def get_dit_flops(model):
def _get_dit_flops_recursive(dit_block_model):
total_flops = 0
for sub_model in dit_block_model.modules():
total_flops += getattr(sub_model, '__flops__', 0)
return total_flops
total_flops = 0
for sub_module in model.modules():
if sub_module.__class__.__name__ == 'DiTBlock':
total_flops += _get_dit_flops_recursive(sub_module)
Tflops = total_flops / 1e12
return Tflops

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant