-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Open
Description
PySDK Version
- PySDK V2 (2.x)
- PySDK V3 (3.x)
Describe the bug
When HyperparameterTuner.tune() receives InputData objects as inputs, it converts them to Channel objects internally but drops the content_type field during conversion. This causes built-in algorithms (e.g., XGBoost) to fail with validate_data_file_path errors because the container doesn't know the data format.
To reproduce
from sagemaker.train.configs import InputData
from sagemaker.train.tuner import HyperparameterTuner
train_input = InputData(
channel_name="train",
data_source="s3://my-bucket/train/train.csv",
content_type="csv", # <-- this gets dropped
)
tuner = HyperparameterTuner(
model_trainer=model_trainer,
objective_metric_name="validation:auc",
hyperparameter_ranges=hyperparameter_ranges,
objective_type="Maximize",
max_jobs=12,
max_parallel_jobs=3,
strategy="Bayesian",
)
tuner.tune(inputs=[train_input])
# All training jobs fail with:
# AlgorithmError: validate_data_file_path(train_path, content_type)
Root Cause
In sagemaker/train/tuner.py, the _create_hyperparameter_tuning_job method converts InputData → Channel without passing content_type:
# tuner.py lines 1362-1373
```python
if isinstance(inp, InputData):
input_data_config.append(Channel(
channel_name=inp.channel_name,
data_source=DataSource(
s3_data_source=S3DataSource(
s3_data_type="S3Prefix",
s3_uri=inp.data_source,
s3_data_distribution_type="FullyReplicated"
)
)
# content_type is missing here!
))Suggested Fix
if isinstance(inp, InputData):
input_data_config.append(Channel(
channel_name=inp.channel_name,
content_type=inp.content_type, # <-- add this
data_source=DataSource(
s3_data_source=S3DataSource(
s3_data_type="S3Prefix",
s3_uri=inp.data_source,
s3_data_distribution_type="FullyReplicated"
)
)
))Workaround
Pass Channel objects directly instead of InputData:
from sagemaker.core.shapes import Channel, DataSource, S3DataSource
train_input = Channel(
channel_name="train",
content_type="csv",
data_source=DataSource(
s3_data_source=S3DataSource(
s3_data_type="S3Prefix",
s3_uri="s3://my-bucket/train/train.csv",
s3_data_distribution_type="FullyReplicated",
)
),
)
tuner.tune(inputs=[train_input]) # works correctlyEnvironment
SageMaker Python SDK version: 3.0.1
Python version: 3.12
Built-in algorithm: XGBoost 1.7-1
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels