Skip to content

Commit

Permalink
Refactor for pydantic v2 and better typing (#165)
Browse files Browse the repository at this point in the history
* Refactor for pydantic v2 and better type hints

Some more code modification to replace some deprecated methods in
pydantic v2. I have also improved typing in the code. Some bugs were
found along the way, including the `match` method of the SpinSystem
class.

* Bump package versions
  • Loading branch information
gbouvignies authored Jul 11, 2023
1 parent 706db72 commit 1f55753
Show file tree
Hide file tree
Showing 78 changed files with 599 additions and 681 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,4 @@ Sandbox/

# PDM
.pdm-python
.pdm-build
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ repos:

- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.0.275
rev: v0.0.277
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]

- repo: https://github.com/psf/black
rev: 23.3.0 # Replace by any tag/version: https://github.com/psf/black/tags
rev: 23.7.0 # Replace by any tag/version: https://github.com/psf/black/tags
hooks:
- id: black
language_version: python3.10
24 changes: 16 additions & 8 deletions chemex/configuration/base.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any
from typing import Any

from pydantic import BaseModel, ConfigDict, model_validator

if TYPE_CHECKING:
from collections.abc import MutableMapping

def ensure_list(variable: Any | list[Any] | None) -> list[Any]:
if isinstance(variable, list):
return variable
if variable is None:
return []
return [variable]


def to_lower(string: Any) -> Any:
if isinstance(string, str):
return string.lower()
return string


class BaseModelLowerCase(BaseModel):
model_config = ConfigDict(str_to_lower=True)

@model_validator(mode="before")
@classmethod
def to_lower_case(
cls, values: MutableMapping[str, Any]
) -> MutableMapping[str, Any]:
return {k.lower(): v for k, v in values.items()}
def key_to_lower(cls, model: dict[str, Any]) -> dict[str, Any]:
return {to_lower(k): v for k, v in model.items()}
51 changes: 23 additions & 28 deletions chemex/configuration/conditions.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,35 @@
from __future__ import annotations

from functools import total_ordering
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Annotated, Any, Literal

from pydantic import (
BaseModel,
BeforeValidator,
Field,
PositiveFloat,
ValidationError,
field_validator,
model_validator,
)

from chemex.configuration.base import to_lower
from chemex.models.model import model

if TYPE_CHECKING:
from collections.abc import Hashable, MutableMapping
from collections.abc import Hashable

LabelType = Annotated[Literal["1h", "2h", "13c", "15n"], BeforeValidator(to_lower)]


@total_ordering
class Conditions(BaseModel, frozen=True):
h_larmor_frq: PositiveFloat | None = None
temperature: float | None = None
p_total: PositiveFloat | None = Field(default=None)
l_total: PositiveFloat | None = Field(default=None)
p_total: PositiveFloat | None = None
l_total: PositiveFloat | None = None
d2o: float | None = Field(gt=0.0, lt=1.0, default=None)
label: tuple[Literal["1h", "2h", "13c", "15n"], ...] = ()
label: tuple[LabelType, ...] = ()

def rounded(self) -> Conditions:
h_larmor_frq = round(self.h_larmor_frq, 1) if self.h_larmor_frq else None
Expand All @@ -49,7 +53,7 @@ def search_keys(self) -> set[Hashable]:

@property
def section(self) -> str:
parts = []
parts: list[str] = []
if self.temperature is not None:
parts.append(f"T->{self.temperature:.1f}C")
if self.h_larmor_frq is not None:
Expand All @@ -64,7 +68,7 @@ def section(self) -> str:

@property
def folder(self):
parts = []
parts: list[str] = []
if self.temperature is not None:
parts.append(f"{self.temperature:.1f}C")
if self.h_larmor_frq is not None:
Expand Down Expand Up @@ -112,39 +116,30 @@ def __lt__(self, other: Conditions) -> bool:

@total_ordering
class ConditionsFromFile(Conditions, frozen=True):
@model_validator(mode="before")
def key_to_lower(cls, model: dict[str, Any]) -> dict[str, Any]:
return {to_lower(k): v for k, v in model.items()}

