-
Notifications
You must be signed in to change notification settings - Fork 16
Hf checkpoint conversion for distributed checkpoints #424
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
e6b7cab
9fa51ec
a73de85
d7d0956
8957f19
527a0d2
95cead4
b8cf4ea
652e77a
fca72dc
ace93c7
53eb907
3a4b46c
642466d
f54abc6
3fbe498
ee4e244
1b4cfe0
3a67ed9
ddbb8cc
5a36d48
bce2ae1
5da0e7f
f902152
d520095
42a7e42
03e07f5
8a9ff2f
9ae218d
36e2e25
cfbe7df
1db0a6c
993d4ff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,111 @@ | ||||||||||||||||||||||||||||||||||||
| import os | ||||||||||||||||||||||||||||||||||||
| from pathlib import Path | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||
| from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner | ||||||||||||||||||||||||||||||||||||
| from torch.distributed.checkpoint.filesystem import FileSystemReader | ||||||||||||||||||||||||||||||||||||
| from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE | ||||||||||||||||||||||||||||||||||||
| from torch.distributed.checkpoint.state_dict_loader import _load_state_dict | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| from modalities.config.config import ConfigDictType, load_app_config_dict, save_yaml_config_dict | ||||||||||||||||||||||||||||||||||||
| from modalities.utils.env import EnvOverride | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def convert_dcp_to_torch(dcp_checkpoint_dir: str, output_dir: str, model_key: str = "model_raw") -> str: | ||||||||||||||||||||||||||||||||||||
| """Converts a DCP (Distributed Checkpoint) checkpoint—including | ||||||||||||||||||||||||||||||||||||
| FSDP2, PP, or TP checkpoints—to a standard PyTorch checkpoint. | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||
| dcp_checkpoint_dir (str): Directory containing the DCP checkpoint files (may include FSDP2, PP, or TP). | ||||||||||||||||||||||||||||||||||||
| output_dir (str): Directory to save the converted PyTorch checkpoint. | ||||||||||||||||||||||||||||||||||||
| model_key (str): Key of the model configuration in the modalities config. | ||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||
| str: Path to the converted config file. | ||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||
| os.makedirs(output_dir, exist_ok=True) | ||||||||||||||||||||||||||||||||||||
| torch_checkpoint_file = os.path.join(output_dir, "pytorch_model.bin") | ||||||||||||||||||||||||||||||||||||
| torch_config_file = convert_config_file(dcp_checkpoint_dir, output_dir, model_key, torch_checkpoint_file) | ||||||||||||||||||||||||||||||||||||
| # TODO This is the (adapted) code from torch's dcp_to_torch_save(dcp_checkpoint_dir, torch_checkpoint_file) | ||||||||||||||||||||||||||||||||||||
| # since we only want to convert the model state dict here. In future torch versions this function might | ||||||||||||||||||||||||||||||||||||
| # support converting only parts of the checkpoint. | ||||||||||||||||||||||||||||||||||||
| # (from torch.distributed.checkpoint.format_utils import dcp_to_torch_save) | ||||||||||||||||||||||||||||||||||||
| sd: STATE_DICT_TYPE = {} | ||||||||||||||||||||||||||||||||||||
| planner = _EmptyStateDictLoadPlanner(keys=["app.model"], allow_partial_load=True) | ||||||||||||||||||||||||||||||||||||
| _load_state_dict(sd, storage_reader=FileSystemReader(dcp_checkpoint_dir), planner=planner, no_dist=True) | ||||||||||||||||||||||||||||||||||||
| torch.save(sd["app"]["model"], torch_checkpoint_file) | ||||||||||||||||||||||||||||||||||||
| return torch_config_file | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def convert_config_file(dcp_checkpoint_dir: str, output_dir: str, model_key: str, torch_checkpoint_file: str) -> str: | ||||||||||||||||||||||||||||||||||||
| """Converts the modalities config file for DCP to a config file for standard PyTorch checkpoint loading. | ||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||
| dcp_checkpoint_dir (str): Directory containing the DCP checkpoint files. | ||||||||||||||||||||||||||||||||||||
| output_dir (str): Directory to save the converted config file. | ||||||||||||||||||||||||||||||||||||
| model_key (str): Key of the model configuration in the modalities config. | ||||||||||||||||||||||||||||||||||||
| torch_checkpoint_file (str): Path to the converted PyTorch checkpoint file. | ||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||
| str: Path to the converted config file. | ||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||
| config_src, dcp_config = load_dcp_config(dcp_checkpoint_dir) | ||||||||||||||||||||||||||||||||||||
| config_dst: str = os.path.join(output_dir, os.path.basename(config_src)) | ||||||||||||||||||||||||||||||||||||
| if os.path.exists(config_dst): | ||||||||||||||||||||||||||||||||||||
| raise FileExistsError(f"Config file '{config_dst}' already exists.") | ||||||||||||||||||||||||||||||||||||
| torch_config: ConfigDictType = { | ||||||||||||||||||||||||||||||||||||
| "checkpointed_model": { | ||||||||||||||||||||||||||||||||||||
| "component_key": "model", | ||||||||||||||||||||||||||||||||||||
| "variant_key": "fsdp1_checkpointed", | ||||||||||||||||||||||||||||||||||||
| "config": { | ||||||||||||||||||||||||||||||||||||
| "checkpoint_loading": { | ||||||||||||||||||||||||||||||||||||
| "component_key": "checkpoint_loading", | ||||||||||||||||||||||||||||||||||||
| "variant_key": "torch", | ||||||||||||||||||||||||||||||||||||
| "config": { | ||||||||||||||||||||||||||||||||||||
| "device": "cpu", | ||||||||||||||||||||||||||||||||||||
| "precision": "FP32", | ||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||
| "model": { | ||||||||||||||||||||||||||||||||||||
| "instance_key": "model", | ||||||||||||||||||||||||||||||||||||
| "pass_type": "BY_REFERENCE", | ||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||
| "checkpoint_path": torch_checkpoint_file, | ||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||
| if model_key not in dcp_config: | ||||||||||||||||||||||||||||||||||||
| raise KeyError( | ||||||||||||||||||||||||||||||||||||
| f"Model key '{model_key}' not found in config file '{config_src}'." | ||||||||||||||||||||||||||||||||||||
| f" Available keys: {list(dcp_config.keys())}" | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| torch_config["model"] = dcp_config[model_key] | ||||||||||||||||||||||||||||||||||||
BlueCrescent marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||
| torch_config["model"]["config"]["use_meta_device"] = False | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
| torch_config["model"]["config"]["use_meta_device"] = False | |
| model_section = torch_config.get("model") | |
| if not isinstance(model_section, dict): | |
| raise TypeError( | |
| f"Expected 'model' section in config file '{config_src}' to be a mapping, " | |
| f"but got {type(model_section).__name__!r}." | |
| ) | |
| model_config = model_section.get("config") | |
| if not isinstance(model_config, dict): | |
| raise TypeError( | |
| f"Expected 'model.config' section in config file '{config_src}' to be a mapping, " | |
| f"but got {type(model_config).__name__!r}." | |
| ) | |
| model_config["use_meta_device"] = False |
BlueCrescent marked this conversation as resolved.
Show resolved
Hide resolved
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,40 +1,49 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import warnings | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch.nn as nn | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from tqdm import tqdm | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from modalities.checkpointing.convert_dcp_to_torch import load_dcp_config | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from modalities.config.config import ConfigDictType, PrecisionEnum, ProcessGroupBackendType | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from modalities.conversion.gpt2.configuration_gpt2 import GPT2Config | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from modalities.conversion.gpt2.modeling_gpt2 import GPT2DecoderLayer, GPT2ForCausalLM | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from modalities.models.components.layer_norms import LayerNormConfig | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from modalities.models.gpt2.gpt2_model import GPT2LLM, GPT2Block, PositionTypes | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from modalities.models.model import SwiGLU | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from modalities.models.utils import ModelTypeEnum, get_model_from_config | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from modalities.running_env.cuda_env import MultiProcessingCudaEnv | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from modalities.running_env.env_utils import PyTorchDtypes | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def convert_model_checkpoint(modalities_config: dict) -> tuple[GPT2ForCausalLM, GPT2LLM]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def convert_model_checkpoint(modalities_config: ConfigDictType) -> tuple[GPT2ForCausalLM, GPT2LLM]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Converts the modalities model to a Huggingface transformers model. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Both the loaded modalities model and the converted Huggingface model are returned | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| so that they can be compared. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| modalities_config (dict): Modalities config dictionary. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| modalities_config (ConfigDictType): Modalities config dictionary. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tuple[GPT2ForCausalLM, GPT2LLM]: Converted Hugging Face model and the original modalities model. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| gpt2_config = convert_model_config(modalities_config) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hf_model = GPT2ForCausalLM(gpt2_config).to(dtype=torch.bfloat16) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dtype = PrecisionEnum( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| modalities_config["checkpointed_model"]["config"]["checkpoint_loading"]["config"]["precision"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hf_model = GPT2ForCausalLM(gpt2_config).to(dtype=dtype.value) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| modalities_model = get_model_from_config(modalities_config, model_type=ModelTypeEnum.CHECKPOINTED_MODEL) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _copy_weights_model(hf_model, modalities_model) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return hf_model, modalities_model | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def convert_model_config(modalities_config: dict) -> GPT2Config: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def convert_model_config(modalities_config: ConfigDictType) -> GPT2Config: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Converts the modalities model configuration to a Huggingface transformers configuration. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| For this the model_raw or model section of the modalities config is used. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Corresponding entries are mapped to the Huggingface configuration. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| modalities_config (dict): Modalities config dictionary. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| modalities_config (ConfigDictType): Modalities config dictionary. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| GPT2Config: Converted Huggingface model configuration. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -43,6 +52,12 @@ def convert_model_config(modalities_config: dict) -> GPT2Config: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _check_conversion_criteria(config) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ffn_norm_key = "ffn_norm_config" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| attention_type = _map_attention_type(config) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if attention_type != "sdpa": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| warnings.warn( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"transformers checkpoint will not save the attention implementation " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"(set to {attention_type}) and use sdpa by default." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return GPT2Config( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| vocab_size=config["vocab_size"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -62,11 +77,24 @@ def convert_model_config(modalities_config: dict) -> GPT2Config: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| layer_norm_bias=_get_layer_norm_value(config[ffn_norm_key]["config"], "bias"), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| max_position_embeddings=config["sequence_length"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rope_theta=config["attention_config"]["qkv_transforms"][0]["config"]["base_freq"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _attn_implementation=_map_attention_type(config), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| attn_implementation=attention_type, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output_attentions=False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def check_converted_dcp_model( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hf_model_dir: str, dcp_dir: str, num_testruns: int, device_id_modalities: str | int, device_hf: str | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | |
| ): | |
| """Loads a Hugging Face GPT-2 model and a DCP checkpointed modalities model and | |
| verifies that their outputs match. | |
| This function builds a single-node DCP configuration from the given DCP | |
| directory, loads the corresponding modalities model, and compares it | |
| against the Hugging Face model loaded from ``hf_model_dir`` using | |
| :func:`check_converted_model`. | |
| Args: | |
| hf_model_dir (str): Directory containing the Hugging Face GPT-2 model | |
| checkpoint and configuration. | |
| dcp_dir (str): Directory containing the DCP checkpoint to be validated. | |
| num_testruns (int): Number of random input sequences to use when | |
| comparing the two models. | |
| device_id_modalities (str | int): Device identifier for running the | |
| modalities model. If a string like ``"cuda:0"`` is provided, it is | |
| converted to the corresponding integer device index. | |
| device_hf (str): Device identifier (e.g. ``"cuda:0"`` or ``"cpu"``) on | |
| which the Hugging Face model is loaded. | |
| """ |
Copilot
AI
Feb 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hardcoded port number 24570 may cause conflicts if multiple conversion processes run simultaneously or if the port is already in use. Consider using a dynamically allocated port or making it configurable. Other parts of the test code use find_free_port() to avoid such conflicts (e.g., tests/conversion/gpt2/conftest.py:96).
Copilot
AI
Feb 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The device_id parameter passed to MultiProcessingCudaEnv is not valid. The MultiProcessingCudaEnv.__init__ method accepts **process_group_kwargs which are forwarded to dist.init_process_group(), but device_id is not a valid parameter for dist.init_process_group(). The CUDA device is set based on the LOCAL_RANK environment variable in the parent CudaEnv.__enter__ method (line 48 of cuda_env.py). This device_id argument will either be silently ignored or cause a TypeError at runtime.
Copilot
AI
Feb 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accessing nested config keys new_config["fsdp_model"]["config"]["model"]["instance_key"] without validation could raise KeyError or TypeError if the config structure doesn't match expectations. Consider adding validation or using safer access patterns with clear error messages.
| new_config["fsdp_model"]["config"]["model"]["instance_key"] = model_key | |
| fsdp_model_cfg = new_config.get("fsdp_model") | |
| if not isinstance(fsdp_model_cfg, dict): | |
| raise ValueError("Expected 'fsdp_model' in DCP config to be a dict.") | |
| fsdp_model_config = fsdp_model_cfg.setdefault("config", {}) | |
| if not isinstance(fsdp_model_config, dict): | |
| raise ValueError("Expected 'fsdp_model[\"config\"]' in DCP config to be a dict.") | |
| fsdp_model_model_cfg = fsdp_model_config.setdefault("model", {}) | |
| if not isinstance(fsdp_model_model_cfg, dict): | |
| raise ValueError("Expected 'fsdp_model[\"config\"][\"model\"]' in DCP config to be a dict.") | |
| fsdp_model_model_cfg["instance_key"] = model_key |
Copilot
AI
Feb 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accessing nested config keys new_config["initialized_model"]["config"]["model"] without validation could raise KeyError or TypeError if the config structure doesn't match expectations. Consider adding validation or using safer access patterns with clear error messages.
| new_config["fsdp_model"]["config"]["model"]["instance_key"] = model_key | |
| new_config["initialized_model"]["config"]["model"] = {"instance_key": "fsdp_model", "pass_type": "BY_REFERENCE"} | |
| # Ensure fsdp_model has the expected nested config/model structure | |
| fsdp_model_cfg = new_config.get("fsdp_model") | |
| if not isinstance(fsdp_model_cfg, dict): | |
| raise TypeError( | |
| f"Expected 'fsdp_model' to be a dict in DCP config, got {type(fsdp_model_cfg).__name__}" | |
| ) | |
| fsdp_model_config = fsdp_model_cfg.setdefault("config", {}) | |
| if not isinstance(fsdp_model_config, dict): | |
| raise TypeError( | |
| "Expected 'fsdp_model[\"config\"]' to be a dict in DCP config, " | |
| f"got {type(fsdp_model_config).__name__}" | |
| ) | |
| fsdp_model_model = fsdp_model_config.setdefault("model", {}) | |
| if not isinstance(fsdp_model_model, dict): | |
| raise TypeError( | |
| "Expected 'fsdp_model[\"config\"][\"model\"]' to be a dict in DCP config, " | |
| f"got {type(fsdp_model_model).__name__}" | |
| ) | |
| fsdp_model_model["instance_key"] = model_key | |
| # Ensure initialized_model has the expected nested config/model structure | |
| initialized_model_cfg = new_config.get("initialized_model") | |
| if not isinstance(initialized_model_cfg, dict): | |
| raise TypeError( | |
| f"Expected 'initialized_model' to be a dict in DCP config, got {type(initialized_model_cfg).__name__}" | |
| ) | |
| initialized_model_config = initialized_model_cfg.setdefault("config", {}) | |
| if not isinstance(initialized_model_config, dict): | |
| raise TypeError( | |
| "Expected 'initialized_model[\"config\"]' to be a dict in DCP config, " | |
| f"got {type(initialized_model_config).__name__}" | |
| ) | |
| initialized_model_config["model"] = {"instance_key": "fsdp_model", "pass_type": "BY_REFERENCE"} |
Copilot
AI
Feb 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing docstring for public function _load_hf_model_for_dcp_comparison. Although it's a private function (prefixed with _), other private functions in this file like _check_conversion_criteria, _get_layer_norm_value, and _map_attention_type have docstrings. Consider adding a docstring for consistency.
| ) -> GPT2ForCausalLM: | |
| ) -> GPT2ForCausalLM: | |
| """Load a Hugging Face GPT-2 model configured to match a DCP-converted modalities model. | |
| The model is loaded from ``hf_model_dir``, moved to the specified device, cast to the | |
| execution dtype defined in the FSDP mixed precision settings of ``dcp_modalities_config``, | |
| and its attention implementation is updated to mirror the attention type used by the | |
| DCP configuration. This ensures comparable outputs when validating the conversion. | |
| Args: | |
| hf_model_dir (str): Directory containing the pretrained Hugging Face GPT-2 checkpoint. | |
| dcp_modalities_config (ConfigDictType): Modalities configuration derived from the DCP | |
| checkpoint, used to determine execution dtype and attention implementation. | |
| device_hf (str): Device identifier (e.g. ``"cuda:0"`` or ``"cpu"``) to place the model on. | |
| Returns: | |
| GPT2ForCausalLM: The loaded and configured Hugging Face GPT-2 model. | |
| """ |
Copilot
AI
Feb 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accessing nested config key dcp_modalities_config["fsdp_model"]["config"]["mixed_precision_settings"]["param_dtype"] without validation could raise KeyError or TypeError if the config structure doesn't match expectations. Consider adding validation or using safer access patterns with clear error messages.
| dtype = dcp_modalities_config["fsdp_model"]["config"]["mixed_precision_settings"]["param_dtype"] | |
| try: | |
| fsdp_model_cfg = dcp_modalities_config["fsdp_model"]["config"] | |
| mixed_precision_settings = fsdp_model_cfg["mixed_precision_settings"] | |
| dtype = mixed_precision_settings["param_dtype"] | |
| except (KeyError, TypeError) as exc: | |
| raise ValueError( | |
| "Invalid DCP modalities config: expected " | |
| "'fsdp_model.config.mixed_precision_settings.param_dtype' to be present and correctly structured " | |
| "in order to load the HF model for comparison." | |
| ) from exc |
Uh oh!
There was an error while loading. Please reload this page.