Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
320 changes: 320 additions & 0 deletions backends/arm/_passes/ANALYSIS_expensive_transposes.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,320 @@
# Analysis: Expensive Transposes in Control Ceres Model

## Executive Summary

The Control Ceres model has expensive transpose operations that Vela implements as long sequences of `NPU_OP_POOL` (1x1 AvgPool) operations on Ethos-U55. These transposes are **NOT** fused by the current passes because they surround `Reshape` (view_copy) operations.

## Key Finding: Reshape Requires Transposes for NCHW ↔ NHWC Conversion

The most expensive transposes in the model are generated by patterns involving `Reshape` operations:

```
Pattern: Transpose → Reshape → Transpose
[1,2,14,72] → [1,28,1,72] → [1,1,72,28]
(NHWC→NCHW) (reshape) (NCHW→NHWC)
```

The transposes are **required** because:
1. The TOSA/Vela backend requires NHWC layout for Conv2D operations
2. The PyTorch model uses NCHW layout internally
3. `view_copy` (Reshape) changes tensor dimensions and requires consistent layout

## Expensive Transpose Inventory

### Highest Priority (Around Reshape Operations)

| Transpose ID | Tensor Size | % Time | Cycles | Pattern |
|--------------|-------------|--------|--------|---------|
| `tosa_transpose_default_8` | 54,048 bytes | 2.51% | 252 | T→Reshape→**T** |
| `tosa_transpose_default_7` | 54,048 bytes | 2.51% | 252 | **T**→Reshape→T |
| `tosa_transpose_default_9` | 53,984 bytes | 2.51% | 756 | Rescale→**T**→Reshape |
| `tosa_transpose_default_10` | 53,984 bytes | 2.51% | 252 | Reshape→**T** |
| `tosa_transpose_default_11` | 51,680 bytes | 2.40% | 126 | Reshape→**T** |
| `tosa_transpose_default_12` | 51,616 bytes | 2.40% | 378 | **T**→Reshape |

**Total time for these 6 transposes: ~15% of model execution**

### Medium Priority (Other Transposes)

| Transpose ID | Tensor Size | % Time | Cycles |
|--------------|-------------|--------|--------|
| `tosa_transpose_default_4` | 49,664 bytes | 2.31% | 189 |
| `tosa_transpose_default_3` | 47,008 bytes | 2.18% | 95 |
| `tosa_transpose_default_2` | 44,992 bytes | 2.09% | 95 |
| `tosa_transpose_default_1` | 40,960 bytes | 1.90% | 95 |
| `tosa_transpose_default_6` | 33,312 bytes | 1.55% | 178 |
| `tosa_transpose_default_5` | 32,320 bytes | 1.50% | 224 |
| `tosa_transpose_default` | 32,416 bytes | 1.50% | 133 |

### Lower Priority (Final Output Transposes)

| Transpose ID | Tensor Size | % Time | Cycles | Location |
|--------------|-------------|--------|--------|----------|
| `tosa_transpose_default_13` | 28,032 bytes | 1.30% | 62 | Near end |
| `tosa_transpose_default_18` | 24,160 bytes | 1.12% | 252 | Final |
| `tosa_transpose_default_17` | 21,648 bytes | 1.00% | 126 | Final |
| `tosa_transpose_default_16` | 17,616 bytes | 0.82% | 126 | Final |
| `tosa_transpose_default_15` | 9,552 bytes | 0.44% | 126 | Final |
| `tosa_transpose_default_14` | 8,864 bytes | 0.41% | 66 | Final |

## Graph Pattern Analysis

### Pattern 1: Around Reshape Operations (indices ~797-802)
```
Transpose (permute_copy_6)
Clamp (aten_clamp_default_2)
Rescale (tosa_rescale_default_24)
Transpose (tosa_transpose_default_7) ← EXPENSIVE
Reshape (aten_view_copy_default_2)
Transpose (tosa_transpose_default_8) ← EXPENSIVE
```

### Pattern 2: Rescale → Transpose → Reshape (indices ~838-841)
```
Rescale (tosa_rescale_default_37)
Transpose (tosa_transpose_default_9) ← EXPENSIVE
Reshape (aten_view_copy_default_5)
Transpose (tosa_transpose_default_10) ← EXPENSIVE
```

### Pattern 3: Clamp → Rescale → Reshape → Transpose (indices ~856-859)
```
Clamp (aten_clamp_default_3)
Rescale (tosa_rescale_default_45)
Reshape (aten_view_copy_default_8)
Transpose (tosa_transpose_default_11) ← EXPENSIVE
```

