Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend list function #2950

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build_helpers/build_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def run(self) -> None: # type: ignore

class SDistCommand(sdist.sdist):
def run(self) -> None:
if not self.dry_run: # type: ignore
if not self.dry_run:
self.run_command("clean")
run_antlr(self)
sdist.sdist.run(self)
Expand Down
9 changes: 9 additions & 0 deletions hydra/_internal/config_loader_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,15 @@ def _apply_overrides_to_config(overrides: List[Override], cfg: DictConfig) -> No
)
elif override.is_force_add():
OmegaConf.update(cfg, key, value, merge=True, force_add=True)
elif override.is_list_extend():
config_val = OmegaConf.select(cfg, key, throw_on_missing=True)
if not OmegaConf.is_list(config_val):
raise ConfigCompositionException(
"Could not append to config list. The existing value of"
f" '{override.key_or_group}' is {config_val} which is not"
f" a list."
)
config_val.extend(value)
else:
try:
OmegaConf.update(cfg, key, value, merge=True)
Expand Down
8 changes: 8 additions & 0 deletions hydra/_internal/grammar/grammar_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
ChoiceSweep,
Glob,
IntervalSweep,
ListExtensionOverrideValue,
ParsedElementType,
QuotedString,
RangeSweep,
Expand Down Expand Up @@ -399,3 +400,10 @@ def glob(
exclude = [exclude]

return Glob(include=include, exclude=exclude)


def extend_list(*args: Any) -> ListExtensionOverrideValue:
"""
Extends an existing list in the config with the given values.
"""
return ListExtensionOverrideValue(values=list(args))
1 change: 1 addition & 0 deletions hydra/core/override_parser/overrides_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,4 +124,5 @@ def create_functions() -> Functions:
functions.register(name="sort", func=grammar_functions.sort)
functions.register(name="shuffle", func=grammar_functions.shuffle)
functions.register(name="glob", func=grammar_functions.glob)
functions.register(name="extend_list", func=grammar_functions.extend_list)
return functions
8 changes: 8 additions & 0 deletions hydra/core/override_parser/overrides_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Glob,
IntervalSweep,
Key,
ListExtensionOverrideValue,
Override,
OverrideType,
ParsedElementType,
Expand Down Expand Up @@ -190,6 +191,13 @@ def visitOverride(self, ctx: OverrideParser.OverrideContext) -> Override:
value_type = ValueType.RANGE_SWEEP
else:
value_type = ValueType.ELEMENT
if isinstance(value, ListExtensionOverrideValue):
if not override_type == OverrideType.CHANGE:
raise HydraException(
"Trying to use override symbols when extending a list"
)
override_type = OverrideType.EXTEND_LIST
value = value.values

return Override(
type=override_type,
Expand Down
12 changes: 12 additions & 0 deletions hydra/core/override_parser/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ class OverrideType(Enum):
ADD = 2
FORCE_ADD = 3
DEL = 4
EXTEND_LIST = 5


class ValueType(Enum):
Expand Down Expand Up @@ -227,6 +228,11 @@ def match(s: str, globs: List[str]) -> bool:
return res


@dataclass
class ListExtensionOverrideValue:
values: List["ParsedElementType"]


class Transformer:
@staticmethod
def identity(x: ParsedElementType) -> ParsedElementType:
Expand Down Expand Up @@ -286,6 +292,12 @@ def is_force_add(self) -> bool:
"""
return self.type == OverrideType.FORCE_ADD

def is_list_extend(self) -> bool:
"""
:return: True if this override represents appending to a list config value
"""
return self.type == OverrideType.EXTEND_LIST

@staticmethod
def _convert_value(value: ParsedElementType) -> Optional[ElementType]:
if isinstance(value, list):
Expand Down
1 change: 1 addition & 0 deletions news/1547.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add extend_list function to override syntax
49 changes: 48 additions & 1 deletion tests/test_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
from hydra.core.config_search_path import SearchPathQuery
from hydra.core.config_store import ConfigStore
from hydra.core.global_hydra import GlobalHydra
from hydra.errors import ConfigCompositionException, HydraException
from hydra.errors import (
ConfigCompositionException,
HydraException,
OverrideParseException,
)
from hydra.test_utils.test_utils import chdir_hydra_root

chdir_hydra_root()
Expand Down Expand Up @@ -429,6 +433,49 @@ class Config:
compose(config_name="config", overrides=overrides)


@mark.usefixtures("initialize_hydra_no_path")
@mark.parametrize(
("overrides", "expected"),
[
param(
["list_key=extend_list(d, e)"],
{"list_key": ["a", "b", "c", "d", "e"]},
id="extend_list_with_str",
),
param(
["list_key=extend_list([d1, d2])"],
{"list_key": ["a", "b", "c", ["d1", "d2"]]},
id="extend_list_with_list",
),
param(
["list_key=extend_list(d, [e1])", "list_key=extend_list(f)"],
{"list_key": ["a", "b", "c", "d", ["e1"], "f"]},
id="extend_list_twice",
),
param(
["+list_key=extend_list([d1, d2])"],
raises(OverrideParseException),
id="extend_list_with_append_key",
),
],
)
def test_extending_list(
hydra_restore_singletons: Any, overrides: List[str], expected: Any
) -> None:
@dataclass
class Config:
list_key: Any = field(default_factory=lambda: ["a", "b", "c"])

ConfigStore.instance().store(name="config", node=Config)

if isinstance(expected, dict):
cfg = compose(config_name="config", overrides=overrides)
assert cfg == expected
else:
with expected:
compose(config_name="config", overrides=overrides)


@mark.parametrize("override", ["hydra.foo=bar", "hydra.job_logging.foo=bar"])
def test_hydra_node_validated(initialize_hydra_no_path: Any, override: str) -> None:
with raises(ConfigCompositionException):
Expand Down
11 changes: 11 additions & 0 deletions tests/test_hydra_cli_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@
ValueError while evaluating 'choice()': empty choice is not legal""",
id="empty choice",
),
param(
"+key=extend_list(1, 2, 3)",
"""Error parsing override '+key=extend_list(1, 2, 3)'
Trying to use override symbols when extending a list""",
id="plus key extend_list",
),
param(
"key={inner_key=extend_list(1, 2, 3)}",
"no viable alternative at input '{inner_key='",
id="embedded extend_list",
),
param(
["+key=choice(choice(a,b))", "-m"],
"""Error parsing override '+key=choice(choice(a,b))'
Expand Down
57 changes: 57 additions & 0 deletions tests/test_overrides_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Glob,
IntervalSweep,
Key,
ListExtensionOverrideValue,
Override,
OverrideType,
Quote,
Expand Down Expand Up @@ -174,6 +175,21 @@ def test_element(value: str, expected: Any) -> None:
ChoiceSweep(list=[1, 2, 3], shuffle=True),
id="shuffle(choice(1,2,3))",
),
param(
"extend_list(1,2,three)",
ListExtensionOverrideValue(values=[1, 2, "three"]),
id="extend_list(1,2,three)",
),
param(
"extend_list('5')",
ListExtensionOverrideValue(values=["5"]),
id="extend_list('5')",
),
param(
"extend_list([1,2,3], {a:1, b:2})",
ListExtensionOverrideValue(values=[[1, 2, 3], {"a": 1, "b": 2}]),
id="extend_list([1,2,3], {a:1, b:2})",
),
],
)
def test_value(value: str, expected: Any) -> None:
Expand Down Expand Up @@ -523,6 +539,15 @@ def test_interval_sweep(value: str, expected: Any) -> None:
raises(HydraException, match=re.escape("mismatched input '/'")),
id="error:dollar_in_group",
),
param(
"override",
"+key=extend_list(foobar)",
raises(
HydraException,
match=re.escape("Trying to use override symbols when extending a list"),
),
id="error:plus_in_extend_list_key",
),
],
)
def test_parse_errors(rule: str, value: str, expected: Any) -> None:
Expand Down Expand Up @@ -997,6 +1022,38 @@ def test_override(
assert ret == expected


@mark.parametrize(
"value,expected_key,expected_value",
[
param(
"key=extend_list([1,2])",
"key",
[[1, 2]],
id="extend_list_of_list",
),
param(
"key=extend_list(1,2,3)",
"key",
[1, 2, 3],
id="extend_list_with_multiple_vals",
),
],
)
def test_list_extend_override(
value: str,
expected_key: str,
expected_value: Any,
) -> None:
test_override(
"",
value,
OverrideType.EXTEND_LIST,
expected_key,
expected_value,
ValueType.ELEMENT,
)


def test_deprecated_name_package(hydra_restore_singletons: Any) -> None:
msg = (
"In override key@_name_=value: _name_ keyword is deprecated in packages, "
Expand Down
2 changes: 2 additions & 0 deletions website/docs/advanced/override_grammar/basic.md
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,8 @@ foo=[1,2,3]
nested=[a,[b,[c]]]
```

Lists are assigned, not merged. To extend an existing list, use the [`extend_list` function](extended.md#extending-lists).

### Dictionaries
```python
foo={a:10,b:20}
Expand Down
Loading
Loading