Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Bump pydantic to 2.0 #574

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 14 additions & 18 deletions cashu/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import cbor2
from loguru import logger
from pydantic import BaseModel, root_validator
from pydantic import BaseModel

from cashu.core.json_rpc.base import JSONRPCSubscriptionKinds

Expand Down Expand Up @@ -63,12 +63,13 @@ class ProofState(LedgerEvent):
state: ProofSpentState
witness: Optional[str] = None

@root_validator()
def check_witness(cls, values):
state, witness = values.get("state"), values.get("witness")
if witness is not None and state != ProofSpentState.spent:
raise ValueError('Witness can only be set if the spent state is "SPENT"')
return values
# @model_validator(mode="wrap")
# @classmethod
# def check_witness(cls, values):
# state, witness = values.get("state"), values.get("witness")
# if witness is not None and state != ProofSpentState.spent:
# raise ValueError('Witness can only be set if the spent state is "SPENT"')
# return values

@property
def identifier(self) -> str:
Expand Down Expand Up @@ -132,8 +133,8 @@ class Proof(BaseModel):
reserved: Union[None, bool] = False
# unique ID of send attempt, used for grouping pending tokens in the wallet
send_id: Union[None, str] = ""
time_created: Union[None, str] = ""
time_reserved: Union[None, str] = ""
time_created: Optional[int] = None
time_reserved: Optional[int] = None
derivation_path: Union[None, str] = "" # derivation path of the proof
mint_id: Union[
None, str
Expand Down Expand Up @@ -163,7 +164,7 @@ def to_dict(self, include_dleq=False):
# optional fields
if include_dleq:
assert self.dleq, "DLEQ proof is missing"
return_dict["dleq"] = self.dleq.dict() # type: ignore
return_dict["dleq"] = self.dleq.model_dump() # type: ignore

if self.witness:
return_dict["witness"] = self.witness
Expand Down Expand Up @@ -195,11 +196,6 @@ def htlcpreimage(self) -> Union[str, None]:
return HTLCWitness.from_witness(self.witness).preimage


class Proofs(BaseModel):
# NOTE: not used in Pydantic validation
__root__: List[Proof]


class BlindedMessage(BaseModel):
"""
Blinded message or blinded secret or "output" which is to be signed by the mint
Expand Down Expand Up @@ -806,7 +802,7 @@ def deserialize(cls, tokenv3_serialized: str) -> "TokenV3":
token_base64 += "=" * (4 - len(token_base64) % 4)

token = json.loads(base64.urlsafe_b64decode(token_base64))
return cls.parse_obj(token)
return cls.model_validate(token)

def serialize(self, include_dleq=False) -> str:
"""
Expand Down Expand Up @@ -966,7 +962,7 @@ def from_tokenv3(cls, tokenv3: TokenV3):
return cls(t=cls.t, d=cls.d, m=cls.m, u=cls.u)

def serialize_to_dict(self, include_dleq=False):
return_dict: Dict[str, Any] = dict(t=[t.dict() for t in self.t])
return_dict: Dict[str, Any] = dict(t=[t.model_dump() for t in self.t])
# strip dleq if needed
if not include_dleq:
for token in return_dict["t"]:
Expand Down Expand Up @@ -1013,7 +1009,7 @@ def deserialize(cls, tokenv4_serialized: str) -> "TokenV4":
token_base64 += "=" * (4 - len(token_base64) % 4)

token = cbor2.loads(base64.urlsafe_b64decode(token_base64))
return cls.parse_obj(token)
return cls.model_validate(token)

def to_tokenv3(self) -> TokenV3:
tokenv3 = TokenV3()
Expand Down
2 changes: 1 addition & 1 deletion cashu/core/htlc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class HTLCSecret(Secret):
@classmethod
def from_secret(cls, secret: Secret):
assert SecretKind(secret.kind) == SecretKind.HTLC, "Secret is not a HTLC secret"
# NOTE: exclude tags in .dict() because it doesn't deserialize it properly
# NOTE: exclude tags in .model_dump() because it doesn't deserialize it properly
# need to add it back in manually with tags=secret.tags
return cls(**secret.dict(exclude={"tags"}), tags=secret.tags)

Expand Down
56 changes: 28 additions & 28 deletions cashu/core/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Dict, List, Optional, Union

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, RootModel

from .base import (
BlindedMessage,
Expand Down Expand Up @@ -81,8 +81,8 @@ class KeysetsResponse(BaseModel):
keysets: list[KeysetsResponseKeyset]


class KeysResponse_deprecated(BaseModel):
__root__: Dict[str, str]
class KeysResponse_deprecated(RootModel[Dict[str, str]]):
root: Dict[str, str] = {}


class KeysetsResponse_deprecated(BaseModel):
Expand All @@ -102,16 +102,16 @@ class PostMintQuoteResponse(BaseModel):
request: str # input payment request
paid: Optional[
bool
] # whether the request has been paid # DEPRECATED as per NUT PR #141
] = None # whether the request has been paid # DEPRECATED as per NUT PR #141
state: str # state of the quote
expiry: Optional[int] # expiry of the quote
expiry: Optional[int] = None # expiry of the quote

@classmethod
def from_mint_quote(self, mint_quote: MintQuote) -> "PostMintQuoteResponse":
to_dict = mint_quote.dict()
to_dict = mint_quote.model_dump()
# turn state into string
to_dict["state"] = mint_quote.state.value
return PostMintQuoteResponse.parse_obj(to_dict)
return PostMintQuoteResponse.model_validate(to_dict)


# ------- API: MINT -------
Expand All @@ -120,7 +120,7 @@ def from_mint_quote(self, mint_quote: MintQuote) -> "PostMintQuoteResponse":
class PostMintRequest(BaseModel):
quote: str = Field(..., max_length=settings.mint_max_request_length) # quote id
outputs: List[BlindedMessage] = Field(
..., max_items=settings.mint_max_request_length
..., max_length=settings.mint_max_request_length
)


Expand All @@ -135,7 +135,7 @@ class GetMintResponse_deprecated(BaseModel):

class PostMintRequest_deprecated(BaseModel):
outputs: List[BlindedMessage_Deprecated] = Field(
..., max_items=settings.mint_max_request_length
..., max_length=settings.mint_max_request_length
)


Expand All @@ -151,7 +151,7 @@ class PostMeltRequestOptionMpp(BaseModel):


class PostMeltRequestOptions(BaseModel):
mpp: Optional[PostMeltRequestOptionMpp]
mpp: Optional[PostMeltRequestOptionMpp] = None


class PostMeltQuoteRequest(BaseModel):
Expand Down Expand Up @@ -182,50 +182,50 @@ class PostMeltQuoteResponse(BaseModel):
fee_reserve: int # input fee reserve
paid: bool # whether the request has been paid # DEPRECATED as per NUT PR #136
state: str # state of the quote
expiry: Optional[int] # expiry of the quote
expiry: Optional[int] = None # expiry of the quote
payment_preimage: Optional[str] = None # payment preimage
change: Union[List[BlindedSignature], None] = None

@classmethod
def from_melt_quote(self, melt_quote: MeltQuote) -> "PostMeltQuoteResponse":
to_dict = melt_quote.dict()
to_dict = melt_quote.model_dump()
# turn state into string
to_dict["state"] = melt_quote.state.value
return PostMeltQuoteResponse.parse_obj(to_dict)
return PostMeltQuoteResponse.model_validate(to_dict)


# ------- API: MELT -------


class PostMeltRequest(BaseModel):
quote: str = Field(..., max_length=settings.mint_max_request_length) # quote id
inputs: List[Proof] = Field(..., max_items=settings.mint_max_request_length)
inputs: List[Proof] = Field(..., max_length=settings.mint_max_request_length)
outputs: Union[List[BlindedMessage], None] = Field(
None, max_items=settings.mint_max_request_length
None, max_length=settings.mint_max_request_length
)


class PostMeltResponse_deprecated(BaseModel):
paid: Union[bool, None]
preimage: Union[str, None]
paid: Union[bool, None] = None
preimage: Union[str, None] = None
change: Union[List[BlindedSignature], None] = None


class PostMeltRequest_deprecated(BaseModel):
proofs: List[Proof] = Field(..., max_items=settings.mint_max_request_length)
proofs: List[Proof] = Field(..., max_length=settings.mint_max_request_length)
pr: str = Field(..., max_length=settings.mint_max_request_length)
outputs: Union[List[BlindedMessage_Deprecated], None] = Field(
None, max_items=settings.mint_max_request_length
None, max_length=settings.mint_max_request_length
)


# ------- API: SPLIT -------


class PostSplitRequest(BaseModel):
inputs: List[Proof] = Field(..., max_items=settings.mint_max_request_length)
inputs: List[Proof] = Field(..., max_length=settings.mint_max_request_length)
outputs: List[BlindedMessage] = Field(
..., max_items=settings.mint_max_request_length
..., max_length=settings.mint_max_request_length
)


Expand All @@ -235,10 +235,10 @@ class PostSplitResponse(BaseModel):

# deprecated since 0.13.0
class PostSplitRequest_Deprecated(BaseModel):
proofs: List[Proof] = Field(..., max_items=settings.mint_max_request_length)
proofs: List[Proof] = Field(..., max_length=settings.mint_max_request_length)
amount: Optional[int] = None
outputs: List[BlindedMessage_Deprecated] = Field(
..., max_items=settings.mint_max_request_length
..., max_length=settings.mint_max_request_length
)


Expand All @@ -256,15 +256,15 @@ class PostSplitResponse_Very_Deprecated(BaseModel):


class PostCheckStateRequest(BaseModel):
Ys: List[str] = Field(..., max_items=settings.mint_max_request_length)
Ys: List[str] = Field(..., max_length=settings.mint_max_request_length)


class PostCheckStateResponse(BaseModel):
states: List[ProofState] = []


class CheckSpendableRequest_deprecated(BaseModel):
proofs: List[Proof] = Field(..., max_items=settings.mint_max_request_length)
proofs: List[Proof] = Field(..., max_length=settings.mint_max_request_length)


class CheckSpendableResponse_deprecated(BaseModel):
Expand All @@ -277,21 +277,21 @@ class CheckFeesRequest_deprecated(BaseModel):


class CheckFeesResponse_deprecated(BaseModel):
fee: Union[int, None]
fee: Union[int, None] = None


# ------- API: RESTORE -------


class PostRestoreRequest(BaseModel):
outputs: List[BlindedMessage] = Field(
..., max_items=settings.mint_max_request_length
..., max_length=settings.mint_max_request_length
)


class PostRestoreRequest_Deprecated(BaseModel):
outputs: List[BlindedMessage_Deprecated] = Field(
..., max_items=settings.mint_max_request_length
..., max_length=settings.mint_max_request_length
)


Expand Down
4 changes: 2 additions & 2 deletions cashu/core/p2pk.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ class P2PKSecret(Secret):
@classmethod
def from_secret(cls, secret: Secret):
assert SecretKind(secret.kind) == SecretKind.P2PK, "Secret is not a P2PK secret"
# NOTE: exclude tags in .dict() because it doesn't deserialize it properly
# NOTE: exclude tags in .model_dump() because it doesn't deserialize it properly
# need to add it back in manually with tags=secret.tags
return cls(**secret.dict(exclude={"tags"}), tags=secret.tags)
return cls(**secret.model_dump(exclude={"tags"}), tags=secret.tags)

def get_p2pk_pubkey_from_secret(self) -> List[str]:
"""Gets the P2PK pubkey from a Secret depending on the locktime.
Expand Down
30 changes: 15 additions & 15 deletions cashu/core/secret.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Dict, List, Optional, Union

from loguru import logger
from pydantic import BaseModel
from pydantic import BaseModel, RootModel

from .crypto.secp import PrivateKey

Expand All @@ -13,39 +13,39 @@ class SecretKind(Enum):
HTLC = "HTLC"


class Tags(BaseModel):
class Tags(RootModel[List[List[str]]]):
"""
Tags are used to encode additional information in the Secret of a Proof.
"""

__root__: List[List[str]] = []
root: List[List[str]] = []

def __init__(self, tags: Optional[List[List[str]]] = None, **kwargs):
super().__init__(**kwargs)
self.__root__ = tags or []
self.root = tags or []

def __setitem__(self, key: str, value: Union[str, List[str]]) -> None:
if isinstance(value, str):
self.__root__.append([key, value])
self.root.append([key, value])
elif isinstance(value, list):
self.__root__.append([key, *value])
self.root.append([key, *value])

def __getitem__(self, key: str) -> Union[str, None]:
return self.get_tag(key)

def get_tag(self, tag_name: str) -> Union[str, None]:
for tag in self.__root__:
for tag in self.root:
if tag[0] == tag_name:
return tag[1]
return None

def get_tag_all(self, tag_name: str) -> List[str]:
all_tags = []
for tag in self.__root__:
allroot = []
for tag in self.root:
if tag[0] == tag_name:
for t in tag[1:]:
all_tags.append(t)
return all_tags
allroot.append(t)
return allroot


class Secret(BaseModel):
Expand All @@ -61,9 +61,9 @@ def serialize(self) -> str:
"data": self.data,
"nonce": self.nonce or PrivateKey().serialize()[:32],
}
if self.tags.__root__:
logger.debug(f"Serializing tags: {self.tags.__root__}")
data_dict["tags"] = self.tags.__root__
if self.tags.root:
logger.debug(f"Serializing tags: {self.tags.root}")
data_dict["tags"] = self.tags.root
return json.dumps(
[self.kind, data_dict],
)
Expand All @@ -73,7 +73,7 @@ def deserialize(cls, from_proof: str):
kind, kwargs = json.loads(from_proof)
data = kwargs.pop("data")
nonce = kwargs.pop("nonce")
tags_list: List = kwargs.pop("tags", None)
tags_list: List[List[str]] = kwargs.pop("tags", None) or []
tags = Tags(tags=tags_list)
logger.debug(f"Deserialized Secret: {kind}, {data}, {nonce}, {tags}")
return cls(kind=kind, data=data, nonce=nonce, tags=tags)
Loading
Loading