## Investigation Questions

### 1. Are Order Annotations Correct?
- **Question**: Are the TOSA transposes using the correct permutation orders?
- **Investigation**: Check `ToTosaMemoryFormatPass` to verify permutation logic

### 2. Where Are Transposes Inserted?
- **Source**: `ToTosaMemoryFormatPass` in `executorch/backends/arm/_passes/to_tosa_memory_format_pass.py`
- **Purpose**: Convert NCHW (PyTorch default) to NHWC (TOSA/Ethos requirement)

### 3. Can We Eliminate Earlier?
- **Option A**: Modify model to use NHWC throughout (training change)
- **Option B**: Fuse Transpose through Reshape mathematically
- **Option C**: Handle reshape in NHWC space directly

## Root Cause Analysis: Why FuseTransposeReshapeTransposePass Doesn't Help

### Finding: The Reshapes Are Not Simple Dimension Combinations

The `FuseTransposeReshapeTransposePass` can only fuse patterns where the reshape is a simple dimension combination/split (e.g., `[1, 2, 14, 72]` → `[1, 28, 72]` which combines dims 1,2 into dim 1).

However, the expensive transposes in Control Ceres have **complex reshapes that reorder dimensions**:

```
tosa_transpose_default_7: OFM: [1, 2, 14, 72]
reshape (view_copy_2): [1, 2, 14, 72] → [1, 1, 72, 28] ← This is NOT a simple combine/split!
tosa_transpose_default_8: IFM: [1, 1, 72, 28]
```

The reshape `[1, 2, 14, 72]` → `[1, 1, 72, 28]` involves:
- Combining `2 * 14 = 28`
- But also **moving** the 72 channel dimension to a different position

This is equivalent to:
1. Flatten: `[1, 2, 14, 72]` → `[1, 2016]` (total elements: 2016)
2. Reshape: `[1, 2016]` → `[1, 1, 72, 28]`

The `_get_shape_indices` function returns `None` for such reshapes, so the fusion is skipped.

### Implication

**The transposes cannot be fused with the current approach** because the reshape involves both dimension combining AND reordering. These transposes are mathematically necessary for the reshape to work correctly.

### Possible Alternative Strategies

1. **Modify the model architecture** to avoid such reshape patterns
2. **Use NHWC-native operations** in the model to eliminate the need for transposes
3. **Investigate Vela optimizations** to make transposes more efficient
4. **Create a more sophisticated fusion pass** that can handle arbitrary reshapes (complex mathematical analysis required)

## Next Steps

1. [x] Read `ToTosaMemoryFormatPass` to understand transpose insertion logic
2. [ ] Identify where reshapes are created in the model
3. [ ] Investigate if `Transpose → Reshape → Transpose` can be mathematically fused
4. [ ] Check if the model can use NHWC-compatible reshape shapes
5. [ ] Consider creating `FuseTransposeThroughReshapePass` if mathematically feasible

---

## 🚀 New Strategy: Compile-Time Transpose Folding for Static Tensors

### Critical Finding: FuseConstantArgsPass SKIPS Transposes

The `FuseConstantArgsPass` (lines 142-148 in `fuse_constant_ops_pass.py`) **explicitly SKIPS** `TRANSPOSE.default` operations:

```python
if node.target in [
exir_ops.backend.tosa.MATMUL.default,
exir_ops.backend.tosa.RESCALE.default,
exir_ops.backend.tosa.RESIZE.default,
exir_ops.backend.tosa.TABLE.default,
exir_ops.backend.tosa.TRANSPOSE.default, # <-- SKIPPED!
]:
continue
```

This means that even when a transpose operates on a static tensor (weight/constant), it is NOT folded at compile time.

### Why Transposes Are Currently Not Folded

The comment history doesn't explain why transposes were excluded. Possible reasons:
1. **Concern about tensor size increase** - But transposes preserve tensor size
2. **Special handling needed for shape metadata** - Transposing changes `tosa_dim_order`
3. **Simply not implemented yet**

### Proposed Solution: FoldConstantTransposePass

Create a new pass that specifically folds transposes on constant/static tensors at compile time.

#### How It Would Work

1. **Identify transpose nodes** on static tensors (parameters, buffers, lifted tensor constants)
2. **Actually permute the tensor data** at compile time using `tensor.permute(order).contiguous()`
3. **Create a new constant placeholder** with the permuted data
4. **Remove the transpose node** and rewire users to the new constant