@field_validator("d2o")
@classmethod
def validate_d2o(cls, d2o):
def validate_d2o(cls, d2o: float | None) -> float | None:
if "hd" in model.name and d2o is None:
raise ValidationError()
msg = 'To use the "hd" model, d2o must be provided'
raise ValidationError(msg)
return d2o

@field_validator("temperature")
@classmethod
def validate_temperature(cls, temperature):
def validate_temperature(cls, temperature: float | None) -> float | None:
if "eyring" in model.name and temperature is None:
raise ValidationError()
msg = 'To use the "eyring" model, "temperature" must be provided'
raise ValidationError(msg)
return temperature

@field_validator("label", mode="before")
@classmethod
def set_to_lower_case(cls, label):
return tuple(value.lower() for value in label)

@model_validator(mode="after")
@classmethod
def validate_p_total_l_total(
cls, conditions: ConditionsFromFile
) -> ConditionsFromFile:
are_not_both_set = conditions.p_total is None or conditions.l_total is None
if "binding" in model.name and are_not_both_set:
msg = "Either p_total or l_total must be provided"
raise ValueError(msg)
msg = 'To use the "binding" model, "p_total" and "l_total" must be provided'
raise ValidationError(msg)
return conditions

@model_validator(mode="before")
@classmethod
def set_keys_to_lower_case(
cls, values: MutableMapping[str, Any]
) -> MutableMapping[str, Any]:
return {k.lower(): v for k, v in values.items()}
45 changes: 24 additions & 21 deletions chemex/configuration/data.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations

from pathlib import Path
from typing import Literal
from typing import Annotated, Any, Literal

from pydantic import field_validator
from pydantic import BeforeValidator, field_validator

from chemex.configuration.base import BaseModelLowerCase
from chemex.configuration.base import BaseModelLowerCase, ensure_list
from chemex.parameters.spin_system import PydanticSpinSystem


Expand All @@ -14,36 +14,33 @@ class DataSettings(BaseModelLowerCase):
scaled: bool = True


PathList = Annotated[list[Path], BeforeValidator(ensure_list)]


class RelaxationDataSettings(DataSettings):
error: Literal["file", "duplicates"] = "file"
filter_planes: list[int] = []
profiles: dict[PydanticSpinSystem, Path | list[Path]] = {}
profiles: dict[PydanticSpinSystem, PathList] = {}

@field_validator("profiles", mode="before")
@classmethod
def make_list(cls, v):
if isinstance(v, dict):
for key, value in v.items():
if isinstance(value, str):
v[key] = [value]
return v
@field_validator("error", mode="before")
def to_lower(cls, error: Any) -> str:
if isinstance(error, str):
return error.lower()
return error


class CestDataSettings(DataSettings):
error: Literal["file", "duplicates", "scatter"] = "file"
filter_planes: list[int] = []
filter_offsets: list[tuple[float, float]] = [(0.0, 0.0)]
filter_ref_planes: bool = False
profiles: dict[PydanticSpinSystem, list[Path]] = {}
profiles: dict[PydanticSpinSystem, PathList] = {}

@field_validator("profiles", mode="before")
@classmethod
def make_list(cls, v):
if isinstance(v, dict):
for key, value in v.items():
if isinstance(value, str):
v[key] = [value]
return v
@field_validator("error", mode="before")
def to_lower(cls, error: Any) -> str:
if isinstance(error, str):
return error.lower()
return error


class CestDataSettingsNoRef(CestDataSettings):
Expand All @@ -53,3 +50,9 @@ class CestDataSettingsNoRef(CestDataSettings):
class ShiftDataSettings(DataSettings):
error: Literal["file", "duplicates"] = "file"
scaled: bool = False

@field_validator("error", mode="before")
def to_lower(cls, error: Any) -> str:
if isinstance(error, str):
return error.lower()
return error
10 changes: 4 additions & 6 deletions chemex/configuration/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import sys
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Annotated, Literal

from pydantic import ConfigDict, Field, ValidationError
from pydantic import BeforeValidator, Field, ValidationError
from pydantic.types import PositiveInt

