Skip to content
Merged
21 changes: 15 additions & 6 deletions sagemaker-core/src/sagemaker/core/workflow/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from _hashlib import HASH as Hash
except ImportError:
import typing

Hash = typing.Any

from sagemaker.core.common_utils import base_from_name
Expand Down Expand Up @@ -227,7 +228,9 @@ def get_processing_code_hash(code: str, source_dir: str, dependencies: List[str]
return None


def get_training_code_hash(entry_point: str, source_dir: str, dependencies: List[str]) -> str:
def get_training_code_hash(
entry_point: str, source_dir: str, dependencies: Optional[str] = None
) -> str:
"""Get the hash of a training step's code artifact(s).

Args:
Expand All @@ -236,9 +239,9 @@ def get_training_code_hash(entry_point: str, source_dir: str, dependencies: List
training
source_dir (str): Path to a directory with any other training source
code dependencies aside from the entry point file
dependencies (str): A list of paths to directories (absolute
or relative) with any additional libraries that will be exported
to the container
dependencies Optional[str]: The relative path within ``source_dir`` to a
``requirements.txt`` file with any additional libraries that
will be exported to the container
Returns:
str: A hash string representing the unique code artifact(s) for the step
"""
Expand All @@ -248,11 +251,17 @@ def get_training_code_hash(entry_point: str, source_dir: str, dependencies: List
if source_dir:
source_dir_url = urlparse(source_dir)
if source_dir_url.scheme == "" or source_dir_url.scheme == "file":
return hash_files_or_dirs([source_dir] + dependencies)
if dependencies:
return hash_files_or_dirs([source_dir] + [dependencies])
else:
return hash_files_or_dirs([source_dir])
elif entry_point:
entry_point_url = urlparse(entry_point)
if entry_point_url.scheme == "" or entry_point_url.scheme == "file":
return hash_files_or_dirs([entry_point] + dependencies)
if dependencies:
return hash_files_or_dirs([entry_point] + [dependencies])
else:
return hash_files_or_dirs([entry_point])
return None


Expand Down
43 changes: 30 additions & 13 deletions sagemaker-core/tests/unit/workflow/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,27 +269,44 @@ def test_get_training_code_hash_with_source_dir(self):
with tempfile.TemporaryDirectory() as temp_dir:
entry_file = Path(temp_dir, "train.py")
entry_file.write_text("print('training')")
requirements_file = Path(temp_dir, "requirements.txt")
requirements_file.write_text("numpy==1.21.0")

result = get_training_code_hash(
entry_point=str(entry_file), source_dir=temp_dir, dependencies=[]
result_no_deps = get_training_code_hash(
entry_point=str(entry_file), source_dir=temp_dir, dependencies=None
)
result_with_deps = get_training_code_hash(
entry_point=str(entry_file), source_dir=temp_dir, dependencies=str(requirements_file)
)

assert result is not None
assert len(result) == 64
assert result_no_deps is not None
assert result_with_deps is not None
assert len(result_no_deps) == 64
assert len(result_with_deps) == 64
assert result_no_deps != result_with_deps

def test_get_training_code_hash_entry_point_only(self):
"""Test get_training_code_hash with entry_point only"""
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
f.write("print('training')")
temp_file = f.name
with tempfile.TemporaryDirectory() as temp_dir:
entry_file = Path(temp_dir, "train.py")
entry_file.write_text("print('training')")
requirements_file = Path(temp_dir, "requirements.txt")
requirements_file.write_text("numpy==1.21.0")

try:
result = get_training_code_hash(entry_point=temp_file, source_dir=None, dependencies=[])
# Without dependencies
result_no_deps = get_training_code_hash(
entry_point=str(entry_file), source_dir=None, dependencies=None
)
# With dependencies
result_with_deps = get_training_code_hash(
entry_point=str(entry_file), source_dir=None, dependencies=str(requirements_file)
)

assert result is not None
assert len(result) == 64
finally:
os.unlink(temp_file)
assert result_no_deps is not None
assert result_with_deps is not None
assert len(result_no_deps) == 64
assert len(result_with_deps) == 64
assert result_no_deps != result_with_deps

def test_get_training_code_hash_s3_uri(self):
"""Test get_training_code_hash with S3 URI returns None"""
Expand Down
Loading