#### Pattern Before:
```
static_weight (placeholder) -> TRANSPOSE [0,2,3,1] -> Conv2D
```

#### Pattern After:
```
static_weight_nhwc (placeholder, data already permuted) -> Conv2D
```

### Example Implementation (Conceptual)

```python
class FoldConstantTransposePass(ArmPass):
"""Folds transposes on static tensors at compile time."""

def __init__(self, exported_program: ExportedProgram, *args, **kwargs):
super().__init__(*args, **kwargs)
self.exported_program = exported_program

def call(self, graph_module):
modified = False
for node in list(graph_module.graph.nodes):
if node.target != exir_ops.backend.tosa.TRANSPOSE.default:
continue

input_node = node.all_input_nodes[0]
if not is_param_node(self.exported_program, input_node):
continue # Not a static tensor

# Get the static tensor data
tensor = get_param_tensor(self.exported_program, input_node)
perm = node.args[1] # Permutation order

# Actually permute the data at compile time
permuted_tensor = tensor.permute(perm).contiguous()

# Create new constant placeholder with permuted data
with graph_module.graph.inserting_before(input_node):
const_node = create_constant_placeholder(
self.exported_program,
graph=graph_module.graph,
kind=get_constant_placeholder_kind(self.exported_program, input_node),
name=f"{input_node.name}_permuted",
data=permuted_tensor,
persistent_buffer=is_persistent_buffer(self.exported_program, input_node),
)

# Update metadata
const_node.meta["tosa_dim_order"] = node.meta.get("tosa_dim_order", tuple(range(len(perm))))

# Rewire users and remove transpose
node.replace_all_uses_with(const_node)
modified = True

if modified:
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, modified)
```

### Impact Analysis

#### Transposes That Could Be Folded:
- Weight transposes for Conv2D (NCHW→NHWC)
- Constant tensor transposes for MatMul
- Lifted tensor constants used in reshape patterns

#### Transposes That CANNOT Be Folded:
- Transposes on dynamic activations (runtime data)
- Transposes on model inputs

### Precedent: RewriteConvPass._reshape_weights()

The ARM backend already has precedent for compile-time weight reordering in `RewriteConvPass._reshape_weights()` (lines 115-161):

```python
def _reshape_weights(self, weight_node: torch.fx.Node, in_channels: int) -> None:
weight_tensor = get_param_tensor(self.exported_program, weight_node)

reshaped_weight_tensor = (
weight_tensor.permute(HWCM_ORDER)
.reshape(...)
.permute(NHWC_INVERSE_ORDER)
)

# Update state_dict with permuted tensor
self.exported_program.state_dict[param_name] = reshaped_weight_tensor
```

### Questions to Investigate

1. **Why was TRANSPOSE.default excluded from FuseConstantArgsPass?**
- Search for git history or comments

2. **Does folding transposes break any downstream passes?**
- `ToTosaMemoryFormatPass` annotations
- TOSA serialization

3. **Are there edge cases?**
- Transpose Conv2D weights (special handling at line 430)
- Multi-user constant nodes

### Next Steps for Implementation

1. [ ] Search for why transposes were excluded from FuseConstantArgsPass
2. [ ] Create FoldConstantTransposePass
3. [ ] Add to pass pipeline AFTER FuseConstantArgsPass
4. [ ] Test on Control Ceres model to measure impact
5. [ ] Measure Vela cycle reduction

## Related Files

- `/data/users/eliamesefe/fbsource/fbcode/executorch/backends/arm/_passes/to_tosa_memory_format_pass.py`
- `/data/users/eliamesefe/fbsource/fbcode/executorch/backends/arm/_passes/arm_pass_manager.py`
- `/data/users/eliamesefe/fbsource/fbcode/executorch/backends/arm/_passes/fuse_transpose_sandwich_pass.py`
- `/data/users/eliamesefe/fbsource/fbcode/executorch/backends/arm/_passes/propagate_transposes_through_rescale_pass.py`
- `/data/users/eliamesefe/fbsource/fbcode/executorch/backends/arm/_passes/fuse_constant_ops_pass.py` (FuseConstantArgsPass)
- `/data/users/eliamesefe/fbsource/fbcode/executorch/backends/arm/_passes/rewrite_conv_pass.py` (_reshape_weights precedent)

---

**Date**: 2026-03-06
**Author**: Eli Amesefe
**Related Work**: Transpose fusion passes for Ethos-U55 optimization
Loading
Loading