Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## [Unreleased]

- Added suport to load arbritary yamls given a dataclass definition
- Added custom Omegaconf resolver for `${git_hash:}`
- Added Omegaconf resolvers when parsing yamls; added custom resolver for `${now:}`
- Support `URI / URI_LIKE`, similar to `pathlib.Path`.

Expand Down
28 changes: 28 additions & 0 deletions docs/user_guide/config_guides/task_config_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,31 @@ tomorrow_plus_5_hours_30_min_15_sec: "2023-12-16 20:00:37"
next_week: "2023-12-22"
multiple_args: "20231130:20231214"
```

### Git Hash Resolver

This resolver returns the current git hash if one is available. Takes no arguments and returns the git hash as a string.
Specifically this returns the SHA that is returned when the following is run in the active working directory:

```bash
git rev-parse HEAD
```

If no git repository is found, or there is an error, will return empty string.

Examples:

```yaml
experiment:
commit: "${git_hash:}"
model_version: "model_${git_hash:}"
Comment thread
svij-sc marked this conversation as resolved.
```

Assuming you are scheduling workflows from an active git repo with the current commit hash:
`9d42b423b65961692ffc650a0714a63a1b695b12`, this would resolve:

```yaml
experiment:
commit: "9d42b423b65961692ffc650a0714a63a1b695b12"
model_version: "model_9d42b423b65961692ffc650a0714a63a1b695b12"
```
49 changes: 47 additions & 2 deletions python/gigl/common/omegaconf_resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
files to provide dynamic values during configuration loading.
"""

import subprocess
from datetime import datetime, timedelta

from omegaconf import OmegaConf
Expand Down Expand Up @@ -120,16 +121,60 @@ def now_resolver(*args: str) -> str:
return target_time.strftime(format_str)


def git_hash_resolver() -> str:
"""Resolver that returns the current git hash.

This resolver returns the current git hash if one is available.
Takes no arguments and returns the git hash as a string.

Returns:
Current git hash as a string, or empty string if not available.

Example:
In YAML config:
```yaml
model_version: "model_${git_hash}"
experiment_id: "exp_${git_hash}_${now:%Y%m%d}"
```
"""
try:
result = subprocess.run(
["git", "rev-parse", "HEAD"],
capture_output=True,
text=True,
check=True,
timeout=5,
)
return result.stdout.strip()
except (
subprocess.CalledProcessError,
subprocess.TimeoutExpired,
FileNotFoundError,
):
logger.info(
"Could not retrieve git hash - git command failed or not in a git repository"
)
return ""


def register_resolvers() -> None:
"""Register all custom OmegaConf resolvers.

This function should be called at application startup to register
all custom resolvers with OmegaConf.
"""
logger.info("Registering OmegaConf resolvers")
if not OmegaConf.has_resolver("now"):
logger.info("Registering OmegaConf resolver 'now'")
Comment thread
svij-sc marked this conversation as resolved.
OmegaConf.register_new_resolver("now", now_resolver)
else:
logger.info(
logger.debug(
"OmegaConf resolver 'now' already registered, skipping registration"
)

if not OmegaConf.has_resolver("git_hash"):
logger.info("Registering OmegaConf resolver 'git_hash'")
OmegaConf.register_new_resolver("git_hash", git_hash_resolver)
else:
logger.debug(
"OmegaConf resolver 'git_hash' already registered, skipping registration"
)
27 changes: 27 additions & 0 deletions python/gigl/common/utils/yaml_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Type, TypeVar, cast

from omegaconf import OmegaConf

from gigl.common import Uri
from gigl.common.logger import Logger
from gigl.common.omegaconf_resolvers import register_resolvers
from gigl.src.common.utils.file_loader import FileLoader

logger = Logger()

T = TypeVar("T")

register_resolvers()
Comment thread
svij-sc marked this conversation as resolved.


def load_resolved_yaml(uri: Uri, type_of_object: Type[T]) -> T:
with FileLoader().load_to_temp_file(uri) as tf:
test_spec_data = OmegaConf.load(tf.name)

# Merge OmegaConf structured config with loaded data for validation
merged_config = OmegaConf.merge(
OmegaConf.structured(type_of_object), test_spec_data
)

# Convert to strongly typed T object
return cast(T, OmegaConf.to_object(merged_config))
Comment thread
svij-sc marked this conversation as resolved.
55 changes: 54 additions & 1 deletion python/tests/unit/common/test_omegaconf_resolvers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import subprocess
import unittest
from datetime import datetime
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import yaml
from omegaconf import OmegaConf
Expand Down Expand Up @@ -71,6 +72,58 @@ def test_now_resolver_with_invalid_format(self):
with self.assertRaises(ValueError):
OmegaConf.create(yaml_config).experiment.name

@patch("gigl.common.omegaconf_resolvers.subprocess.run")
def test_git_hash_resolver_success(self, mock_subprocess_run):
# Mock successful git command
mock_result = MagicMock()
mocked_hash = "abc123def456789012345678901234567890abcd"
mock_result.stdout = mocked_hash
mock_subprocess_run.return_value = mock_result

yaml_config = """
experiment:
commit: "${git_hash:}"
model_version: "model_${git_hash:}"
"""

config = OmegaConf.create(yaml.safe_load(yaml_config))
# Verify subprocess was called with the correct git command

self.assertEqual(config.experiment.commit, mocked_hash)
self.assertEqual(config.experiment.model_version, f"model_{mocked_hash}")
mock_subprocess_run.assert_called()

@patch("gigl.common.omegaconf_resolvers.subprocess.run")
def test_git_hash_resolver_command_not_found(self, mock_subprocess_run):
# Command not found throws a 127 exit code.
mock_subprocess_run.side_effect = subprocess.CalledProcessError(
127, ["git", "rev-parse", "HEAD"], "command not found: git"
)
yaml_config = """
experiment:
commit: "${git_hash:}"
"""

# Should return empty string when git is not available
self.assertEqual(OmegaConf.create(yaml_config).experiment.commit, "")

@patch("gigl.common.omegaconf_resolvers.subprocess.run")
def test_git_hash_resolver_not_git_repo(self, mock_subprocess_run):
# When calling git rev-parse HEAD on a non-git directory, it will return a
# CalledProcessError with exit code 128.
mock_subprocess_run.side_effect = subprocess.CalledProcessError(
128,
["git", "rev-parse", "HEAD"],
"fatal: not a git repository (or any of the parent directories): .git",
)

yaml_config = """
experiment:
commit: "${git_hash:}"
"""

self.assertEqual(OmegaConf.create(yaml_config).experiment.commit, "")


if __name__ == "__main__":
unittest.main()
89 changes: 89 additions & 0 deletions python/tests/unit/common/utils/yaml_loader_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import os
import tempfile
import textwrap
import unittest
from dataclasses import dataclass, field
from datetime import datetime
from typing import List
from unittest.mock import MagicMock, patch

from gigl.common import LocalUri
from gigl.common.utils.yaml_loader import load_resolved_yaml


@dataclass
class _SubConfig:
name: str
value: int
enabled: bool = True
tags: List[str] = field(default_factory=list)
Comment thread
svij-sc marked this conversation as resolved.


@dataclass
class _Complex_TestConfig:
basic_config: _SubConfig
description: str


class YamlLoaderTest(unittest.TestCase):
def setUp(self):
"""Set up test fixtures."""
super().setUp()
self.temp_file = tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", delete=False
)

def tearDown(self):
self.temp_file.close()
os.remove(self.temp_file.name)
Comment thread
svij-sc marked this conversation as resolved.
super().tearDown()

def test_load_resolved_yaml_simple_config(self):
"""Test loading a simple YAML configuration."""

contents = textwrap.dedent(
"""
basic_config:
name: "experiment_${now:%Y%m%d}"
value: 42
enabled: true
tags:
- "tag_${git_hash:}"
- "${basic_config.value}" # resolves to 42
description: "This is a test description"
"""
)
with self.temp_file:
self.temp_file.write(contents)

patch_commit_hash = "1234567890"
patch_datetime = datetime(2023, 12, 15, 14, 30, 22)

expected_result = _Complex_TestConfig(
basic_config=_SubConfig(
name=f"experiment_{patch_datetime.strftime('%Y%m%d')}",
value=42,
enabled=True,
tags=[f"tag_{patch_commit_hash}", "42"],
),
description="This is a test description",
)
with patch(
"gigl.common.omegaconf_resolvers.subprocess.run"
) as mock_subprocess_run, patch(
"gigl.common.omegaconf_resolvers.datetime"
) as mock_datetime:
Comment thread
kmontemayor2-sc marked this conversation as resolved.
mock_result = MagicMock()
mock_result.stdout = patch_commit_hash
mock_subprocess_run.return_value = mock_result
mock_datetime.now.return_value = patch_datetime

uri = LocalUri(self.temp_file.name)
result: _Complex_TestConfig = load_resolved_yaml(uri, _Complex_TestConfig)
self.assertTrue(isinstance(result, _Complex_TestConfig))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit. self.assertIsInstance


self.assertEqual(result, expected_result)


if __name__ == "__main__":
unittest.main()