from chemex.configuration.base import BaseModelLowerCase
Expand All @@ -18,7 +18,7 @@


# Type definitions
AllType = Literal["*", "all", "ALL", "All", "ALl", "AlL"]
AllType = Annotated[Literal["*", "all"], BeforeValidator(str.lower)]
SelectionType = list[PydanticSpinSystem] | AllType | None


Expand All @@ -35,8 +35,6 @@ class Selection:


class Method(BaseModelLowerCase):
model_config = ConfigDict(str_to_lower=True)

fitmethod: str = "leastsq"
include: SelectionType = None
exclude: SelectionType = None
Expand All @@ -55,7 +53,7 @@ def selection(self) -> Selection:


def read_methods(filenames: Iterable[Path]) -> Methods:
methods = {}
methods: Methods = {}

for filename in filenames:
methods_dict = read_toml(filename)
Expand Down
71 changes: 30 additions & 41 deletions chemex/configuration/parameters.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,32 @@
from __future__ import annotations

from collections.abc import Iterable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Annotated

from annotated_types import Len
from pydantic import RootModel, field_validator
from pydantic import AfterValidator, BeforeValidator, RootModel

from chemex.configuration.base import ensure_list
from chemex.parameters.name import ParamName
from chemex.toml import read_toml_multi

if TYPE_CHECKING:
from collections.abc import Iterable
from pathlib import Path

# Type definitions
ValueListType = Annotated[list[float], Len(min_length=1, max_length=4)]
DefaultType = tuple[ParamName, "DefaultSetting"]
DefaultListType = list[DefaultType]

def rename_section(section_name: str) -> str:
if section_name == "global":
return ""
return f"{section_name},nuc->"


ValuesType = Annotated[list[float], Len(max_length=4), BeforeValidator(ensure_list)]
LowerCaseString = Annotated[str, BeforeValidator(str.lower)]
ValuesDictType = dict[LowerCaseString, ValuesType]
SectionType = Annotated[LowerCaseString, AfterValidator(rename_section)]
ParamsConfigType = dict[SectionType, ValuesDictType]
ParamsConfigModel = RootModel[ParamsConfigType]


@dataclass(frozen=True)
Expand All @@ -27,42 +37,21 @@ class DefaultSetting:
brute_step: float | None = None


class ParamsConfig(RootModel):
root: dict[str, dict[str, ValueListType]]

@field_validator("root", mode="before")
@classmethod
def to_lower(cls, values):
return {
k1.lower(): {k2.lower(): v2 for k2, v2 in v1.items()}
for k1, v1 in values.items()
}

@field_validator("root", mode="before")
@classmethod
def to_list(cls, values):
for values1 in values.values():
for key2, values2 in values1.items():
if not isinstance(values2, Iterable):
values1[key2] = [values2]
return values

@field_validator("root")
@classmethod
def reorder(cls, values):
return {"global": values.pop("global", {}), **values}

def to_defaults_list(self) -> DefaultListType:
defaults_list: DefaultListType = []
for section, settings in self.root.items():
prefix = f"{section}, NUC->" if section != "global" else ""
for key, values in settings.items():
pname = ParamName.from_section(f"{prefix}{key}")
default_values = DefaultSetting(*values)
defaults_list.append((pname, default_values))
return defaults_list
DefaultType = tuple[ParamName, DefaultSetting]
DefaultListType = list[DefaultType]


def build_default_list(params_config: ParamsConfigModel) -> DefaultListType:
defaults: DefaultListType = []
for section, params in params_config.root.items():
for key, values in params.items():
pname = ParamName.from_section(f"{section}{key}")
default_values = DefaultSetting(*values)
defaults.append((pname, default_values))
return defaults


def read_defaults(filenames: Iterable[Path]) -> DefaultListType:
config = read_toml_multi(filenames)
return ParamsConfig.model_validate(config).to_defaults_list()
param_config = ParamsConfigModel.model_validate(config)
return build_default_list(param_config)
Loading

0 comments on commit 1f55753

Please sign in to comment.